"""WebSocket routes for real-time updates.""" from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from uuid import UUID from typing import Optional import logging from app.database import SessionLocal from app.websocket.connection_manager import manager from app.services.map_share_service import check_map_access from app.dependencies import get_user_from_token logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) @router.websocket("/ws/maps/{map_id}") async def websocket_endpoint( websocket: WebSocket, map_id: UUID, token: Optional[str] = Query(None), share_token: Optional[str] = Query(None) ): """ WebSocket endpoint for real-time updates. Clients can connect using: - JWT token (authenticated users) - Share token (guest access) Example: ws://localhost:8000/ws/maps/{map_id}?token={jwt_token} Example: ws://localhost:8000/ws/maps/{map_id}?share_token={share_token} """ # Accept the connection first await websocket.accept() # Create a temporary DB session just for authentication # This session will be closed immediately after checking access db = SessionLocal() try: # Verify access to the map user = None if token: try: user = get_user_from_token(token, db) except Exception as e: logger.error(f"Invalid token: {e}") await websocket.close(code=1008, reason="Invalid token") return # Check map access has_access, permission = check_map_access(db, map_id, user, share_token) if not has_access: await websocket.close(code=1008, reason="Access denied") return # Store permission for the connection permission_value = permission.value finally: # CRITICAL: Close the DB session immediately after authentication db.close() # Add to connection manager (don't call accept again, connect method will accept) # Note: We need to call connect but it will try to accept again, so we skip it # Instead, manually add the connection user_id = user.id if user else None map_key = str(map_id) if map_key not in manager.active_connections: manager.active_connections[map_key] = set() manager.active_connections[map_key].add((websocket, user_id)) manager.websocket_to_map[websocket] = map_key user_info = f"user {user_id}" if user_id else "guest" logger.info(f"Client ({user_info}) connected to map {map_id}. Total connections: {len(manager.active_connections[map_key])}") try: # Send initial connection message await websocket.send_json({ "type": "connected", "data": { "map_id": str(map_id), "permission": permission_value } }) # Keep connection alive and listen for messages while True: # Receive messages (for potential future use like cursor position, etc.) data = await websocket.receive_json() # Echo back for now (can add more features later) logger.info(f"Received message from client: {data}") except WebSocketDisconnect: manager.disconnect(websocket, map_id) logger.info(f"Client disconnected from map {map_id}") except Exception as e: logger.error(f"WebSocket error: {e}") manager.disconnect(websocket, map_id)