diff --git a/app/resources/expenses.py b/app/resources/expenses.py new file mode 100644 index 0000000..e342d90 --- /dev/null +++ b/app/resources/expenses.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from starlette.requests import Request +from starlette.responses import JSONResponse +from tortoise.transactions import atomic + +from app.models import ExpenseItem, ExpenseParticipant, ExpenseReport, ExpenseShare, User +from app.schemas import ( + ExpenseItemCreate, + ExpenseItemOut, + ExpenseReportCreate, + ExpenseReportListOut, + ExpenseReportOut, + ExpenseShareOut, +) + + +def _build_report_out(report: ExpenseReport) -> dict: + items_out: list[dict] = [] + for item in report.items: + shares_out: list[dict] = [] + for share in item.shares: + shares_out.append( + ExpenseShareOut( + user_sub=share.user.sub, + user_name=share.user.name, + percentage=share.percentage, + ).model_dump() + ) + items_out.append( + ExpenseItemOut( + id=item.id, + description=item.description, + amount=item.amount, + shares=shares_out, + ).model_dump() + ) + + participant_subs: set[str] = set() + for item in report.items: + for share in item.shares: + participant_subs.add(share.user.sub) + + participants_out = [ + {"sub": sub, "name": None} for sub in sorted(participant_subs) + ] + + return ExpenseReportOut( + id=report.id, + title=report.title, + creator_sub=report.creator.sub, + creator_name=report.creator.name, + created_at=report.created_at.isoformat(), + items=items_out, + participants=participants_out, + ).model_dump() + + +async def create_expense_report(request: Request) -> JSONResponse: + user: User = request.state.user + try: + body = await request.json() + except Exception: + return JSONResponse({"detail": "Invalid JSON"}, status_code=400) + + try: + payload = ExpenseReportCreate.model_validate(body) + except Exception as e: + return JSONResponse({"detail": str(e)}, status_code=400) + + all_subs: set[str] = set() + for item in payload.items: + all_subs.update(item.participants) + all_subs.add(user.sub) + + existing_users = await User.filter(sub__in=list(all_subs)).all() + sub_to_user: dict[str, User] = {u.sub: u for u in existing_users} + missing = all_subs - set(sub_to_user.keys()) + if missing: + return JSONResponse( + {"detail": f"Unknown user(s): {', '.join(sorted(missing))}"}, + status_code=400, + ) + + @atomic() + async def _create() -> dict: + report = await ExpenseReport.create(title=payload.title, creator=user) + + for sub in all_subs: + await ExpenseParticipant.create( + report=report, user=sub_to_user[sub] + ) + + for item_data in payload.items: + item = await ExpenseItem.create( + report=report, + description=item_data.description, + amount=item_data.amount, + ) + + if item_data.shares is not None: + for sub, pct in item_data.shares.items(): + await ExpenseShare.create( + item=item, + user=sub_to_user[sub], + percentage=pct, + ) + else: + equal_pct = 100.0 / len(item_data.participants) + for sub in item_data.participants: + await ExpenseShare.create( + item=item, + user=sub_to_user[sub], + percentage=equal_pct, + ) + + report = await ( + ExpenseReport.filter(id=report.id) + .prefetch_related( + "items__shares__user", + "creator", + ) + .first() + ) + assert report is not None + return _build_report_out(report) + + try: + result = await _create() + except Exception as e: + return JSONResponse({"detail": str(e)}, status_code=500) + + return JSONResponse(result, status_code=201) + + +async def list_expense_reports(request: Request) -> JSONResponse: + user: User = request.state.user + + participation_ids = await ExpenseParticipant.filter(user=user).values_list( + "report_id", flat=True + ) + + if not participation_ids: + return JSONResponse( + ExpenseReportListOut(reports=[], total_count=0).model_dump() + ) + + reports = await ( + ExpenseReport.filter(id__in=list(participation_ids)) + .order_by("-created_at") + .prefetch_related( + "items__shares__user", + "creator", + ) + ) + + reports_out = [_build_report_out(r) for r in reports] + + return JSONResponse( + ExpenseReportListOut( + reports=reports_out, + total_count=len(reports_out), + ).model_dump() + )