private shares and revokation works
This commit is contained in:
@@ -63,7 +63,7 @@ async def revoke_map_share(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Revoke map share from a user."""
|
||||
map_share_service.delete_map_share(db, map_id, share_id, current_user)
|
||||
await map_share_service.delete_map_share(db, map_id, share_id, current_user)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -60,9 +60,20 @@ async def websocket_endpoint(
|
||||
# CRITICAL: Close the DB session immediately after authentication
|
||||
db.close()
|
||||
|
||||
# Add to connection manager (don't call accept again)
|
||||
manager.active_connections.setdefault(str(map_id), set()).add(websocket)
|
||||
logger.info(f"Client connected to map {map_id}. Total connections: {len(manager.active_connections[str(map_id)])}")
|
||||
# Add to connection manager (don't call accept again, connect method will accept)
|
||||
# Note: We need to call connect but it will try to accept again, so we skip it
|
||||
# Instead, manually add the connection
|
||||
user_id = user.id if user else None
|
||||
map_key = str(map_id)
|
||||
|
||||
if map_key not in manager.active_connections:
|
||||
manager.active_connections[map_key] = set()
|
||||
|
||||
manager.active_connections[map_key].add((websocket, user_id))
|
||||
manager.websocket_to_map[websocket] = map_key
|
||||
|
||||
user_info = f"user {user_id}" if user_id else "guest"
|
||||
logger.info(f"Client ({user_info}) connected to map {map_id}. Total connections: {len(manager.active_connections[map_key])}")
|
||||
|
||||
try:
|
||||
# Send initial connection message
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.models.map_share import SharePermission
|
||||
|
||||
class MapShareCreate(BaseModel):
|
||||
"""Schema for creating a map share with a specific user."""
|
||||
user_id: UUID
|
||||
user_identifier: str # Can be username, email, or UUID
|
||||
permission: SharePermission = SharePermission.READ
|
||||
|
||||
|
||||
|
||||
@@ -8,11 +8,53 @@ from shapely.geometry import shape, Point, LineString
|
||||
import json
|
||||
|
||||
from app.models.map_item import MapItem
|
||||
from app.models.map import Map
|
||||
from app.models.map_share import MapShare, SharePermission
|
||||
from app.models.user import User
|
||||
from app.schemas.map_item import MapItemCreate, MapItemUpdate
|
||||
from app.services.map_service import get_map_by_id
|
||||
|
||||
|
||||
def check_edit_permission(db: Session, map_id: UUID, user: User) -> None:
|
||||
"""Check if user has edit permission on a map. Raises exception if not."""
|
||||
map_obj = db.query(Map).filter(Map.id == map_id).first()
|
||||
|
||||
if not map_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Map not found"
|
||||
)
|
||||
|
||||
# Owner always has edit permission
|
||||
if map_obj.owner_id == user.id:
|
||||
return
|
||||
|
||||
# Admin always has edit permission
|
||||
if user.is_admin:
|
||||
return
|
||||
|
||||
# Check if user has share access
|
||||
share = db.query(MapShare).filter(
|
||||
MapShare.map_id == map_id,
|
||||
MapShare.user_id == user.id
|
||||
).first()
|
||||
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this map"
|
||||
)
|
||||
|
||||
# Check if share permission is EDIT
|
||||
if share.permission != SharePermission.EDIT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You only have read-only access to this map"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_map_items(db: Session, map_id: UUID, user: Optional[User] = None) -> List[MapItem]:
|
||||
"""Get all items for a map."""
|
||||
# Verify user has access to the map
|
||||
@@ -60,8 +102,8 @@ def geography_to_geojson(geography) -> dict:
|
||||
|
||||
def create_map_item(db: Session, map_id: UUID, item_data: MapItemCreate, user: User) -> MapItem:
|
||||
"""Create a new map item."""
|
||||
# Verify user has access to the map
|
||||
get_map_by_id(db, map_id, user)
|
||||
# Verify user has edit permission on the map
|
||||
check_edit_permission(db, map_id, user)
|
||||
|
||||
# Convert GeoJSON to PostGIS geography
|
||||
geometry_wkt = geojson_to_geography(item_data.geometry)
|
||||
@@ -142,6 +184,9 @@ def update_map_item(db: Session, item_id: UUID, item_data: MapItemUpdate, user:
|
||||
"""Update a map item."""
|
||||
item = get_map_item_by_id(db, item_id, user)
|
||||
|
||||
# Verify user has edit permission on the map
|
||||
check_edit_permission(db, item.map_id, user)
|
||||
|
||||
# Update fields if provided
|
||||
if item_data.type is not None:
|
||||
item.type = item_data.type
|
||||
@@ -162,6 +207,9 @@ def delete_map_item(db: Session, item_id: UUID, user: User) -> None:
|
||||
"""Delete a map item."""
|
||||
item = get_map_item_by_id(db, item_id, user)
|
||||
|
||||
# Verify user has edit permission on the map
|
||||
check_edit_permission(db, item.map_id, user)
|
||||
|
||||
# Capture map_id and item_id before deletion for WebSocket broadcast
|
||||
map_id = item.map_id
|
||||
deleted_item_id = str(item.id)
|
||||
|
||||
@@ -6,12 +6,28 @@ from fastapi import HTTPException, status
|
||||
|
||||
from app.models.map import Map
|
||||
from app.models.user import User
|
||||
from app.models.map_share import MapShare
|
||||
from app.schemas.map import MapCreate, MapUpdate
|
||||
|
||||
|
||||
def get_user_maps(db: Session, user_id: UUID) -> List[Map]:
|
||||
"""Get all maps owned by a user."""
|
||||
return db.query(Map).filter(Map.owner_id == user_id).order_by(Map.updated_at.desc()).all()
|
||||
"""Get all maps owned by or shared with a user."""
|
||||
# Get owned maps
|
||||
owned_maps = db.query(Map).filter(Map.owner_id == user_id).all()
|
||||
|
||||
# Get shared maps
|
||||
shared_map_ids = db.query(MapShare.map_id).filter(MapShare.user_id == user_id).all()
|
||||
shared_map_ids = [share.map_id for share in shared_map_ids]
|
||||
|
||||
shared_maps = []
|
||||
if shared_map_ids:
|
||||
shared_maps = db.query(Map).filter(Map.id.in_(shared_map_ids)).all()
|
||||
|
||||
# Combine and sort by updated_at
|
||||
all_maps = owned_maps + shared_maps
|
||||
all_maps.sort(key=lambda m: m.updated_at, reverse=True)
|
||||
|
||||
return all_maps
|
||||
|
||||
|
||||
def get_map_by_id(db: Session, map_id: UUID, user: Optional[User] = None) -> Map:
|
||||
@@ -26,7 +42,15 @@ def get_map_by_id(db: Session, map_id: UUID, user: Optional[User] = None) -> Map
|
||||
|
||||
# If user is provided, check authorization
|
||||
if user:
|
||||
if map_obj.owner_id != user.id and not user.is_admin:
|
||||
# Check if user is owner, admin, or has been granted access via share
|
||||
is_owner = map_obj.owner_id == user.id
|
||||
is_admin = user.is_admin
|
||||
has_share_access = db.query(MapShare).filter(
|
||||
MapShare.map_id == map_id,
|
||||
MapShare.user_id == user.id
|
||||
).first() is not None
|
||||
|
||||
if not (is_owner or is_admin or has_share_access):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to access this map"
|
||||
|
||||
@@ -33,18 +33,37 @@ def create_map_share(
|
||||
detail="Only the map owner can share it"
|
||||
)
|
||||
|
||||
# Check if user exists
|
||||
target_user = db.query(User).filter(User.id == share_data.user_id).first()
|
||||
# Look up user by username, email, or UUID
|
||||
target_user = None
|
||||
user_identifier = share_data.user_identifier.strip()
|
||||
|
||||
# Try UUID first
|
||||
try:
|
||||
user_uuid = UUID(user_identifier)
|
||||
target_user = db.query(User).filter(User.id == user_uuid).first()
|
||||
except ValueError:
|
||||
# Not a valid UUID, try username or email
|
||||
target_user = db.query(User).filter(
|
||||
(User.username == user_identifier) | (User.email == user_identifier)
|
||||
).first()
|
||||
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
detail=f"User not found with identifier: {user_identifier}"
|
||||
)
|
||||
|
||||
# Prevent sharing with self
|
||||
if target_user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot share map with yourself"
|
||||
)
|
||||
|
||||
# Check if already shared
|
||||
existing_share = db.query(MapShare).filter(
|
||||
MapShare.map_id == map_id,
|
||||
MapShare.user_id == share_data.user_id
|
||||
MapShare.user_id == target_user.id
|
||||
).first()
|
||||
|
||||
if existing_share:
|
||||
@@ -58,7 +77,7 @@ def create_map_share(
|
||||
# Create new share
|
||||
share = MapShare(
|
||||
map_id=map_id,
|
||||
user_id=share_data.user_id,
|
||||
user_id=target_user.id,
|
||||
permission=share_data.permission,
|
||||
shared_by=current_user.id
|
||||
)
|
||||
@@ -120,7 +139,7 @@ def update_map_share(
|
||||
return share
|
||||
|
||||
|
||||
def delete_map_share(
|
||||
async def delete_map_share(
|
||||
db: Session,
|
||||
map_id: UUID,
|
||||
share_id: UUID,
|
||||
@@ -146,9 +165,16 @@ def delete_map_share(
|
||||
detail="Share not found"
|
||||
)
|
||||
|
||||
# Get user_id before deleting the share
|
||||
revoked_user_id = share.user_id
|
||||
|
||||
db.delete(share)
|
||||
db.commit()
|
||||
|
||||
# Disconnect the user's WebSocket connections
|
||||
from app.websocket.connection_manager import manager
|
||||
await manager.disconnect_user(map_id, revoked_user_id)
|
||||
|
||||
|
||||
def create_share_link(
|
||||
db: Session,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""WebSocket connection manager for real-time map updates."""
|
||||
from fastapi import WebSocket
|
||||
from typing import Dict, List, Set
|
||||
from typing import Dict, List, Set, Optional, Tuple
|
||||
from uuid import UUID
|
||||
import json
|
||||
import logging
|
||||
@@ -12,10 +12,12 @@ class ConnectionManager:
|
||||
"""Manages WebSocket connections for real-time map updates."""
|
||||
|
||||
def __init__(self):
|
||||
# map_id -> Set of WebSocket connections
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
# map_id -> Set of (WebSocket, user_id) tuples
|
||||
self.active_connections: Dict[str, Set[Tuple[WebSocket, Optional[UUID]]]] = {}
|
||||
# websocket -> map_id mapping for quick lookup
|
||||
self.websocket_to_map: Dict[WebSocket, str] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, map_id: UUID):
|
||||
async def connect(self, websocket: WebSocket, map_id: UUID, user_id: Optional[UUID] = None):
|
||||
"""Accept a new WebSocket connection for a map."""
|
||||
await websocket.accept()
|
||||
map_key = str(map_id)
|
||||
@@ -23,20 +25,63 @@ class ConnectionManager:
|
||||
if map_key not in self.active_connections:
|
||||
self.active_connections[map_key] = set()
|
||||
|
||||
self.active_connections[map_key].add(websocket)
|
||||
logger.info(f"Client connected to map {map_id}. Total connections: {len(self.active_connections[map_key])}")
|
||||
# Store websocket with user_id
|
||||
self.active_connections[map_key].add((websocket, user_id))
|
||||
self.websocket_to_map[websocket] = map_key
|
||||
|
||||
user_info = f"user {user_id}" if user_id else "guest"
|
||||
logger.info(f"Client ({user_info}) connected to map {map_id}. Total connections: {len(self.active_connections[map_key])}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket, map_id: UUID):
|
||||
"""Remove a WebSocket connection."""
|
||||
map_key = str(map_id)
|
||||
|
||||
if map_key in self.active_connections:
|
||||
self.active_connections[map_key].discard(websocket)
|
||||
# Find and remove the tuple containing this websocket
|
||||
self.active_connections[map_key] = {
|
||||
conn for conn in self.active_connections[map_key]
|
||||
if conn[0] != websocket
|
||||
}
|
||||
if not self.active_connections[map_key]:
|
||||
del self.active_connections[map_key]
|
||||
|
||||
# Remove from websocket_to_map
|
||||
self.websocket_to_map.pop(websocket, None)
|
||||
|
||||
logger.info(f"Client disconnected from map {map_id}")
|
||||
|
||||
async def disconnect_user(self, map_id: UUID, user_id: UUID):
|
||||
"""Disconnect all connections for a specific user on a specific map."""
|
||||
map_key = str(map_id)
|
||||
|
||||
if map_key not in self.active_connections:
|
||||
return
|
||||
|
||||
# Find all websockets for this user
|
||||
connections_to_close = [
|
||||
websocket for websocket, uid in self.active_connections[map_key]
|
||||
if uid == user_id
|
||||
]
|
||||
|
||||
# Close each connection
|
||||
for websocket in connections_to_close:
|
||||
try:
|
||||
await websocket.close(code=1008, reason="Access revoked")
|
||||
logger.info(f"Closed WebSocket for user {user_id} on map {map_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing WebSocket for user {user_id}: {e}")
|
||||
|
||||
# Remove from active connections
|
||||
self.active_connections[map_key] = {
|
||||
conn for conn in self.active_connections[map_key]
|
||||
if conn[0] != websocket
|
||||
}
|
||||
self.websocket_to_map.pop(websocket, None)
|
||||
|
||||
# Clean up empty map entry
|
||||
if not self.active_connections[map_key]:
|
||||
del self.active_connections[map_key]
|
||||
|
||||
async def broadcast_to_map(self, map_id: UUID, message: dict):
|
||||
"""Broadcast a message to all clients connected to a specific map."""
|
||||
map_key = str(map_id)
|
||||
@@ -48,16 +93,16 @@ class ConnectionManager:
|
||||
connections = self.active_connections[map_key].copy()
|
||||
disconnected = []
|
||||
|
||||
for connection in connections:
|
||||
for websocket, user_id in connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message to client: {e}")
|
||||
disconnected.append(connection)
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Remove disconnected clients
|
||||
for connection in disconnected:
|
||||
self.disconnect(connection, map_id)
|
||||
for websocket in disconnected:
|
||||
self.disconnect(websocket, map_id)
|
||||
|
||||
async def send_item_created(self, map_id: UUID, item_data: dict):
|
||||
"""Notify clients that a new item was created."""
|
||||
|
||||
Reference in New Issue
Block a user