From fc2f76cb1231659287f6aff23f57be91967de8a6 Mon Sep 17 00:00:00 2001 From: Walter Oggioni Date: Sat, 2 May 2026 06:54:22 +0200 Subject: [PATCH] Add Splitwise clone backend --- app/auth.py | 143 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 app/auth.py diff --git a/app/auth.py b/app/auth.py new file mode 100644 index 0000000..ae36036 --- /dev/null +++ b/app/auth.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + +import httpx +from jose import jwt +from jose.exceptions import JWTError, ExpiredSignatureError +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + +from app.config import settings +from app.models import User + +PUBLIC_PATHS = {"/health"} + + +@dataclass +class OIDCConfig: + issuer: str + jwks_uri: str + + +_oidc_config: OIDCConfig | None = None + + +async def _fetch_oidc_config(client: httpx.AsyncClient) -> OIDCConfig: + url = f"{settings.oidc_issuer}/.well-known/openid-configuration" + resp = await client.get(url) + resp.raise_for_status() + data = resp.json() + issuer = data["issuer"] + jwks_uri = data["jwks_uri"] + return OIDCConfig(issuer=issuer, jwks_uri=jwks_uri) + + +async def init_oidc() -> None: + global _oidc_config + async with httpx.AsyncClient(timeout=15.0) as client: + _oidc_config = await _fetch_oidc_config(client) + + +def _get_key(token: str, jwks: dict[str, Any]) -> dict[str, Any] | None: + try: + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + except JWTError: + return None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + return key + return None + + +async def _validate_token(token: str) -> dict[str, Any]: + if _oidc_config is None: + raise OIDCNotInitialized() + + async with httpx.AsyncClient(timeout=15.0) as client: + jwks_resp = await client.get(_oidc_config.jwks_uri) + jwks_resp.raise_for_status() + jwks = jwks_resp.json() + + key = _get_key(token, jwks) + if key is None: + raise InvalidTokenError("No matching JWKS key found") + + claims = jwt.decode( + token, + key=key, + algorithms=["RS256"], + issuer=_oidc_config.issuer, + audience=settings.audiences, + options={"verify_exp": True}, + ) + return claims + + +async def _upsert_user(sub: str, email: str | None, name: str | None) -> User: + user = await User.filter(sub=sub).first() + if user is not None: + changed = False + if user.email != email: + user.email = email + changed = True + if user.name != name: + user.name = name + changed = True + if changed: + await user.save(update_fields=["email", "name"]) + else: + user = await User.create(sub=sub, email=email, name=name) + return user + + +class OIDCNotInitialized(Exception): + pass + + +class InvalidTokenError(Exception): + pass + + +class OIDCAuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Any) -> JSONResponse: + if request.url.path in PUBLIC_PATHS: + return await call_next(request) + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return JSONResponse({"detail": "Missing or invalid Authorization header"}, status_code=401) + + token = auth_header.removeprefix("Bearer ").strip() + if not token: + return JSONResponse({"detail": "Empty bearer token"}, status_code=401) + + try: + claims = await _validate_token(token) + except ExpiredSignatureError: + return JSONResponse({"detail": "Token has expired"}, status_code=401) + except InvalidTokenError: + return JSONResponse({"detail": "Invalid token"}, status_code=401) + except OIDCNotInitialized: + return JSONResponse({"detail": "OIDC not initialized"}, status_code=500) + except Exception: + return JSONResponse({"detail": "Invalid token"}, status_code=401) + + sub = claims.get("sub") + if not sub: + return JSONResponse({"detail": "Token missing sub claim"}, status_code=401) + + email = claims.get("email") + name = claims.get("name") or claims.get("preferred_username") + + try: + user = await _upsert_user(sub, email, name) + except Exception: + return JSONResponse({"detail": "Internal server error"}, status_code=500) + + request.state.user = user + return await call_next(request)