Files
spliteasy/app/auth.py

144 lines
4.2 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
import httpx
from jose import jwt
from jose.exceptions import JWTError, ExpiredSignatureError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from app.config import settings
from app.models import User
PUBLIC_PATHS = {"/health"}
@dataclass
class OIDCConfig:
issuer: str
jwks_uri: str
_oidc_config: OIDCConfig | None = None
async def _fetch_oidc_config(client: httpx.AsyncClient) -> OIDCConfig:
url = f"{settings.oidc_issuer}/.well-known/openid-configuration"
resp = await client.get(url)
resp.raise_for_status()
data = resp.json()
issuer = data["issuer"]
jwks_uri = data["jwks_uri"]
return OIDCConfig(issuer=issuer, jwks_uri=jwks_uri)
async def init_oidc() -> None:
global _oidc_config
async with httpx.AsyncClient(timeout=15.0) as client:
_oidc_config = await _fetch_oidc_config(client)
def _get_key(token: str, jwks: dict[str, Any]) -> dict[str, Any] | None:
try:
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
except JWTError:
return None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key
return None
async def _validate_token(token: str) -> dict[str, Any]:
if _oidc_config is None:
raise OIDCNotInitialized()
async with httpx.AsyncClient(timeout=15.0) as client:
jwks_resp = await client.get(_oidc_config.jwks_uri)
jwks_resp.raise_for_status()
jwks = jwks_resp.json()
key = _get_key(token, jwks)
if key is None:
raise InvalidTokenError("No matching JWKS key found")
claims = jwt.decode(
token,
key=key,
algorithms=["RS256"],
issuer=_oidc_config.issuer,
audience=settings.audiences,
options={"verify_exp": True},
)
return claims
async def _upsert_user(sub: str, email: str | None, name: str | None) -> User:
user = await User.filter(sub=sub).first()
if user is not None:
changed = False
if user.email != email:
user.email = email
changed = True
if user.name != name:
user.name = name
changed = True
if changed:
await user.save(update_fields=["email", "name"])
else:
user = await User.create(sub=sub, email=email, name=name)
return user
class OIDCNotInitialized(Exception):
pass
class InvalidTokenError(Exception):
pass
class OIDCAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Any) -> JSONResponse:
if request.url.path in PUBLIC_PATHS:
return await call_next(request)
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return JSONResponse({"detail": "Missing or invalid Authorization header"}, status_code=401)
token = auth_header.removeprefix("Bearer ").strip()
if not token:
return JSONResponse({"detail": "Empty bearer token"}, status_code=401)
try:
claims = await _validate_token(token)
except ExpiredSignatureError:
return JSONResponse({"detail": "Token has expired"}, status_code=401)
except InvalidTokenError:
return JSONResponse({"detail": "Invalid token"}, status_code=401)
except OIDCNotInitialized:
return JSONResponse({"detail": "OIDC not initialized"}, status_code=500)
except Exception:
return JSONResponse({"detail": "Invalid token"}, status_code=401)
sub = claims.get("sub")
if not sub:
return JSONResponse({"detail": "Token missing sub claim"}, status_code=401)
email = claims.get("email")
name = claims.get("name") or claims.get("preferred_username")
try:
user = await _upsert_user(sub, email, name)
except Exception:
return JSONResponse({"detail": "Internal server error"}, status_code=500)
request.state.user = user
return await call_next(request)