"""WebSocket connection manager for real-time map updates.""" from fastapi import WebSocket from typing import Dict, List, Set, Optional, Tuple from uuid import UUID import json import logging logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections for real-time map updates.""" def __init__(self): # map_id -> Set of (WebSocket, user_id) tuples self.active_connections: Dict[str, Set[Tuple[WebSocket, Optional[UUID]]]] = {} # websocket -> map_id mapping for quick lookup self.websocket_to_map: Dict[WebSocket, str] = {} async def connect(self, websocket: WebSocket, map_id: UUID, user_id: Optional[UUID] = None): """Accept a new WebSocket connection for a map.""" await websocket.accept() map_key = str(map_id) if map_key not in self.active_connections: self.active_connections[map_key] = set() # Store websocket with user_id self.active_connections[map_key].add((websocket, user_id)) self.websocket_to_map[websocket] = map_key user_info = f"user {user_id}" if user_id else "guest" logger.info(f"Client ({user_info}) connected to map {map_id}. Total connections: {len(self.active_connections[map_key])}") def disconnect(self, websocket: WebSocket, map_id: UUID): """Remove a WebSocket connection.""" map_key = str(map_id) if map_key in self.active_connections: # Find and remove the tuple containing this websocket self.active_connections[map_key] = { conn for conn in self.active_connections[map_key] if conn[0] != websocket } if not self.active_connections[map_key]: del self.active_connections[map_key] # Remove from websocket_to_map self.websocket_to_map.pop(websocket, None) logger.info(f"Client disconnected from map {map_id}") async def disconnect_user(self, map_id: UUID, user_id: UUID): """Disconnect all connections for a specific user on a specific map.""" map_key = str(map_id) if map_key not in self.active_connections: return # Find all websockets for this user connections_to_close = [ websocket for websocket, uid in self.active_connections[map_key] if uid == user_id ] # Close each connection for websocket in connections_to_close: try: await websocket.close(code=1008, reason="Access revoked") logger.info(f"Closed WebSocket for user {user_id} on map {map_id}") except Exception as e: logger.error(f"Error closing WebSocket for user {user_id}: {e}") # Remove from active connections self.active_connections[map_key] = { conn for conn in self.active_connections[map_key] if conn[0] != websocket } self.websocket_to_map.pop(websocket, None) # Clean up empty map entry if not self.active_connections[map_key]: del self.active_connections[map_key] async def disconnect_guests(self, map_id: UUID): """Disconnect all guest connections (no user_id) on a specific map.""" map_key = str(map_id) if map_key not in self.active_connections: return # Find all guest websockets (uid is None) connections_to_close = [ websocket for websocket, uid in self.active_connections[map_key] if uid is None ] # Close each connection for websocket in connections_to_close: try: await websocket.close(code=1008, reason="Share link revoked") logger.info(f"Closed WebSocket for guest on map {map_id}") except Exception as e: logger.error(f"Error closing WebSocket for guest: {e}") # Remove from active connections self.active_connections[map_key] = { conn for conn in self.active_connections[map_key] if conn[0] != websocket } self.websocket_to_map.pop(websocket, None) # Clean up empty map entry if not self.active_connections[map_key]: del self.active_connections[map_key] async def broadcast_to_map(self, map_id: UUID, message: dict): """Broadcast a message to all clients connected to a specific map.""" map_key = str(map_id) if map_key not in self.active_connections: return # Create a copy of the set to avoid modification during iteration connections = self.active_connections[map_key].copy() disconnected = [] for websocket, user_id in connections: try: await websocket.send_json(message) except Exception as e: logger.error(f"Error sending message to client: {e}") disconnected.append(websocket) # Remove disconnected clients for websocket in disconnected: self.disconnect(websocket, map_id) async def send_item_created(self, map_id: UUID, item_data: dict): """Notify clients that a new item was created.""" await self.broadcast_to_map(map_id, { "type": "item_created", "data": item_data }) async def send_item_updated(self, map_id: UUID, item_data: dict): """Notify clients that an item was updated.""" await self.broadcast_to_map(map_id, { "type": "item_updated", "data": item_data }) async def send_item_deleted(self, map_id: UUID, item_id: str): """Notify clients that an item was deleted.""" await self.broadcast_to_map(map_id, { "type": "item_deleted", "data": {"id": item_id} }) # Global connection manager instance manager = ConnectionManager()