131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
"""WebSocket connection manager for real-time map updates."""
|
|
from fastapi import WebSocket
|
|
from typing import Dict, List, Set, Optional, Tuple
|
|
from uuid import UUID
|
|
import json
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections for real-time map updates."""
|
|
|
|
def __init__(self):
|
|
# map_id -> Set of (WebSocket, user_id) tuples
|
|
self.active_connections: Dict[str, Set[Tuple[WebSocket, Optional[UUID]]]] = {}
|
|
# websocket -> map_id mapping for quick lookup
|
|
self.websocket_to_map: Dict[WebSocket, str] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, map_id: UUID, user_id: Optional[UUID] = None):
|
|
"""Accept a new WebSocket connection for a map."""
|
|
await websocket.accept()
|
|
map_key = str(map_id)
|
|
|
|
if map_key not in self.active_connections:
|
|
self.active_connections[map_key] = set()
|
|
|
|
# Store websocket with user_id
|
|
self.active_connections[map_key].add((websocket, user_id))
|
|
self.websocket_to_map[websocket] = map_key
|
|
|
|
user_info = f"user {user_id}" if user_id else "guest"
|
|
logger.info(f"Client ({user_info}) connected to map {map_id}. Total connections: {len(self.active_connections[map_key])}")
|
|
|
|
def disconnect(self, websocket: WebSocket, map_id: UUID):
|
|
"""Remove a WebSocket connection."""
|
|
map_key = str(map_id)
|
|
|
|
if map_key in self.active_connections:
|
|
# Find and remove the tuple containing this websocket
|
|
self.active_connections[map_key] = {
|
|
conn for conn in self.active_connections[map_key]
|
|
if conn[0] != websocket
|
|
}
|
|
if not self.active_connections[map_key]:
|
|
del self.active_connections[map_key]
|
|
|
|
# Remove from websocket_to_map
|
|
self.websocket_to_map.pop(websocket, None)
|
|
|
|
logger.info(f"Client disconnected from map {map_id}")
|
|
|
|
async def disconnect_user(self, map_id: UUID, user_id: UUID):
|
|
"""Disconnect all connections for a specific user on a specific map."""
|
|
map_key = str(map_id)
|
|
|
|
if map_key not in self.active_connections:
|
|
return
|
|
|
|
# Find all websockets for this user
|
|
connections_to_close = [
|
|
websocket for websocket, uid in self.active_connections[map_key]
|
|
if uid == user_id
|
|
]
|
|
|
|
# Close each connection
|
|
for websocket in connections_to_close:
|
|
try:
|
|
await websocket.close(code=1008, reason="Access revoked")
|
|
logger.info(f"Closed WebSocket for user {user_id} on map {map_id}")
|
|
except Exception as e:
|
|
logger.error(f"Error closing WebSocket for user {user_id}: {e}")
|
|
|
|
# Remove from active connections
|
|
self.active_connections[map_key] = {
|
|
conn for conn in self.active_connections[map_key]
|
|
if conn[0] != websocket
|
|
}
|
|
self.websocket_to_map.pop(websocket, None)
|
|
|
|
# Clean up empty map entry
|
|
if not self.active_connections[map_key]:
|
|
del self.active_connections[map_key]
|
|
|
|
async def broadcast_to_map(self, map_id: UUID, message: dict):
|
|
"""Broadcast a message to all clients connected to a specific map."""
|
|
map_key = str(map_id)
|
|
|
|
if map_key not in self.active_connections:
|
|
return
|
|
|
|
# Create a copy of the set to avoid modification during iteration
|
|
connections = self.active_connections[map_key].copy()
|
|
disconnected = []
|
|
|
|
for websocket, user_id in connections:
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.error(f"Error sending message to client: {e}")
|
|
disconnected.append(websocket)
|
|
|
|
# Remove disconnected clients
|
|
for websocket in disconnected:
|
|
self.disconnect(websocket, map_id)
|
|
|
|
async def send_item_created(self, map_id: UUID, item_data: dict):
|
|
"""Notify clients that a new item was created."""
|
|
await self.broadcast_to_map(map_id, {
|
|
"type": "item_created",
|
|
"data": item_data
|
|
})
|
|
|
|
async def send_item_updated(self, map_id: UUID, item_data: dict):
|
|
"""Notify clients that an item was updated."""
|
|
await self.broadcast_to_map(map_id, {
|
|
"type": "item_updated",
|
|
"data": item_data
|
|
})
|
|
|
|
async def send_item_deleted(self, map_id: UUID, item_id: str):
|
|
"""Notify clients that an item was deleted."""
|
|
await self.broadcast_to_map(map_id, {
|
|
"type": "item_deleted",
|
|
"data": {"id": item_id}
|
|
})
|
|
|
|
|
|
# Global connection manager instance
|
|
manager = ConnectionManager()
|