feat(deps): add require_active_subscription guard with allowlist

Mounts on Pro routers (trees, sessions, scripts, FlowPilot, etc.) and
returns 402 with structured detail when an account's subscription is
missing or locked. Allowlist bypasses billing/account/auth flows so
users can recover from a lapsed subscription.

Conftest now seeds a default Pro/active Subscription on test_user and
test_admin (delete-then-insert because the register endpoint already
creates a free/active sub by default). Two existing tests adapted to
the new seeded plan; tenant-isolation tests seed Subscription rows for
the accounts they create directly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 14:35:59 -04:00
parent cfe0e6cae6
commit 9ec208f6e7
7 changed files with 245 additions and 29 deletions

View File

@@ -222,3 +222,70 @@ async def require_admin_db(
the user object is needed in the handler. the user object is needed in the handler.
""" """
return db return db
_SUBSCRIPTION_GUARD_ALLOWLIST = {
"/api/v1/auth/me",
"/api/v1/auth/logout",
"/api/v1/auth/password/change",
"/api/v1/auth/email/send-verification",
"/api/v1/auth/email/verify",
"/api/v1/billing/state",
"/api/v1/billing/checkout-session",
"/api/v1/billing/portal-session",
"/api/v1/users/me",
"/api/v1/users/me/onboarding-step",
}
async def require_active_subscription(
request: Request,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
):
"""Returns the Subscription row when the account has access; raises 402
when locked. Mounted on routers requiring Pro entitlement.
'Locked' = (trialing AND current_period_end < now()) OR
(canceled OR incomplete OR no subscription).
Active states: active, complimentary, trialing-with-time-remaining, past_due.
"""
if request.url.path in _SUBSCRIPTION_GUARD_ALLOWLIST:
return None
from app.models.subscription import Subscription
from datetime import datetime, timezone
result = await db.execute(
select(Subscription).where(Subscription.account_id == current_user.account_id)
)
sub = result.scalar_one_or_none()
if sub is None:
raise HTTPException(
status_code=402,
detail={"error": "no_subscription", "upgrade_url": "/account/billing/select-plan"},
)
now = datetime.now(timezone.utc)
is_live = (
sub.status in ("active", "complimentary", "past_due")
or (
sub.status == "trialing"
and sub.current_period_end is not None
and sub.current_period_end > now
)
)
if not is_live:
raise HTTPException(
status_code=402,
detail={
"error": "subscription_inactive",
"status": sub.status,
"plan": sub.plan,
"current_period_end": sub.current_period_end.isoformat() if sub.current_period_end else None,
"upgrade_url": "/account/billing/select-plan",
},
)
return sub

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app.api.deps import require_tenant_context from app.api.deps import require_tenant_context, require_active_subscription
from app.api.endpoints import ( from app.api.endpoints import (
admin, admin,
admin_audit, admin_audit,
@@ -102,23 +102,32 @@ api_router.include_router(admin_survey.router)
api_router.include_router(admin_gallery.router) api_router.include_router(admin_gallery.router)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# User-facing endpoints — tenant context required # User-facing endpoints — tenant context required
#
# _tenant_deps: routers that only require an authenticated user inside a
# tenant (auth/account/admin/non-Pro feature surfaces).
# _pro_deps: routers gated behind an active Pro subscription. Adds
# require_active_subscription which raises 402 unless the
# account's Subscription is active/complimentary/past_due or
# trialing-with-time-remaining. Allowlisted paths in deps.py
# bypass the gate for billing/account admin/auth flows.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_tenant_deps = [Depends(require_tenant_context)] _tenant_deps = [Depends(require_tenant_context)]
_pro_deps = [Depends(require_tenant_context), Depends(require_active_subscription)]
api_router.include_router(trees.router, dependencies=_tenant_deps) api_router.include_router(trees.router, dependencies=_pro_deps)
api_router.include_router(sidebar.router, dependencies=_tenant_deps) api_router.include_router(sidebar.router, dependencies=_tenant_deps)
api_router.include_router(sessions.router, dependencies=_tenant_deps) api_router.include_router(sessions.router, dependencies=_pro_deps)
api_router.include_router(invite.router, dependencies=_tenant_deps) api_router.include_router(invite.router, dependencies=_tenant_deps)
api_router.include_router(categories.router, dependencies=_tenant_deps) api_router.include_router(categories.router, dependencies=_tenant_deps)
api_router.include_router(tags.router, dependencies=_tenant_deps) api_router.include_router(tags.router, dependencies=_tenant_deps)
api_router.include_router(folders.router, dependencies=_tenant_deps) api_router.include_router(folders.router, dependencies=_tenant_deps)
api_router.include_router(step_categories.router, dependencies=_tenant_deps) api_router.include_router(step_categories.router, dependencies=_pro_deps)
api_router.include_router(steps.router, dependencies=_tenant_deps) api_router.include_router(steps.router, dependencies=_pro_deps)
api_router.include_router(accounts.router, dependencies=_tenant_deps) api_router.include_router(accounts.router, dependencies=_tenant_deps)
api_router.include_router(shares.router, dependencies=_tenant_deps) api_router.include_router(shares.router, dependencies=_tenant_deps)
api_router.include_router(tree_markdown.router, dependencies=_tenant_deps) api_router.include_router(tree_markdown.router, dependencies=_tenant_deps)
api_router.include_router(ratings.router, dependencies=_tenant_deps) api_router.include_router(ratings.router, dependencies=_tenant_deps)
api_router.include_router(analytics.router, dependencies=_tenant_deps) api_router.include_router(analytics.router, dependencies=_pro_deps)
api_router.include_router(target_lists.router, dependencies=_tenant_deps) api_router.include_router(target_lists.router, dependencies=_tenant_deps)
api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps) api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps)
api_router.include_router(feedback.router, dependencies=_tenant_deps) api_router.include_router(feedback.router, dependencies=_tenant_deps)
@@ -126,31 +135,31 @@ api_router.include_router(ai_builder.router, dependencies=_tenant_deps)
api_router.include_router(ai_fix.router, dependencies=_tenant_deps) api_router.include_router(ai_fix.router, dependencies=_tenant_deps)
api_router.include_router(ai_chat.router, dependencies=_tenant_deps) api_router.include_router(ai_chat.router, dependencies=_tenant_deps)
api_router.include_router(copilot.router, dependencies=_tenant_deps) api_router.include_router(copilot.router, dependencies=_tenant_deps)
api_router.include_router(assistant_chat.router, dependencies=_tenant_deps) api_router.include_router(assistant_chat.router, dependencies=_pro_deps)
api_router.include_router(tree_transfer.router, dependencies=_tenant_deps) api_router.include_router(tree_transfer.router, dependencies=_tenant_deps)
api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps) api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps)
api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps) api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps)
api_router.include_router(scripts.router, dependencies=_tenant_deps) api_router.include_router(scripts.router, dependencies=_pro_deps)
api_router.include_router(integrations.router, dependencies=_tenant_deps) api_router.include_router(integrations.router, dependencies=_pro_deps)
api_router.include_router(onboarding.router, dependencies=_tenant_deps) api_router.include_router(onboarding.router, dependencies=_tenant_deps)
api_router.include_router(branding.router, dependencies=_tenant_deps) api_router.include_router(branding.router, dependencies=_tenant_deps)
api_router.include_router(supporting_data.router, dependencies=_tenant_deps) api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps) api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
# session_handoffs queue router must come before ai_sessions to avoid conflict # session_handoffs queue router must come before ai_sessions to avoid conflict
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps) api_router.include_router(session_handoffs.queue_router, dependencies=_pro_deps)
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps) api_router.include_router(session_resolutions.router, dependencies=_pro_deps)
# session_facts mounts under /ai-sessions/{id}/facts — register before ai_sessions # session_facts mounts under /ai-sessions/{id}/facts — register before ai_sessions
# so the {session_id}/facts subpaths take precedence over any future generic catchalls. # so the {session_id}/facts subpaths take precedence over any future generic catchalls.
api_router.include_router(session_facts.router, dependencies=_tenant_deps) api_router.include_router(session_facts.router, dependencies=_pro_deps)
api_router.include_router(session_suggested_fixes.router, dependencies=_tenant_deps) api_router.include_router(session_suggested_fixes.router, dependencies=_pro_deps)
api_router.include_router(draft_templates.router, dependencies=_tenant_deps) api_router.include_router(draft_templates.router, dependencies=_tenant_deps)
api_router.include_router(ai_sessions.router, dependencies=_tenant_deps) api_router.include_router(ai_sessions.router, dependencies=_pro_deps)
api_router.include_router(flow_proposals.router, dependencies=_tenant_deps) api_router.include_router(flow_proposals.router, dependencies=_pro_deps)
api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps) api_router.include_router(flowpilot_analytics.router, dependencies=_pro_deps)
api_router.include_router(notifications.router, dependencies=_tenant_deps) api_router.include_router(notifications.router, dependencies=_tenant_deps)
api_router.include_router(uploads.router, dependencies=_tenant_deps) api_router.include_router(uploads.router, dependencies=_tenant_deps)
api_router.include_router(script_builder.router, dependencies=_tenant_deps) api_router.include_router(script_builder.router, dependencies=_pro_deps)
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps) api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
api_router.include_router(session_branches.router, dependencies=_tenant_deps) api_router.include_router(session_branches.router, dependencies=_pro_deps)
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps) api_router.include_router(session_handoffs.router, dependencies=_pro_deps)
api_router.include_router(device_types.router, dependencies=_tenant_deps) api_router.include_router(device_types.router, dependencies=_tenant_deps)

