"""Tests for require_active_subscription dependency. Verifies the 402 gating logic for Pro-guarded routers and the allowlist that lets billing/account/auth flows through even when locked. """ import uuid import pytest from datetime import datetime, timezone, timedelta from sqlalchemy import delete from app.models.subscription import Subscription async def _set_subscription(test_db, account_id, **fields): """Replace any existing Subscription on the account with one matching `fields`.""" await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) test_db.add(Subscription(account_id=account_id, **fields)) await test_db.commit() @pytest.mark.asyncio async def test_active_subscription_passes(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription(test_db, account_id, plan="pro", status="active") response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code != 402 @pytest.mark.asyncio async def test_complimentary_subscription_passes(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription(test_db, account_id, plan="pro", status="complimentary") response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code != 402 @pytest.mark.asyncio async def test_trialing_unexpired_passes(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription( test_db, account_id, plan="pro", status="trialing", current_period_end=datetime.now(timezone.utc) + timedelta(days=5), ) response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code != 402 @pytest.mark.asyncio async def test_trialing_expired_returns_402(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription( test_db, account_id, plan="pro", status="trialing", current_period_end=datetime.now(timezone.utc) - timedelta(hours=1), ) response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code == 402 body = response.json() assert body["detail"]["error"] == "subscription_inactive" @pytest.mark.asyncio async def test_canceled_returns_402(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription(test_db, account_id, plan="pro", status="canceled") response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code == 402 @pytest.mark.asyncio async def test_no_subscription_returns_402(client, test_db, test_user, auth_headers): account_id = uuid.UUID(test_user["user_data"]["account_id"]) # Remove the seeded default subscription await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id)) await test_db.commit() response = await client.get("/api/v1/trees", headers=auth_headers) assert response.status_code == 402 body = response.json() assert body["detail"]["error"] == "no_subscription" @pytest.mark.asyncio async def test_auth_me_bypasses_guard(client, test_db, test_user, auth_headers): """Allowlisted route works even when subscription is canceled.""" account_id = uuid.UUID(test_user["user_data"]["account_id"]) await _set_subscription(test_db, account_id, plan="pro", status="canceled") response = await client.get("/api/v1/auth/me", headers=auth_headers) assert response.status_code == 200