"""Integration tests for admin plan limits and account override endpoints.""" from unittest.mock import AsyncMock, patch import pytest from httpx import AsyncClient from sqlalchemy import select from app.models.plan_billing import PlanBilling class TestAdminPlanLimits: @pytest.mark.asyncio async def test_list_plan_limits( self, client: AsyncClient, admin_auth_headers: dict ): """List all plan limits.""" response = await client.get("/api/v1/admin/plan-limits", headers=admin_auth_headers) assert response.status_code == 200 plans = response.json() assert len(plans) >= 3 # free, pro, team seeded in conftest plan_names = [p["plan"] for p in plans] assert "free" in plan_names @pytest.mark.asyncio async def test_update_plan_limits( self, client: AsyncClient, admin_auth_headers: dict ): """Update a plan's limits.""" response = await client.put( "/api/v1/admin/plan-limits", json={ "plan": "free", "max_trees": 5, "max_sessions_per_month": 30, "max_users": 2, "custom_branding": False, "priority_support": False, "export_formats": ["markdown", "text"], }, headers=admin_auth_headers, ) assert response.status_code == 200 data = response.json() assert data["max_trees"] == 5 @pytest.mark.asyncio async def test_list_account_overrides( self, client: AsyncClient, admin_auth_headers: dict ): """List account overrides.""" response = await client.get("/api/v1/admin/account-overrides", headers=admin_auth_headers) assert response.status_code == 200 assert isinstance(response.json(), list) @pytest.mark.asyncio async def test_non_admin_cannot_access( self, client: AsyncClient, auth_headers: dict ): """Non-admin gets 403.""" response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers) assert response.status_code == 403 @pytest.mark.asyncio async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present( self, client: AsyncClient, admin_auth_headers: dict, test_db ): """GET /admin/plan-limits returns plan_billing fields when a row exists, and None for plans that don't have one yet.""" # Seed a plan_billing row for "pro". existing = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "pro") )).scalar_one_or_none() if existing is None: test_db.add(PlanBilling( plan="pro", display_name="Pro", description="For working teams", monthly_price_cents=4900, annual_price_cents=49000, stripe_product_id="prod_seed", stripe_monthly_price_id="price_seed_m", stripe_annual_price_id="price_seed_a", is_public=True, is_archived=False, sort_order=10, )) await test_db.commit() response = await client.get( "/api/v1/admin/plan-limits", headers=admin_auth_headers ) assert response.status_code == 200 plans_by_name = {p["plan"]: p for p in response.json()} assert "pro" in plans_by_name pro = plans_by_name["pro"] assert pro["display_name"] == "Pro" assert pro["monthly_price_cents"] == 4900 assert pro["stripe_monthly_price_id"] == "price_seed_m" assert pro["is_public"] is True assert pro["is_archived"] is False assert pro["sort_order"] == 10 # A plan without a plan_billing row should still return, with None # billing fields. if "free" in plans_by_name: free = plans_by_name["free"] # free has no plan_billing row in the seed → fields are None. no_billing_row = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "free") )).scalar_one_or_none() is None if no_billing_row: assert free["display_name"] is None assert free["monthly_price_cents"] is None assert free["stripe_product_id"] is None @pytest.mark.asyncio async def test_admin_plan_limits_put_creates_plan_billing_row( self, client: AsyncClient, admin_auth_headers: dict, test_db ): """PUT /admin/plan-limits upserts a plan_billing row when billing fields are included in the body.""" # Ensure no plan_billing row exists for "enterprise" yet. existing = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() if existing is not None: await test_db.delete(existing) await test_db.commit() response = await client.put( "/api/v1/admin/plan-limits", json={ "plan": "enterprise", "max_trees": None, "max_sessions_per_month": None, "max_users": None, "custom_branding": True, "priority_support": True, "export_formats": ["markdown", "text", "pdf"], "display_name": "Team", "description": "For growing shops", "monthly_price_cents": 9900, "annual_price_cents": 99000, "stripe_product_id": "prod_team_test", "stripe_monthly_price_id": "price_team_m", "stripe_annual_price_id": "price_team_a", "is_public": True, "is_archived": False, "sort_order": 20, }, headers=admin_auth_headers, ) assert response.status_code == 200, response.text body = response.json() assert body["display_name"] == "Team" assert body["monthly_price_cents"] == 9900 assert body["stripe_product_id"] == "prod_team_test" assert body["sort_order"] == 20 # Confirm the row was actually persisted. await test_db.commit() # ensure session sees other-session writes pb = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() assert pb is not None assert pb.display_name == "Team" assert pb.monthly_price_cents == 9900 assert pb.stripe_monthly_price_id == "price_team_m" assert pb.is_public is True @pytest.mark.asyncio async def test_admin_plan_limits_put_does_not_null_out_required_fields( self, client: AsyncClient, admin_auth_headers: dict, test_db ): """PUT /admin/plan-limits must not NULL out NOT NULL columns on the plan_billing row when the caller passes explicit nulls. The set of guarded fields is {display_name, is_public, is_archived, sort_order}. """ # Seed a plan_billing row for "enterprise" with non-default values for every # NOT NULL field so we can detect any clobbering. existing = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() if existing is not None: await test_db.delete(existing) await test_db.commit() seeded = PlanBilling( plan="enterprise", display_name="Team Seeded", is_public=False, is_archived=True, sort_order=5, ) test_db.add(seeded) await test_db.commit() response = await client.put( "/api/v1/admin/plan-limits", json={ "plan": "enterprise", "max_trees": None, "max_sessions_per_month": None, "max_users": None, "custom_branding": True, "priority_support": True, "export_formats": ["markdown", "text"], # Explicit nulls for every NOT NULL plan_billing field. "display_name": None, "is_public": None, "is_archived": None, "sort_order": None, }, headers=admin_auth_headers, ) assert response.status_code == 200, response.text # Confirm the seeded NOT NULL values were preserved. await test_db.commit() # ensure session sees writes from the request pb = (await test_db.execute( select(PlanBilling).where(PlanBilling.plan == "enterprise") )).scalar_one_or_none() assert pb is not None assert pb.display_name == "Team Seeded" assert pb.is_public is False assert pb.is_archived is True assert pb.sort_order == 5 @pytest.mark.asyncio async def test_admin_plan_limits_put_invalidates_billing_cache( self, client: AsyncClient, admin_auth_headers: dict ): """PUT /admin/plan-limits calls BillingService.invalidate_billing_cache with the account_ids on the affected plan.""" # Patch the staticmethod on the class. The endpoint imports # BillingService at module load, so patch the symbol on the class # itself — both the import and the dotted reference resolve to it. with patch( "app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache", new_callable=AsyncMock, ) as spy: response = await client.put( "/api/v1/admin/plan-limits", json={ "plan": "pro", "max_trees": 25, "max_sessions_per_month": 500, "max_users": 10, "custom_branding": True, "priority_support": True, "export_formats": ["markdown", "text"], }, headers=admin_auth_headers, ) assert response.status_code == 200, response.text spy.assert_awaited_once() (account_ids_arg,) = spy.await_args.args # admin fixture seeds an active Pro Subscription, so we expect at # least one account_id in the invalidation list. assert isinstance(account_ids_arg, list) assert len(account_ids_arg) >= 1