View File

@@ -248,13 +248,23 @@ async def client(test_db: AsyncSession):
@pytest.fixture @pytest.fixture
async def test_user(client): async def test_user(client, test_db):
""" """
Create a test user and return their credentials. Create a test user and return their credentials.
Also seeds a default active Pro Subscription so Pro-guarded routes work
in tests. Phase 1 Task 11 added require_active_subscription; without
this seed every existing test that hits a Pro router would 402. The
register endpoint creates a default `free`/`active` Subscription, so
we delete-then-insert to avoid the unique account_id constraint.
Returns: Returns:
dict with email, password, and user_data dict with email, password, and user_data
""" """
import uuid
from sqlalchemy import delete
from app.models.subscription import Subscription
user_data = { user_data = {
"email": "test@example.com", "email": "test@example.com",
"password": "TestPassword123!", "password": "TestPassword123!",
@@ -264,6 +274,13 @@ async def test_user(client):
response = await client.post("/api/v1/auth/register", json=user_data) response = await client.post("/api/v1/auth/register", json=user_data)
assert response.status_code == 200 or response.status_code == 201 assert response.status_code == 200 or response.status_code == 201
account_id = uuid.UUID(response.json()["account_id"])
await test_db.execute(
delete(Subscription).where(Subscription.account_id == account_id)
)
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
await test_db.commit()
return { return {
"email": user_data["email"], "email": user_data["email"],
"password": user_data["password"], "password": user_data["password"],
@@ -346,11 +363,14 @@ async def test_admin(client, test_db):
Create a test super-admin user. Create a test super-admin user.
Registers as engineer (the only role available at registration), Registers as engineer (the only role available at registration),
then promotes to super_admin directly via the DB session. then promotes to super_admin directly via the DB session. Also
seeds a default active Pro Subscription (see test_user docstring).
""" """
import uuid
from uuid import UUID as PyUUID from uuid import UUID as PyUUID
from sqlalchemy import select from sqlalchemy import select, delete
from app.models.user import User from app.models.user import User
from app.models.subscription import Subscription
admin_data = { admin_data = {
"email": "admin@example.com", "email": "admin@example.com",
@@ -365,6 +385,12 @@ async def test_admin(client, test_db):
result = await test_db.execute(select(User).where(User.id == user_id)) result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one() user = result.scalar_one()
user.is_super_admin = True user.is_super_admin = True
account_id = uuid.UUID(response.json()["account_id"])
await test_db.execute(
delete(Subscription).where(Subscription.account_id == account_id)
)
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
await test_db.commit() await test_db.commit()
return { return {

View File

@@ -21,17 +21,21 @@ class TestAccountEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict): async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's subscription details.""" """Test getting current user's subscription details.
The test_user fixture seeds a Pro/active Subscription so
Pro-guarded routers work; reflect that in the expected plan.
"""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "subscription" in data assert "subscription" in data
assert "limits" in data assert "limits" in data
assert "usage" in data assert "usage" in data
assert data["subscription"]["plan"] == "free" assert data["subscription"]["plan"] == "pro"
assert data["subscription"]["status"] == "active" assert data["subscription"]["status"] == "active"
assert data["limits"]["max_trees"] == 3 assert data["limits"]["max_trees"] == 25
assert data["limits"]["max_sessions_per_month"] == 20 assert data["limits"]["max_sessions_per_month"] == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict): async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):

View File

@@ -0,0 +1,89 @@
"""Tests for require_active_subscription dependency.
Verifies the 402 gating logic for Pro-guarded routers and the allowlist
that lets billing/account/auth flows through even when locked.
"""
import uuid
import pytest
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from app.models.subscription import Subscription
async def _set_subscription(test_db, account_id, **fields):
"""Replace any existing Subscription on the account with one matching `fields`."""
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(account_id=account_id, **fields))
await test_db.commit()
@pytest.mark.asyncio
async def test_active_subscription_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="active")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_complimentary_subscription_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="complimentary")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_trialing_unexpired_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(
test_db, account_id,
plan="pro", status="trialing",
current_period_end=datetime.now(timezone.utc) + timedelta(days=5),
)
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_trialing_expired_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(
test_db, account_id,
plan="pro", status="trialing",
current_period_end=datetime.now(timezone.utc) - timedelta(hours=1),
)
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
body = response.json()
assert body["detail"]["error"] == "subscription_inactive"
@pytest.mark.asyncio
async def test_canceled_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
@pytest.mark.asyncio
async def test_no_subscription_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
# Remove the seeded default subscription
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
await test_db.commit()
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
body = response.json()
assert body["detail"]["error"] == "no_subscription"
@pytest.mark.asyncio
async def test_auth_me_bypasses_guard(client, test_db, test_user, auth_headers):
"""Allowlisted route works even when subscription is canceled."""
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
response = await client.get("/api/v1/auth/me", headers=auth_headers)
assert response.status_code == 200

View File

@@ -10,8 +10,15 @@ class TestSubscriptionLimits:
"""Test suite for subscription plan limits.""" """Test suite for subscription plan limits."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict): async def test_free_plan_tree_limit(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that free plan has tree creation limit of 3.""" """Test that free plan has tree creation limit of 3."""
from app.models.subscription import Subscription
sub = (await test_db.execute(select(Subscription))).scalar_one()
sub.plan = "free"
await test_db.commit()
tree_template = { tree_template = {
"name": "Limit Test Tree", "name": "Limit Test Tree",
"tree_structure": { "tree_structure": {
@@ -90,8 +97,15 @@ class TestSubscriptionLimits:
assert response.status_code == 201 assert response.status_code == 201
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict): async def test_free_plan_limits_correct(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that free plan limits are correct.""" """Test that free plan limits are correct."""
from app.models.subscription import Subscription
sub = (await test_db.execute(select(Subscription))).scalar_one()
sub.plan = "free"
await test_db.commit()
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
limits = response.json()["limits"] limits = response.json()["limits"]

View File

@@ -12,13 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.account import Account from app.models.account import Account
from app.models.user import User from app.models.user import User
from app.models.tree import Tree from app.models.tree import Tree
from app.models.subscription import Subscription
from app.core.security import get_password_hash from app.core.security import get_password_hash
# ── Helpers ────────────────────────────────────────────────────────────────── # ── Helpers ──────────────────────────────────────────────────────────────────
async def _create_account_and_user(db: AsyncSession, prefix: str): async def _create_account_and_user(db: AsyncSession, prefix: str):
"""Create a fresh account + engineer user. Returns (account, user, plain_password).""" """Create a fresh account + engineer user. Returns (account, user, plain_password).
Seeds a default active Pro Subscription for the account so requests pass
the require_active_subscription guard added in Phase 1 Task 11.
"""
password = "TestPass123!" password = "TestPass123!"
account = Account( account = Account(
name=f"{prefix}-corp", name=f"{prefix}-corp",
@@ -36,6 +41,7 @@ async def _create_account_and_user(db: AsyncSession, prefix: str):
account_role="engineer", account_role="engineer",
) )
db.add(user) db.add(user)
db.add(Subscription(account_id=account.id, plan="pro", status="active"))
await db.flush() await db.flush()
return account, user, password return account, user, password
@@ -168,6 +174,7 @@ async def test_ai_session_search_cannot_see_other_users_sessions(
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8]) account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
test_db.add(account) test_db.add(account)
await test_db.flush() await test_db.flush()
test_db.add(Subscription(account_id=account.id, plan="pro", status="active"))
password = "TestPass123!" password = "TestPass123!"
user_a = User( user_a = User(