feat(deps): add require_active_subscription guard with allowlist
Mounts on Pro routers (trees, sessions, scripts, FlowPilot, etc.) and returns 402 with structured detail when an account's subscription is missing or locked. Allowlist bypasses billing/account/auth flows so users can recover from a lapsed subscription. Conftest now seeds a default Pro/active Subscription on test_user and test_admin (delete-then-insert because the register endpoint already creates a free/active sub by default). Two existing tests adapted to the new seeded plan; tenant-isolation tests seed Subscription rows for the accounts they create directly. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -222,3 +222,70 @@ async def require_admin_db(
|
|||||||
the user object is needed in the handler.
|
the user object is needed in the handler.
|
||||||
"""
|
"""
|
||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
_SUBSCRIPTION_GUARD_ALLOWLIST = {
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/password/change",
|
||||||
|
"/api/v1/auth/email/send-verification",
|
||||||
|
"/api/v1/auth/email/verify",
|
||||||
|
"/api/v1/billing/state",
|
||||||
|
"/api/v1/billing/checkout-session",
|
||||||
|
"/api/v1/billing/portal-session",
|
||||||
|
"/api/v1/users/me",
|
||||||
|
"/api/v1/users/me/onboarding-step",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def require_active_subscription(
|
||||||
|
request: Request,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
):
|
||||||
|
"""Returns the Subscription row when the account has access; raises 402
|
||||||
|
when locked. Mounted on routers requiring Pro entitlement.
|
||||||
|
|
||||||
|
'Locked' = (trialing AND current_period_end < now()) OR
|
||||||
|
(canceled OR incomplete OR no subscription).
|
||||||
|
Active states: active, complimentary, trialing-with-time-remaining, past_due.
|
||||||
|
"""
|
||||||
|
if request.url.path in _SUBSCRIPTION_GUARD_ALLOWLIST:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == current_user.account_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if sub is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={"error": "no_subscription", "upgrade_url": "/account/billing/select-plan"},
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
is_live = (
|
||||||
|
sub.status in ("active", "complimentary", "past_due")
|
||||||
|
or (
|
||||||
|
sub.status == "trialing"
|
||||||
|
and sub.current_period_end is not None
|
||||||
|
and sub.current_period_end > now
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not is_live:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={
|
||||||
|
"error": "subscription_inactive",
|
||||||
|
"status": sub.status,
|
||||||
|
"plan": sub.plan,
|
||||||
|
"current_period_end": sub.current_period_end.isoformat() if sub.current_period_end else None,
|
||||||
|
"upgrade_url": "/account/billing/select-plan",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from app.api.deps import require_tenant_context
|
from app.api.deps import require_tenant_context, require_active_subscription
|
||||||
from app.api.endpoints import (
|
from app.api.endpoints import (
|
||||||
admin,
|
admin,
|
||||||
admin_audit,
|
admin_audit,
|
||||||
@@ -102,23 +102,32 @@ api_router.include_router(admin_survey.router)
|
|||||||
api_router.include_router(admin_gallery.router)
|
api_router.include_router(admin_gallery.router)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# User-facing endpoints — tenant context required
|
# User-facing endpoints — tenant context required
|
||||||
|
#
|
||||||
|
# _tenant_deps: routers that only require an authenticated user inside a
|
||||||
|
# tenant (auth/account/admin/non-Pro feature surfaces).
|
||||||
|
# _pro_deps: routers gated behind an active Pro subscription. Adds
|
||||||
|
# require_active_subscription which raises 402 unless the
|
||||||
|
# account's Subscription is active/complimentary/past_due or
|
||||||
|
# trialing-with-time-remaining. Allowlisted paths in deps.py
|
||||||
|
# bypass the gate for billing/account admin/auth flows.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
_tenant_deps = [Depends(require_tenant_context)]
|
_tenant_deps = [Depends(require_tenant_context)]
|
||||||
|
_pro_deps = [Depends(require_tenant_context), Depends(require_active_subscription)]
|
||||||
|
|
||||||
api_router.include_router(trees.router, dependencies=_tenant_deps)
|
api_router.include_router(trees.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(sidebar.router, dependencies=_tenant_deps)
|
api_router.include_router(sidebar.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(sessions.router, dependencies=_tenant_deps)
|
api_router.include_router(sessions.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(invite.router, dependencies=_tenant_deps)
|
api_router.include_router(invite.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(categories.router, dependencies=_tenant_deps)
|
api_router.include_router(categories.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(tags.router, dependencies=_tenant_deps)
|
api_router.include_router(tags.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(folders.router, dependencies=_tenant_deps)
|
api_router.include_router(folders.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(step_categories.router, dependencies=_tenant_deps)
|
api_router.include_router(step_categories.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(steps.router, dependencies=_tenant_deps)
|
api_router.include_router(steps.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(accounts.router, dependencies=_tenant_deps)
|
api_router.include_router(accounts.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(shares.router, dependencies=_tenant_deps)
|
api_router.include_router(shares.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(tree_markdown.router, dependencies=_tenant_deps)
|
api_router.include_router(tree_markdown.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(ratings.router, dependencies=_tenant_deps)
|
api_router.include_router(ratings.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(analytics.router, dependencies=_tenant_deps)
|
api_router.include_router(analytics.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(target_lists.router, dependencies=_tenant_deps)
|
api_router.include_router(target_lists.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps)
|
api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(feedback.router, dependencies=_tenant_deps)
|
api_router.include_router(feedback.router, dependencies=_tenant_deps)
|
||||||
@@ -126,31 +135,31 @@ api_router.include_router(ai_builder.router, dependencies=_tenant_deps)
|
|||||||
api_router.include_router(ai_fix.router, dependencies=_tenant_deps)
|
api_router.include_router(ai_fix.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(ai_chat.router, dependencies=_tenant_deps)
|
api_router.include_router(ai_chat.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(copilot.router, dependencies=_tenant_deps)
|
api_router.include_router(copilot.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(assistant_chat.router, dependencies=_tenant_deps)
|
api_router.include_router(assistant_chat.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(tree_transfer.router, dependencies=_tenant_deps)
|
api_router.include_router(tree_transfer.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps)
|
api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps)
|
api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(scripts.router, dependencies=_tenant_deps)
|
api_router.include_router(scripts.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(integrations.router, dependencies=_tenant_deps)
|
api_router.include_router(integrations.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
|
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(branding.router, dependencies=_tenant_deps)
|
api_router.include_router(branding.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
|
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
|
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
|
||||||
# session_handoffs queue router must come before ai_sessions to avoid conflict
|
# session_handoffs queue router must come before ai_sessions to avoid conflict
|
||||||
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
|
api_router.include_router(session_handoffs.queue_router, dependencies=_pro_deps)
|
||||||
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
|
api_router.include_router(session_resolutions.router, dependencies=_pro_deps)
|
||||||
# session_facts mounts under /ai-sessions/{id}/facts — register before ai_sessions
|
# session_facts mounts under /ai-sessions/{id}/facts — register before ai_sessions
|
||||||
# so the {session_id}/facts subpaths take precedence over any future generic catchalls.
|
# so the {session_id}/facts subpaths take precedence over any future generic catchalls.
|
||||||
api_router.include_router(session_facts.router, dependencies=_tenant_deps)
|
api_router.include_router(session_facts.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(session_suggested_fixes.router, dependencies=_tenant_deps)
|
api_router.include_router(session_suggested_fixes.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(draft_templates.router, dependencies=_tenant_deps)
|
api_router.include_router(draft_templates.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(ai_sessions.router, dependencies=_tenant_deps)
|
api_router.include_router(ai_sessions.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(flow_proposals.router, dependencies=_tenant_deps)
|
api_router.include_router(flow_proposals.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps)
|
api_router.include_router(flowpilot_analytics.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(notifications.router, dependencies=_tenant_deps)
|
api_router.include_router(notifications.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(uploads.router, dependencies=_tenant_deps)
|
api_router.include_router(uploads.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(script_builder.router, dependencies=_tenant_deps)
|
api_router.include_router(script_builder.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
||||||
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
|
api_router.include_router(session_branches.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)
|
api_router.include_router(session_handoffs.router, dependencies=_pro_deps)
|
||||||
api_router.include_router(device_types.router, dependencies=_tenant_deps)
|
api_router.include_router(device_types.router, dependencies=_tenant_deps)
|
||||||
|
|||||||
@@ -248,13 +248,23 @@ async def client(test_db: AsyncSession):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_user(client):
|
async def test_user(client, test_db):
|
||||||
"""
|
"""
|
||||||
Create a test user and return their credentials.
|
Create a test user and return their credentials.
|
||||||
|
|
||||||
|
Also seeds a default active Pro Subscription so Pro-guarded routes work
|
||||||
|
in tests. Phase 1 Task 11 added require_active_subscription; without
|
||||||
|
this seed every existing test that hits a Pro router would 402. The
|
||||||
|
register endpoint creates a default `free`/`active` Subscription, so
|
||||||
|
we delete-then-insert to avoid the unique account_id constraint.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with email, password, and user_data
|
dict with email, password, and user_data
|
||||||
"""
|
"""
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import delete
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
user_data = {
|
user_data = {
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"password": "TestPassword123!",
|
"password": "TestPassword123!",
|
||||||
@@ -264,6 +274,13 @@ async def test_user(client):
|
|||||||
response = await client.post("/api/v1/auth/register", json=user_data)
|
response = await client.post("/api/v1/auth/register", json=user_data)
|
||||||
assert response.status_code == 200 or response.status_code == 201
|
assert response.status_code == 200 or response.status_code == 201
|
||||||
|
|
||||||
|
account_id = uuid.UUID(response.json()["account_id"])
|
||||||
|
await test_db.execute(
|
||||||
|
delete(Subscription).where(Subscription.account_id == account_id)
|
||||||
|
)
|
||||||
|
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"email": user_data["email"],
|
"email": user_data["email"],
|
||||||
"password": user_data["password"],
|
"password": user_data["password"],
|
||||||
@@ -346,11 +363,14 @@ async def test_admin(client, test_db):
|
|||||||
Create a test super-admin user.
|
Create a test super-admin user.
|
||||||
|
|
||||||
Registers as engineer (the only role available at registration),
|
Registers as engineer (the only role available at registration),
|
||||||
then promotes to super_admin directly via the DB session.
|
then promotes to super_admin directly via the DB session. Also
|
||||||
|
seeds a default active Pro Subscription (see test_user docstring).
|
||||||
"""
|
"""
|
||||||
|
import uuid
|
||||||
from uuid import UUID as PyUUID
|
from uuid import UUID as PyUUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, delete
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
admin_data = {
|
admin_data = {
|
||||||
"email": "admin@example.com",
|
"email": "admin@example.com",
|
||||||
@@ -365,6 +385,12 @@ async def test_admin(client, test_db):
|
|||||||
result = await test_db.execute(select(User).where(User.id == user_id))
|
result = await test_db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one()
|
user = result.scalar_one()
|
||||||
user.is_super_admin = True
|
user.is_super_admin = True
|
||||||
|
|
||||||
|
account_id = uuid.UUID(response.json()["account_id"])
|
||||||
|
await test_db.execute(
|
||||||
|
delete(Subscription).where(Subscription.account_id == account_id)
|
||||||
|
)
|
||||||
|
test_db.add(Subscription(account_id=account_id, plan="pro", status="active"))
|
||||||
await test_db.commit()
|
await test_db.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -21,17 +21,21 @@ class TestAccountEndpoints:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
|
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
|
||||||
"""Test getting current user's subscription details."""
|
"""Test getting current user's subscription details.
|
||||||
|
|
||||||
|
The test_user fixture seeds a Pro/active Subscription so
|
||||||
|
Pro-guarded routers work; reflect that in the expected plan.
|
||||||
|
"""
|
||||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "subscription" in data
|
assert "subscription" in data
|
||||||
assert "limits" in data
|
assert "limits" in data
|
||||||
assert "usage" in data
|
assert "usage" in data
|
||||||
assert data["subscription"]["plan"] == "free"
|
assert data["subscription"]["plan"] == "pro"
|
||||||
assert data["subscription"]["status"] == "active"
|
assert data["subscription"]["status"] == "active"
|
||||||
assert data["limits"]["max_trees"] == 3
|
assert data["limits"]["max_trees"] == 25
|
||||||
assert data["limits"]["max_sessions_per_month"] == 20
|
assert data["limits"]["max_sessions_per_month"] == 200
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
|
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
|
||||||
|
|||||||
89
backend/tests/test_subscription_guards.py
Normal file
89
backend/tests/test_subscription_guards.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Tests for require_active_subscription dependency.
|
||||||
|
|
||||||
|
Verifies the 402 gating logic for Pro-guarded routers and the allowlist
|
||||||
|
that lets billing/account/auth flows through even when locked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from sqlalchemy import delete
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_subscription(test_db, account_id, **fields):
|
||||||
|
"""Replace any existing Subscription on the account with one matching `fields`."""
|
||||||
|
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||||
|
test_db.add(Subscription(account_id=account_id, **fields))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_subscription_passes(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(test_db, account_id, plan="pro", status="active")
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code != 402
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complimentary_subscription_passes(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(test_db, account_id, plan="pro", status="complimentary")
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code != 402
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trialing_unexpired_passes(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(
|
||||||
|
test_db, account_id,
|
||||||
|
plan="pro", status="trialing",
|
||||||
|
current_period_end=datetime.now(timezone.utc) + timedelta(days=5),
|
||||||
|
)
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code != 402
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trialing_expired_returns_402(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(
|
||||||
|
test_db, account_id,
|
||||||
|
plan="pro", status="trialing",
|
||||||
|
current_period_end=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||||
|
)
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code == 402
|
||||||
|
body = response.json()
|
||||||
|
assert body["detail"]["error"] == "subscription_inactive"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_canceled_returns_402(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code == 402
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_subscription_returns_402(client, test_db, test_user, auth_headers):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
# Remove the seeded default subscription
|
||||||
|
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||||
|
await test_db.commit()
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code == 402
|
||||||
|
body = response.json()
|
||||||
|
assert body["detail"]["error"] == "no_subscription"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_me_bypasses_guard(client, test_db, test_user, auth_headers):
|
||||||
|
"""Allowlisted route works even when subscription is canceled."""
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await _set_subscription(test_db, account_id, plan="pro", status="canceled")
|
||||||
|
response = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
@@ -10,8 +10,15 @@ class TestSubscriptionLimits:
|
|||||||
"""Test suite for subscription plan limits."""
|
"""Test suite for subscription plan limits."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
|
async def test_free_plan_tree_limit(
|
||||||
|
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
|
||||||
|
):
|
||||||
"""Test that free plan has tree creation limit of 3."""
|
"""Test that free plan has tree creation limit of 3."""
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
sub = (await test_db.execute(select(Subscription))).scalar_one()
|
||||||
|
sub.plan = "free"
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
tree_template = {
|
tree_template = {
|
||||||
"name": "Limit Test Tree",
|
"name": "Limit Test Tree",
|
||||||
"tree_structure": {
|
"tree_structure": {
|
||||||
@@ -90,8 +97,15 @@ class TestSubscriptionLimits:
|
|||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
|
async def test_free_plan_limits_correct(
|
||||||
|
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
|
||||||
|
):
|
||||||
"""Test that free plan limits are correct."""
|
"""Test that free plan limits are correct."""
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
sub = (await test_db.execute(select(Subscription))).scalar_one()
|
||||||
|
sub.plan = "free"
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
limits = response.json()["limits"]
|
limits = response.json()["limits"]
|
||||||
|
|||||||
@@ -12,13 +12,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.tree import Tree
|
from app.models.tree import Tree
|
||||||
|
from app.models.subscription import Subscription
|
||||||
from app.core.security import get_password_hash
|
from app.core.security import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _create_account_and_user(db: AsyncSession, prefix: str):
|
async def _create_account_and_user(db: AsyncSession, prefix: str):
|
||||||
"""Create a fresh account + engineer user. Returns (account, user, plain_password)."""
|
"""Create a fresh account + engineer user. Returns (account, user, plain_password).
|
||||||
|
|
||||||
|
Seeds a default active Pro Subscription for the account so requests pass
|
||||||
|
the require_active_subscription guard added in Phase 1 Task 11.
|
||||||
|
"""
|
||||||
password = "TestPass123!"
|
password = "TestPass123!"
|
||||||
account = Account(
|
account = Account(
|
||||||
name=f"{prefix}-corp",
|
name=f"{prefix}-corp",
|
||||||
@@ -36,6 +41,7 @@ async def _create_account_and_user(db: AsyncSession, prefix: str):
|
|||||||
account_role="engineer",
|
account_role="engineer",
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
|
db.add(Subscription(account_id=account.id, plan="pro", status="active"))
|
||||||
await db.flush()
|
await db.flush()
|
||||||
return account, user, password
|
return account, user, password
|
||||||
|
|
||||||
@@ -168,6 +174,7 @@ async def test_ai_session_search_cannot_see_other_users_sessions(
|
|||||||
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
|
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
|
||||||
test_db.add(account)
|
test_db.add(account)
|
||||||
await test_db.flush()
|
await test_db.flush()
|
||||||
|
test_db.add(Subscription(account_id=account.id, plan="pro", status="active"))
|
||||||
|
|
||||||
password = "TestPass123!"
|
password = "TestPass123!"
|
||||||
user_a = User(
|
user_a = User(
|
||||||
|
|||||||
Reference in New Issue
Block a user