Add Splitwise clone backend
This commit is contained in:
143
app/auth.py
Normal file
143
app/auth.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError, ExpiredSignatureError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
from app.models import User
|
||||
|
||||
PUBLIC_PATHS = {"/health"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OIDCConfig:
|
||||
issuer: str
|
||||
jwks_uri: str
|
||||
|
||||
|
||||
_oidc_config: OIDCConfig | None = None
|
||||
|
||||
|
||||
async def _fetch_oidc_config(client: httpx.AsyncClient) -> OIDCConfig:
|
||||
url = f"{settings.oidc_issuer}/.well-known/openid-configuration"
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
issuer = data["issuer"]
|
||||
jwks_uri = data["jwks_uri"]
|
||||
return OIDCConfig(issuer=issuer, jwks_uri=jwks_uri)
|
||||
|
||||
|
||||
async def init_oidc() -> None:
|
||||
global _oidc_config
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
_oidc_config = await _fetch_oidc_config(client)
|
||||
|
||||
|
||||
def _get_key(token: str, jwks: dict[str, Any]) -> dict[str, Any] | None:
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
kid = unverified_header.get("kid")
|
||||
except JWTError:
|
||||
return None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
async def _validate_token(token: str) -> dict[str, Any]:
|
||||
if _oidc_config is None:
|
||||
raise OIDCNotInitialized()
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
jwks_resp = await client.get(_oidc_config.jwks_uri)
|
||||
jwks_resp.raise_for_status()
|
||||
jwks = jwks_resp.json()
|
||||
|
||||
key = _get_key(token, jwks)
|
||||
if key is None:
|
||||
raise InvalidTokenError("No matching JWKS key found")
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key=key,
|
||||
algorithms=["RS256"],
|
||||
issuer=_oidc_config.issuer,
|
||||
audience=settings.audiences,
|
||||
options={"verify_exp": True},
|
||||
)
|
||||
return claims
|
||||
|
||||
|
||||
async def _upsert_user(sub: str, email: str | None, name: str | None) -> User:
|
||||
user = await User.filter(sub=sub).first()
|
||||
if user is not None:
|
||||
changed = False
|
||||
if user.email != email:
|
||||
user.email = email
|
||||
changed = True
|
||||
if user.name != name:
|
||||
user.name = name
|
||||
changed = True
|
||||
if changed:
|
||||
await user.save(update_fields=["email", "name"])
|
||||
else:
|
||||
user = await User.create(sub=sub, email=email, name=name)
|
||||
return user
|
||||
|
||||
|
||||
class OIDCNotInitialized(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTokenError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OIDCAuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Any) -> JSONResponse:
|
||||
if request.url.path in PUBLIC_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return JSONResponse({"detail": "Missing or invalid Authorization header"}, status_code=401)
|
||||
|
||||
token = auth_header.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return JSONResponse({"detail": "Empty bearer token"}, status_code=401)
|
||||
|
||||
try:
|
||||
claims = await _validate_token(token)
|
||||
except ExpiredSignatureError:
|
||||
return JSONResponse({"detail": "Token has expired"}, status_code=401)
|
||||
except InvalidTokenError:
|
||||
return JSONResponse({"detail": "Invalid token"}, status_code=401)
|
||||
except OIDCNotInitialized:
|
||||
return JSONResponse({"detail": "OIDC not initialized"}, status_code=500)
|
||||
except Exception:
|
||||
return JSONResponse({"detail": "Invalid token"}, status_code=401)
|
||||
|
||||
sub = claims.get("sub")
|
||||
if not sub:
|
||||
return JSONResponse({"detail": "Token missing sub claim"}, status_code=401)
|
||||
|
||||
email = claims.get("email")
|
||||
name = claims.get("name") or claims.get("preferred_username")
|
||||
|
||||
try:
|
||||
user = await _upsert_user(sub, email, name)
|
||||
except Exception:
|
||||
return JSONResponse({"detail": "Internal server error"}, status_code=500)
|
||||
|
||||
request.state.user = user
|
||||
return await call_next(request)
|
||||
Reference in New Issue
Block a user