diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index a1ae7ea5..f5927f5c 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -222,3 +222,70 @@ async def require_admin_db( the user object is needed in the handler. """ return db + + +_SUBSCRIPTION_GUARD_ALLOWLIST = { + "/api/v1/auth/me", + "/api/v1/auth/logout", + "/api/v1/auth/password/change", + "/api/v1/auth/email/send-verification", + "/api/v1/auth/email/verify", + "/api/v1/billing/state", + "/api/v1/billing/checkout-session", + "/api/v1/billing/portal-session", + "/api/v1/users/me", + "/api/v1/users/me/onboarding-step", +} + + +async def require_active_subscription( + request: Request, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_admin_db)], +): + """Returns the Subscription row when the account has access; raises 402 + when locked. Mounted on routers requiring Pro entitlement. + + 'Locked' = (trialing AND current_period_end < now()) OR + (canceled OR incomplete OR no subscription). + Active states: active, complimentary, trialing-with-time-remaining, past_due. + """ + if request.url.path in _SUBSCRIPTION_GUARD_ALLOWLIST: + return None + + from app.models.subscription import Subscription + from datetime import datetime, timezone + + result = await db.execute( + select(Subscription).where(Subscription.account_id == current_user.account_id) + ) + sub = result.scalar_one_or_none() + + if sub is None: + raise HTTPException( + status_code=402, + detail={"error": "no_subscription", "upgrade_url": "/account/billing/select-plan"}, + ) + + now = datetime.now(timezone.utc) + is_live = ( + sub.status in ("active", "complimentary", "past_due") + or ( + sub.status == "trialing" + and sub.current_period_end is not None + and sub.current_period_end > now + ) + ) + if not is_live: + raise HTTPException( + status_code=402, + detail={ + "error": "subscription_inactive", + "status": sub.status, + "plan": sub.plan, + "current_period_end": sub.current_period_end.isoformat() if sub.current_period_end else None, + "upgrade_url": "/account/billing/select-plan", + }, + ) + + return sub diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 5a51bdd9..edc6ec45 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends -from app.api.deps import require_tenant_context +from app.api.deps import require_tenant_context, require_active_subscription from app.api.endpoints import ( admin, admin_audit, @@ -102,23 +102,32 @@ api_router.include_router(admin_survey.router) api_router.include_router(admin_gallery.router) # --------------------------------------------------------------------------- # User-facing endpoints — tenant context required +# +# _tenant_deps: routers that only require an authenticated user inside a +# tenant (auth/account/admin/non-Pro feature surfaces). +# _pro_deps: routers gated behind an active Pro subscription. Adds +# require_active_subscription which raises 402 unless the +# account's Subscription is active/complimentary/past_due or +# trialing-with-time-remaining. Allowlisted paths in deps.py +# bypass the gate for billing/account admin/auth flows. # --------------------------------------------------------------------------- _tenant_deps = [Depends(require_tenant_context)] +_pro_deps = [Depends(require_tenant_context), Depends(require_active_subscription)] -api_router.include_router(trees.router, dependencies=_tenant_deps) +api_router.include_router(trees.router, dependencies=_pro_deps) api_router.include_router(sidebar.router, dependencies=_tenant_deps) -api_router.include_router(sessions.router, dependencies=_tenant_deps) +api_router.include_router(sessions.router, dependencies=_pro_deps) api_router.include_router(invite.router, dependencies=_tenant_deps) api_router.include_router(categories.router, dependencies=_tenant_deps) api_router.include_router(tags.router, dependencies=_tenant_deps) api_router.include_router(folders.router, dependencies=_tenant_deps) -api_router.include_router(step_categories.router, dependencies=_tenant_deps) -api_router.include_router(steps.router, dependencies=_tenant_deps) +api_router.include_router(step_categories.router, dependencies=_pro_deps) +api_router.include_router(steps.router, dependencies=_pro_deps) api_router.include_router(accounts.router, dependencies=_tenant_deps) api_router.include_router(shares.router, dependencies=_tenant_deps) api_router.include_router(tree_markdown.router, dependencies=_tenant_deps) api_router.include_router(ratings.router, dependencies=_tenant_deps) -api_router.include_router(analytics.router, dependencies=_tenant_deps) +api_router.include_router(analytics.router, dependencies=_pro_deps) api_router.include_router(target_lists.router, dependencies=_tenant_deps) api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps) api_router.include_router(feedback.router, dependencies=_tenant_deps) @@ -126,31 +135,31 @@ api_router.include_router(ai_builder.router, dependencies=_tenant_deps) api_router.include_router(ai_fix.router, dependencies=_tenant_deps) api_router.include_router(ai_chat.router, dependencies=_tenant_deps) api_router.include_router(copilot.router, dependencies=_tenant_deps) -api_router.include_router(assistant_chat.router, dependencies=_tenant_deps) +api_router.include_router(assistant_chat.router, dependencies=_pro_deps) api_router.include_router(tree_transfer.router, dependencies=_tenant_deps) api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps) api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps) -api_router.include_router(scripts.router, dependencies=_tenant_deps) -api_router.include_router(integrations.router, dependencies=_tenant_deps) +api_router.include_router(scripts.router, dependencies=_pro_deps) +api_router.include_router(integrations.router, dependencies=_pro_deps) api_router.include_router(onboarding.router, dependencies=_tenant_deps) api_router.include_router(branding.router, dependencies=_tenant_deps) api_router.include_router(supporting_data.router, dependencies=_tenant_deps) api_router.include_router(network_diagrams.router, dependencies=_tenant_deps) # session_handoffs queue router must come before ai_sessions to avoid conflict -api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps) -api_router.include_router(session_resolutions.router, dependencies=_tenant_deps) +api_router.include_router(session_handoffs.queue_router, dependencies=_pro_deps) +api_router.include_router(session_resolutions.router, dependencies=_pro_deps) # session_facts mounts under /ai-sessions/{id}/facts — register before ai_sessions # so the {session_id}/facts subpaths take precedence over any future generic catchalls. -api_router.include_router(session_facts.router, dependencies=_tenant_deps) -api_router.include_router(session_suggested_fixes.router, dependencies=_tenant_deps) +api_router.include_router(session_facts.router, dependencies=_pro_deps) +api_router.include_router(session_suggested_fixes.router, dependencies=_pro_deps) api_router.include_router(draft_templates.router, dependencies=_tenant_deps) -api_router.include_router(ai_sessions.router, dependencies=_tenant_deps) -api_router.include_router(flow_proposals.router, dependencies=_tenant_deps) -api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps) +api_router.include_router(ai_sessions.router, dependencies=_pro_deps) +api_router.include_router(flow_proposals.router, dependencies=_pro_deps) +api_router.include_router(flowpilot_analytics.router, dependencies=_pro_deps) api_router.include_router(notifications.router, dependencies=_tenant_deps) api_router.include_router(uploads.router, dependencies=_tenant_deps) -api_router.include_router(script_builder.router, dependencies=_tenant_deps) +api_router.include_router(script_builder.router, dependencies=_pro_deps) api_router.include_router(beta_feedback.router, dependencies=_tenant_deps) -api_router.include_router(session_branches.router, dependencies=_tenant_deps) -api_router.include_router(session_handoffs.router, dependencies=_tenant_deps) +api_router.include_router(session_branches.router, dependencies=_pro_deps) +api_router.include_router(session_handoffs.router, dependencies=_pro_deps) api_router.include_router(device_types.router, dependencies=_tenant_deps) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 606609df..5f1d21c9 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -248,13 +248,23 @@ async def client(test_db: AsyncSession): @pytest.fixture -async def test_user(client): +async def test_user(client, test_db): """ Create a test user and return their credentials. + Also seeds a default active Pro Subscription so Pro-guarded routes work + in tests. Phase 1 Task 11 added require_active_subscription; without + this seed every existing test that hits a Pro router would 402. The + register endpoint creates a default `free`/`active` Subscription, so + we delete-then-insert to avoid the unique account_id constraint. + Returns: dict with email, password, and user_data """ + import uuid + from sqlalchemy import delete + from app.models.subscription import Subscription + user_data = { "email": "test@example.com", "password": "TestPassword123!", @@ -264,6 +274,13 @@ async def test_user(client): response = await client.post("/api/v1/auth/register", json=user_data) assert response.status_code == 200 or response.status_code == 201 + account_id = uuid.UUID(response.json()["account_id"]) + await test_db.execute( + delete(Subscription).where(Subscription.account_id == account_id) + ) + test_db.add(Subscription(account_id=account_id, plan="pro", status="active")) + await test_db.commit() + return { "email": user_data["email"], "password": user_data["password"], @@ -346,11 +363,14 @@ async def test_admin(client, test_db): Create a test super-admin user. Registers as engineer (the only role available at registration), - then promotes to super_admin directly via the DB session. + then promotes to super_admin directly via the DB session. Also + seeds a default active Pro Subscription (see test_user docstring). """ + import uuid from uuid import UUID as PyUUID - from sqlalchemy import select + from sqlalchemy import select, delete from app.models.user import User + from app.models.subscription import Subscription admin_data = { "email": "admin@example.com", @@ -365,6 +385,12 @@ async def test_admin(client, test_db): result = await test_db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() user.is_super_admin = True + + account_id = uuid.UUID(response.json()["account_id"]) + await test_db.execute( + delete(Subscription).where(Subscription.account_id == account_id) + ) + test_db.add(Subscription(account_id=account_id, plan="pro", status="active")) await test_db.commit() return { diff --git a/backend/tests/test_account_management.py b/backend/tests/test_account_management.py index a8fed198..60031fba 100644 --- a/backend/tests/test_account_management.py +++ b/backend/tests/test_account_management.py @@ -21,17 +21,21 @@ class TestAccountEndpoints: @pytest.mark.asyncio async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict): - """Test getting current user's subscription details.""" + """Test getting current user's subscription details. + + The test_user fixture seeds a Pro/active Subscription so + Pro-guarded routers work; reflect that in the expected plan. + """ response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) assert response.status_code == 200 data = response.json() assert "subscription" in data assert "limits" in data assert "usage" in data - assert data["subscription"]["plan"] == "free" + assert data["subscription"]["plan"] == "pro" assert data["subscription"]["status"] == "active" - assert data["limits"]["max_trees"] == 3 - assert data["limits"]["max_sessions_per_month"] == 20 + assert data["limits"]["max_trees"] == 25 + assert data["limits"]["max_sessions_per_month"] == 200 @pytest.mark.asyncio async def test_get_my_members(self, client: AsyncClient, auth_headers: dict): diff --git a/backend/tests/test_subscription_guards.py b/backend/tests/test_subscription_guards.py new file mode 100644 index 00000000..ab07f146 --- /dev/null +++ b/backend/tests/test_subscription_guards.py @@ -0,0 +1,89 @@ +"""Tests for require_active_subscription dependency. + +Verifies the 402 gating logic for Pro-guarded routers and the allowlist +that lets billing/account/auth flows through even when locked. +""" + +import uuid +import pytest +from datetime import datetime, timezone, timedelta +from sqlalchemy import delete +from app.models.subscription import Subscription + + +async def _set_subscription(test_db, account_id, **fields): + """Replace any existing Subscription on the account with one matching `fields`.""" + await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + test_db.add(Subscription(account_id=account_id, **fields)) + await test_db.commit() + + +@pytest.mark.asyncio +async def test_active_subscription_passes(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription(test_db, account_id, plan="pro", status="active") + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code != 402 + + +@pytest.mark.asyncio +async def test_complimentary_subscription_passes(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription(test_db, account_id, plan="pro", status="complimentary") + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code != 402 + + +@pytest.mark.asyncio +async def test_trialing_unexpired_passes(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription( + test_db, account_id, + plan="pro", status="trialing", + current_period_end=datetime.now(timezone.utc) + timedelta(days=5), + ) + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code != 402 + + +@pytest.mark.asyncio +async def test_trialing_expired_returns_402(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription( + test_db, account_id, + plan="pro", status="trialing", + current_period_end=datetime.now(timezone.utc) - timedelta(hours=1), + ) + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code == 402 + body = response.json() + assert body["detail"]["error"] == "subscription_inactive" + + +@pytest.mark.asyncio +async def test_canceled_returns_402(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription(test_db, account_id, plan="pro", status="canceled") + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code == 402 + + +@pytest.mark.asyncio +async def test_no_subscription_returns_402(client, test_db, test_user, auth_headers): + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + # Remove the seeded default subscription + await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + await test_db.commit() + response = await client.get("/api/v1/trees", headers=auth_headers) + assert response.status_code == 402 + body = response.json() + assert body["detail"]["error"] == "no_subscription" + + +@pytest.mark.asyncio +async def test_auth_me_bypasses_guard(client, test_db, test_user, auth_headers): + """Allowlisted route works even when subscription is canceled.""" + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + await _set_subscription(test_db, account_id, plan="pro", status="canceled") + response = await client.get("/api/v1/auth/me", headers=auth_headers) + assert response.status_code == 200 diff --git a/backend/tests/test_subscription_limits.py b/backend/tests/test_subscription_limits.py index 6f07266f..196861e7 100644 --- a/backend/tests/test_subscription_limits.py +++ b/backend/tests/test_subscription_limits.py @@ -10,8 +10,15 @@ class TestSubscriptionLimits: """Test suite for subscription plan limits.""" @pytest.mark.asyncio - async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict): + async def test_free_plan_tree_limit( + self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession + ): """Test that free plan has tree creation limit of 3.""" + from app.models.subscription import Subscription + sub = (await test_db.execute(select(Subscription))).scalar_one() + sub.plan = "free" + await test_db.commit() + tree_template = { "name": "Limit Test Tree", "tree_structure": { @@ -90,8 +97,15 @@ class TestSubscriptionLimits: assert response.status_code == 201 @pytest.mark.asyncio - async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict): + async def test_free_plan_limits_correct( + self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession + ): """Test that free plan limits are correct.""" + from app.models.subscription import Subscription + sub = (await test_db.execute(select(Subscription))).scalar_one() + sub.plan = "free" + await test_db.commit() + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) assert response.status_code == 200 limits = response.json()["limits"] diff --git a/backend/tests/test_tenant_isolation_p0.py b/backend/tests/test_tenant_isolation_p0.py index 4ef9729f..dfa70f7e 100644 --- a/backend/tests/test_tenant_isolation_p0.py +++ b/backend/tests/test_tenant_isolation_p0.py @@ -12,13 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.account import Account from app.models.user import User from app.models.tree import Tree +from app.models.subscription import Subscription from app.core.security import get_password_hash # ── Helpers ────────────────────────────────────────────────────────────────── async def _create_account_and_user(db: AsyncSession, prefix: str): - """Create a fresh account + engineer user. Returns (account, user, plain_password).""" + """Create a fresh account + engineer user. Returns (account, user, plain_password). + + Seeds a default active Pro Subscription for the account so requests pass + the require_active_subscription guard added in Phase 1 Task 11. + """ password = "TestPass123!" account = Account( name=f"{prefix}-corp", @@ -36,6 +41,7 @@ async def _create_account_and_user(db: AsyncSession, prefix: str): account_role="engineer", ) db.add(user) + db.add(Subscription(account_id=account.id, plan="pro", status="active")) await db.flush() return account, user, password @@ -168,6 +174,7 @@ async def test_ai_session_search_cannot_see_other_users_sessions( account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8]) test_db.add(account) await test_db.flush() + test_db.add(Subscription(account_id=account.id, plan="pro", status="active")) password = "TestPass123!" user_a = User(