from __future__ import annotations import json from dataclasses import dataclass from typing import Any import httpx from jose import jwt from jose.exceptions import JWTError, ExpiredSignatureError from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from app.config import settings from app.models import User PUBLIC_PATHS = {"/health"} @dataclass class OIDCConfig: issuer: str jwks_uri: str _oidc_config: OIDCConfig | None = None async def _fetch_oidc_config(client: httpx.AsyncClient) -> OIDCConfig: url = f"{settings.oidc_issuer}/.well-known/openid-configuration" resp = await client.get(url) resp.raise_for_status() data = resp.json() issuer = data["issuer"] jwks_uri = data["jwks_uri"] return OIDCConfig(issuer=issuer, jwks_uri=jwks_uri) async def init_oidc() -> None: global _oidc_config async with httpx.AsyncClient(timeout=15.0) as client: _oidc_config = await _fetch_oidc_config(client) def _get_key(token: str, jwks: dict[str, Any]) -> dict[str, Any] | None: try: unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get("kid") except JWTError: return None for key in jwks.get("keys", []): if key.get("kid") == kid: return key return None async def _validate_token(token: str) -> dict[str, Any]: if _oidc_config is None: raise OIDCNotInitialized() async with httpx.AsyncClient(timeout=15.0) as client: jwks_resp = await client.get(_oidc_config.jwks_uri) jwks_resp.raise_for_status() jwks = jwks_resp.json() key = _get_key(token, jwks) if key is None: raise InvalidTokenError("No matching JWKS key found") claims = jwt.decode( token, key=key, algorithms=["RS256"], issuer=_oidc_config.issuer, audience=settings.audiences, options={"verify_exp": True}, ) return claims async def _upsert_user(sub: str, email: str | None, name: str | None) -> User: user = await User.filter(sub=sub).first() if user is not None: changed = False if user.email != email: user.email = email changed = True if user.name != name: user.name = name changed = True if changed: await user.save(update_fields=["email", "name"]) else: user = await User.create(sub=sub, email=email, name=name) return user class OIDCNotInitialized(Exception): pass class InvalidTokenError(Exception): pass class OIDCAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Any) -> JSONResponse: if request.url.path in PUBLIC_PATHS: return await call_next(request) auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): return JSONResponse({"detail": "Missing or invalid Authorization header"}, status_code=401) token = auth_header.removeprefix("Bearer ").strip() if not token: return JSONResponse({"detail": "Empty bearer token"}, status_code=401) try: claims = await _validate_token(token) except ExpiredSignatureError: return JSONResponse({"detail": "Token has expired"}, status_code=401) except InvalidTokenError: return JSONResponse({"detail": "Invalid token"}, status_code=401) except OIDCNotInitialized: return JSONResponse({"detail": "OIDC not initialized"}, status_code=500) except Exception: return JSONResponse({"detail": "Invalid token"}, status_code=401) sub = claims.get("sub") if not sub: return JSONResponse({"detail": "Token missing sub claim"}, status_code=401) email = claims.get("email") name = claims.get("name") or claims.get("preferred_username") try: user = await _upsert_user(sub, email, name) except Exception: return JSONResponse({"detail": "Internal server error"}, status_code=500) request.state.user = user return await call_next(request)