commit 3ecf2e8dce6e22365a84d984c1dd6a62d65413bc Author: Shihaam Abdul Rahman Date: Fri Dec 20 06:02:52 2024 +0500 init diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..eed8546 --- /dev/null +++ b/.env.example @@ -0,0 +1,3 @@ +OMADA_USERNAME= +OMADA_PASSWORD= +API_KEYS=donotusethisinproduction,a2ndapikey,anda3rdone diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2f32dfa --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.env +venv/ + diff --git a/proxy.py b/proxy.py new file mode 100644 index 0000000..5023e09 --- /dev/null +++ b/proxy.py @@ -0,0 +1,175 @@ +from fastapi import FastAPI, Request, Response, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader +from fastapi import Security +import httpx +import os +from dotenv import load_dotenv +import json +from typing import Optional, Set +import time + +load_dotenv() + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Load API keys from environment +API_KEYS: Set[str] = set(os.getenv("API_KEYS", "").split(",")) +if not API_KEYS: + raise ValueError("API_KEYS environment variable must be set") + +api_key_header = APIKeyHeader(name="X-API-Key") + +# Global variables to store authentication data +auth_data = { + "session_id": None, + "omadac_id": None, + "token": None, + "last_auth": 0 +} + +OMADA_BASE_URL = os.getenv("OMADA_BASE_URL", "https://omada.sarlink.link") +USERNAME = os.getenv("OMADA_USERNAME") +PASSWORD = os.getenv("OMADA_PASSWORD") + +if not all([USERNAME, PASSWORD]): + raise ValueError("OMADA_USERNAME and OMADA_PASSWORD must be set in environment") + +async def validate_api_key(api_key: str = Security(api_key_header)) -> str: + """Validate the API key""" + if api_key not in API_KEYS: + raise HTTPException( + status_code=401, + detail="Invalid API key", + headers={"WWW-Authenticate": "API key"}, + ) + return api_key + +async def login_to_omada() -> bool: + """Login to Omada controller and update auth data""" + try: + async with httpx.AsyncClient(verify=False) as client: + response = await client.post( + f"{OMADA_BASE_URL}/api/v2/login", + json={"username": USERNAME, "password": PASSWORD}, + headers={"Content-Type": "application/json"} + ) + + if response.status_code != 200: + print(f"Login failed: {response.text}") + return False + + data = response.json() + if data.get("errorCode") != 0: + print(f"Login error: {data}") + return False + + # Extract session ID from cookies + cookies = response.headers.get("set-cookie") + if cookies: + for cookie in cookies.split(", "): + if "TPOMADA_SESSIONID" in cookie: + session_id = cookie.split(";")[0].split("=")[1] + auth_data["session_id"] = session_id + break + + auth_data["omadac_id"] = data["result"]["omadacId"] + auth_data["token"] = data["result"]["token"] + auth_data["last_auth"] = time.time() + + print("Successfully authenticated with Omada controller") + return True + except Exception as e: + print(f"Login error: {str(e)}") + return False + +def is_auth_valid() -> bool: + """Check if current authentication is valid""" + return all([ + auth_data["session_id"], + auth_data["omadac_id"], + auth_data["token"], + time.time() - auth_data["last_auth"] < 3600 # 1 hour timeout + ]) + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) +async def proxy_request(request: Request, path: str, api_key: str = Security(validate_api_key)): + """Forward requests to Omada controller with proper authentication""" + # Ensure we're authenticated with Omada + if not is_auth_valid(): + if not await login_to_omada(): + raise HTTPException(status_code=500, detail="Failed to authenticate with Omada controller") + + # Get the request body if any + body = None + if request.method not in ["GET", "HEAD"]: + body = await request.body() + + # Prepare headers + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("x-api-key", None) # Remove our API key from forwarded request + headers["Cookie"] = f"TPOMADA_SESSIONID={auth_data['session_id']}" + headers["csrf-token"] = auth_data["token"] + + # Construct the target URL + path = path.lstrip("/") + if "api/v2" in path and auth_data["omadac_id"] not in path: + # Insert omadac_id for API requests if not present + parts = path.split("api/v2", 1) + target_url = f"{OMADA_BASE_URL}/{auth_data['omadac_id']}/api/v2{parts[1]}" + else: + target_url = f"{OMADA_BASE_URL}/{path}" + + # Add query parameters + query = request.url.query + if query: + target_url = f"{target_url}?{query}" + + try: + async with httpx.AsyncClient(verify=False) as client: + response = await client.request( + method=request.method, + url=target_url, + headers=headers, + content=body, + follow_redirects=False # Don't automatically follow redirects + ) + + # Handle redirects to login page + if response.status_code == 302 and "/login" in response.headers.get("location", ""): + if await login_to_omada(): + # Retry the request with new credentials + headers["Cookie"] = f"TPOMADA_SESSIONID={auth_data['session_id']}" + headers["csrf-token"] = auth_data["token"] + response = await client.request( + method=request.method, + url=target_url, + headers=headers, + content=body, + follow_redirects=True + ) + + # Create response with same status code and headers + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9bb0a0f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +annotated-types==0.7.0 +anyio==4.7.0 +certifi==2024.12.14 +click==8.1.7 +fastapi==0.115.6 +h11==0.14.0 +httpcore==1.0.7 +httpx==0.28.1 +idna==3.10 +pydantic==2.10.4 +pydantic_core==2.27.2 +python-dotenv==1.0.1 +sniffio==1.3.1 +starlette==0.41.3 +typing_extensions==4.12.2 +uvicorn==0.34.0