"""WebSocket routes for real-time updates.""" from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from sqlalchemy.orm import Session from uuid import UUID from typing import Optional import logging from app.database import get_db from app.websocket.connection_manager import manager from app.services.map_share_service import check_map_access from app.dependencies import get_user_from_token logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) @router.websocket("/ws/maps/{map_id}") async def websocket_endpoint( websocket: WebSocket, map_id: UUID, token: Optional[str] = Query(None), share_token: Optional[str] = Query(None), db: Session = Depends(get_db) ): """ WebSocket endpoint for real-time map updates. Clients can connect using: - JWT token (authenticated users) - Share token (guest access) Example: ws://localhost:8000/ws/maps/{map_id}?token={jwt_token} Example: ws://localhost:8000/ws/maps/{map_id}?share_token={share_token} """ # Verify access to the map user = None if token: try: user = get_user_from_token(token, db) except Exception as e: logger.error(f"Invalid token: {e}") await websocket.close(code=1008, reason="Invalid token") return # Check map access has_access, permission = check_map_access(db, map_id, user, share_token) if not has_access: await websocket.close(code=1008, reason="Access denied") return await manager.connect(websocket, map_id) try: # Send initial connection message await websocket.send_json({ "type": "connected", "data": { "map_id": str(map_id), "permission": permission.value } }) # Keep connection alive and listen for messages while True: # Receive messages (for potential future use like cursor position, etc.) data = await websocket.receive_json() # Echo back for now (can add more features later) logger.info(f"Received message from client: {data}") except WebSocketDisconnect: manager.disconnect(websocket, map_id) logger.info(f"Client disconnected from map {map_id}") except Exception as e: logger.error(f"WebSocket error: {e}") manager.disconnect(websocket, map_id)