feat(deps): add require_active_subscription guard with allowlist

Mounts on Pro routers (trees, sessions, scripts, FlowPilot, etc.) and
returns 402 with structured detail when an account's subscription is
missing or locked. Allowlist bypasses billing/account/auth flows so
users can recover from a lapsed subscription.

Conftest now seeds a default Pro/active Subscription on test_user and
test_admin (delete-then-insert because the register endpoint already
creates a free/active sub by default). Two existing tests adapted to
the new seeded plan; tenant-isolation tests seed Subscription rows for
the accounts they create directly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 14:35:59 -04:00
parent cfe0e6cae6
commit 9ec208f6e7
7 changed files with 245 additions and 29 deletions

View File

@@ -248,13 +248,23 @@ async def client(test_db: AsyncSession):
@pytest.fixture
async def test_user(client):
async def test_user(client, test_db):
"""
Create a test user and return their credentials.
Also seeds a default active Pro Subscription so Pro-guarded routes work
in tests. Phase 1 Task 11 added require_active_subscription; without
this seed every existing test that hits a Pro router would 402. The
register endpoint creates a default `free`/`active` Subscription, so
we delete-then-insert to avoid the unique account_id constraint.
Returns:
dict with email, password, and user_data
"""
import uuid
from sqlalchemy import delete
from app.models.subscription import Subscription
user_data = {
"email": "test@example.com",
"password": "TestPassword123!",
@@ -264,6 +274,13 @@ async def test_user(client):
response = await client.post("/api/v1/auth/register", json=user_data)
assert response.status_code == 200 or response.status_code == 201
account_id = uuid.UUID(response.json()["account_id"])
await test_db.execute(
delete(Subscription).where(Subscription.account_id == account_id)
)
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
await test_db.commit()
return {
"email": user_data["email"],
"password": user_data["password"],
@@ -346,11 +363,14 @@ async def test_admin(client, test_db):
Create a test super-admin user.
Registers as engineer (the only role available at registration),
then promotes to super_admin directly via the DB session.
then promotes to super_admin directly via the DB session. Also
seeds a default active Pro Subscription (see test_user docstring).
"""
import uuid
from uuid import UUID as PyUUID
from sqlalchemy import select
from sqlalchemy import select, delete
from app.models.user import User
from app.models.subscription import Subscription
admin_data = {
"email": "admin@example.com",
@@ -365,6 +385,12 @@ async def test_admin(client, test_db):
result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one()
user.is_super_admin = True
account_id = uuid.UUID(response.json()["account_id"])
await test_db.execute(
delete(Subscription).where(Subscription.account_id == account_id)
)
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
await test_db.commit()
return {

View File

@@ -21,17 +21,21 @@ class TestAccountEndpoints:
@pytest.mark.asyncio
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's subscription details."""
"""Test getting current user's subscription details.
The test_user fixture seeds a Pro/active Subscription so
Pro-guarded routers work; reflect that in the expected plan.
"""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "subscription" in data
assert "limits" in data
assert "usage" in data
assert data["subscription"]["plan"] == "free"
assert data["subscription"]["plan"] == "pro"
assert data["subscription"]["status"] == "active"
assert data["limits"]["max_trees"] == 3
assert data["limits"]["max_sessions_per_month"] == 20
assert data["limits"]["max_trees"] == 25
assert data["limits"]["max_sessions_per_month"] == 200
@pytest.mark.asyncio
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):

View File

@@ -0,0 +1,89 @@
"""Tests for require_active_subscription dependency.
Verifies the 402 gating logic for Pro-guarded routers and the allowlist
that lets billing/account/auth flows through even when locked.
"""
import uuid
import pytest
from datetime import datetime, timezone, timedelta
from sqlalchemy import delete
from app.models.subscription import Subscription
async def _set_subscription(test_db, account_id, **fields):
"""Replace any existing Subscription on the account with one matching `fields`."""
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(account_id=account_id, **fields))
await test_db.commit()
@pytest.mark.asyncio
async def test_active_subscription_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="active")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_complimentary_subscription_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="complimentary")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_trialing_unexpired_passes(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(
test_db, account_id,
plan="pro", status="trialing",
current_period_end=datetime.now(timezone.utc) + timedelta(days=5),
)
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code != 402
@pytest.mark.asyncio
async def test_trialing_expired_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(
test_db, account_id,
plan="pro", status="trialing",
current_period_end=datetime.now(timezone.utc) - timedelta(hours=1),
)
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
body = response.json()
assert body["detail"]["error"] == "subscription_inactive"
@pytest.mark.asyncio
async def test_canceled_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
@pytest.mark.asyncio
async def test_no_subscription_returns_402(client, test_db, test_user, auth_headers):
account_id = uuid.UUID(test_user["user_data"]["account_id"])
# Remove the seeded default subscription
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
await test_db.commit()
response = await client.get("/api/v1/trees", headers=auth_headers)
assert response.status_code == 402
body = response.json()
assert body["detail"]["error"] == "no_subscription"
@pytest.mark.asyncio
async def test_auth_me_bypasses_guard(client, test_db, test_user, auth_headers):
"""Allowlisted route works even when subscription is canceled."""
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
response = await client.get("/api/v1/auth/me", headers=auth_headers)
assert response.status_code == 200

View File

@@ -10,8 +10,15 @@ class TestSubscriptionLimits:
"""Test suite for subscription plan limits."""
@pytest.mark.asyncio
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
async def test_free_plan_tree_limit(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that free plan has tree creation limit of 3."""
from app.models.subscription import Subscription
sub = (await test_db.execute(select(Subscription))).scalar_one()
sub.plan = "free"
await test_db.commit()
tree_template = {
"name": "Limit Test Tree",
"tree_structure": {
@@ -90,8 +97,15 @@ class TestSubscriptionLimits:
assert response.status_code == 201
@pytest.mark.asyncio
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
async def test_free_plan_limits_correct(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that free plan limits are correct."""
from app.models.subscription import Subscription
sub = (await test_db.execute(select(Subscription))).scalar_one()
sub.plan = "free"
await test_db.commit()
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
limits = response.json()["limits"]

View File

@@ -12,13 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.account import Account
from app.models.user import User
from app.models.tree import Tree
from app.models.subscription import Subscription
from app.core.security import get_password_hash
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _create_account_and_user(db: AsyncSession, prefix: str):
"""Create a fresh account + engineer user. Returns (account, user, plain_password)."""
"""Create a fresh account + engineer user. Returns (account, user, plain_password).
Seeds a default active Pro Subscription for the account so requests pass
the require_active_subscription guard added in Phase 1 Task 11.
"""
password = "TestPass123!"
account = Account(
name=f"{prefix}-corp",
@@ -36,6 +41,7 @@ async def _create_account_and_user(db: AsyncSession, prefix: str):
account_role="engineer",
)
db.add(user)
db.add(Subscription(account_id=account.id, plan="pro", status="active"))
await db.flush()
return account, user, password
@@ -168,6 +174,7 @@ async def test_ai_session_search_cannot_see_other_users_sessions(
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
test_db.add(account)
await test_db.flush()
test_db.add(Subscription(account_id=account.id, plan="pro", status="active"))
password = "TestPass123!"
user_a = User(