From 79942c3fd32e163d7fa5870afa690e985f466cec Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 15:12:12 -0400 Subject: [PATCH] feat(billing): add GET /billing/state aggregating subscription + plan + features Co-Authored-By: Claude Opus 4.7 --- backend/app/api/endpoints/billing.py | 18 +++++- backend/app/schemas/billing.py | 30 ++++++++- backend/app/services/billing.py | 59 ++++++++++++++++++ backend/tests/test_billing_state_endpoint.py | 64 ++++++++++++++++++++ 4 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 backend/tests/test_billing_state_endpoint.py diff --git a/backend/app/api/endpoints/billing.py b/backend/app/api/endpoints/billing.py index 024b1420..23d067d4 100644 --- a/backend/app/api/endpoints/billing.py +++ b/backend/app/api/endpoints/billing.py @@ -9,7 +9,11 @@ from app.core.admin_database import get_admin_db from app.core.config import settings from app.models.account import Account from app.models.user import User -from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse +from app.schemas.billing import ( + BillingStateResponse, + CheckoutSessionCreate, + CheckoutSessionResponse, +) from app.services.billing import BillingService router = APIRouter(prefix="/billing", tags=["billing"]) @@ -34,3 +38,15 @@ async def create_checkout_session( cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan", ) return CheckoutSessionResponse(url=url) + + +@router.get("/state", response_model=BillingStateResponse) +async def get_billing_state( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_admin_db)], +) -> BillingStateResponse: + account = (await db.execute( + select(Account).where(Account.id == current_user.account_id) + )).scalar_one() + state = await BillingService.get_billing_state(db, account) + return BillingStateResponse(**state) diff --git a/backend/app/schemas/billing.py b/backend/app/schemas/billing.py index 51f4db2a..ebe9ab9d 100644 --- a/backend/app/schemas/billing.py +++ b/backend/app/schemas/billing.py @@ -1,4 +1,5 @@ -from typing import Literal +from typing import Literal, Optional, Dict, Any +from datetime import datetime from pydantic import BaseModel @@ -10,3 +11,30 @@ class CheckoutSessionCreate(BaseModel): class CheckoutSessionResponse(BaseModel): url: str + + +class SubscriptionState(BaseModel): + status: str + plan: str + current_period_start: Optional[datetime] + current_period_end: Optional[datetime] + cancel_at_period_end: bool + seat_limit: Optional[int] + has_pro_entitlement: bool + is_paid: bool + + +class PlanBillingState(BaseModel): + display_name: str + description: Optional[str] = None + monthly_price_cents: Optional[int] = None + annual_price_cents: Optional[int] = None + + model_config = {"from_attributes": True} + + +class BillingStateResponse(BaseModel): + subscription: SubscriptionState + plan_billing: Optional[PlanBillingState] + plan_limits: Dict[str, Any] + enabled_features: Dict[str, bool] diff --git a/backend/app/services/billing.py b/backend/app/services/billing.py index c94d1880..a104a5b1 100644 --- a/backend/app/services/billing.py +++ b/backend/app/services/billing.py @@ -105,6 +105,61 @@ class BillingService: ) return session.url + @staticmethod + async def get_billing_state(db: AsyncSession, account): + """Aggregate Subscription + PlanLimits + PlanBilling + resolved feature + flags for the account.""" + from app.models.plan_limits import PlanLimits + from app.models.plan_billing import PlanBilling + from app.models.feature_flag import ( + FeatureFlag, PlanFeatureDefault, AccountFeatureOverride, + ) + + sub = (await db.execute( + select(Subscription).where(Subscription.account_id == account.id) + )).scalar_one_or_none() + if sub is None: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail="No subscription for account") + + pl = (await db.execute( + select(PlanLimits).where(PlanLimits.plan == sub.plan) + )).scalar_one_or_none() + pb = (await db.execute( + select(PlanBilling).where(PlanBilling.plan == sub.plan) + )).scalar_one_or_none() + + # Resolved feature flags: plan defaults overridden by account overrides + defaults = (await db.execute( + select(PlanFeatureDefault, FeatureFlag) + .join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id) + .where(PlanFeatureDefault.plan == sub.plan) + )).all() + resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults} + overrides = (await db.execute( + select(AccountFeatureOverride, FeatureFlag) + .join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id) + .where(AccountFeatureOverride.account_id == account.id) + )).all() + for ovr, flag in overrides: + resolved[flag.flag_key] = ovr.enabled + + return { + "subscription": { + "status": sub.status, + "plan": sub.plan, + "current_period_start": sub.current_period_start, + "current_period_end": sub.current_period_end, + "cancel_at_period_end": sub.cancel_at_period_end, + "seat_limit": sub.seat_limit, + "has_pro_entitlement": sub.has_pro_entitlement, + "is_paid": sub.is_paid, + }, + "plan_billing": pb, + "plan_limits": _plan_limits_to_dict(pl) if pl else {}, + "enabled_features": resolved, + } + @staticmethod async def apply_subscription_event( db: AsyncSession, event_id: str, event_type: str, payload: dict @@ -136,6 +191,10 @@ class BillingService: return True +def _plan_limits_to_dict(pl) -> dict: + return {c.name: getattr(pl, c.name) for c in pl.__table__.columns} + + def _excerpt(payload: dict) -> dict: obj = payload.get("data", {}).get("object", {}) return { diff --git a/backend/tests/test_billing_state_endpoint.py b/backend/tests/test_billing_state_endpoint.py new file mode 100644 index 00000000..2f27b73c --- /dev/null +++ b/backend/tests/test_billing_state_endpoint.py @@ -0,0 +1,64 @@ +import uuid +import pytest +from sqlalchemy import select +from app.models.subscription import Subscription +from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride + + +@pytest.mark.asyncio +async def test_billing_state_returns_subscription_plan_features( + client, test_db, test_user, auth_headers +): + """Subscription is already seeded by test_user fixture (pro/active). + Add a feature flag default for `pro` and verify it shows up in the response.""" + flag = FeatureFlag(flag_key="psa_integration", display_name="PSA Integration") + test_db.add(flag) + await test_db.flush() + test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=True)) + await test_db.commit() + + response = await client.get("/api/v1/billing/state", headers=auth_headers) + assert response.status_code == 200, response.json() + body = response.json() + assert body["subscription"]["status"] == "active" + assert body["subscription"]["plan"] == "pro" + assert body["subscription"]["has_pro_entitlement"] is True + assert body["subscription"]["is_paid"] is True + assert body["enabled_features"]["psa_integration"] is True + # plan_limits should be a dict with the seeded pro limits from conftest + assert body["plan_limits"]["plan"] == "pro" + assert body["plan_limits"]["max_trees"] == 25 + + +@pytest.mark.asyncio +async def test_billing_state_account_override_beats_plan_default( + client, test_db, test_user, auth_headers +): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + + flag = FeatureFlag(flag_key="escalation_mode", display_name="Escalation Mode") + test_db.add(flag) + await test_db.flush() + test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=False)) + test_db.add(AccountFeatureOverride( + account_id=account_id, flag_id=flag.id, enabled=True, + )) + await test_db.commit() + + response = await client.get("/api/v1/billing/state", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["enabled_features"]["escalation_mode"] is True + + +@pytest.mark.asyncio +async def test_billing_state_404_when_no_subscription( + client, test_db, test_user, auth_headers +): + """Wipe the seeded subscription and verify the endpoint surfaces 404.""" + from sqlalchemy import delete + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + await test_db.commit() + + response = await client.get("/api/v1/billing/state", headers=auth_headers) + assert response.status_code == 404