public shares now work
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
"""WebSocket routes for real-time updates."""
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from app.database import get_db
|
||||
from app.database import SessionLocal
|
||||
from app.websocket.connection_manager import manager
|
||||
from app.services.map_share_service import check_map_access
|
||||
from app.dependencies import get_user_from_token
|
||||
@@ -20,11 +19,10 @@ async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
map_id: UUID,
|
||||
token: Optional[str] = Query(None),
|
||||
share_token: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db)
|
||||
share_token: Optional[str] = Query(None)
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time map updates.
|
||||
WebSocket endpoint for real-time updates.
|
||||
|
||||
Clients can connect using:
|
||||
- JWT token (authenticated users)
|
||||
@@ -33,23 +31,38 @@ async def websocket_endpoint(
|
||||
Example: ws://localhost:8000/ws/maps/{map_id}?token={jwt_token}
|
||||
Example: ws://localhost:8000/ws/maps/{map_id}?share_token={share_token}
|
||||
"""
|
||||
# Verify access to the map
|
||||
user = None
|
||||
if token:
|
||||
try:
|
||||
user = get_user_from_token(token, db)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid token: {e}")
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
# Accept the connection first
|
||||
await websocket.accept()
|
||||
|
||||
# Create a temporary DB session just for authentication
|
||||
# This session will be closed immediately after checking access
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Verify access to the map
|
||||
user = None
|
||||
if token:
|
||||
try:
|
||||
user = get_user_from_token(token, db)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid token: {e}")
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
return
|
||||
|
||||
# Check map access
|
||||
has_access, permission = check_map_access(db, map_id, user, share_token)
|
||||
if not has_access:
|
||||
await websocket.close(code=1008, reason="Access denied")
|
||||
return
|
||||
|
||||
# Check map access
|
||||
has_access, permission = check_map_access(db, map_id, user, share_token)
|
||||
if not has_access:
|
||||
await websocket.close(code=1008, reason="Access denied")
|
||||
return
|
||||
# Store permission for the connection
|
||||
permission_value = permission.value
|
||||
finally:
|
||||
# CRITICAL: Close the DB session immediately after authentication
|
||||
db.close()
|
||||
|
||||
await manager.connect(websocket, map_id)
|
||||
# Add to connection manager (don't call accept again)
|
||||
manager.active_connections.setdefault(str(map_id), set()).add(websocket)
|
||||
logger.info(f"Client connected to map {map_id}. Total connections: {len(manager.active_connections[str(map_id)])}")
|
||||
|
||||
try:
|
||||
# Send initial connection message
|
||||
@@ -57,7 +70,7 @@ async def websocket_endpoint(
|
||||
"type": "connected",
|
||||
"data": {
|
||||
"map_id": str(map_id),
|
||||
"permission": permission.value
|
||||
"permission": permission_value
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user