"""Tests for enhanced invite codes with plan assignment and trial durations.""" import pytest from datetime import datetime, timezone, timedelta from httpx import AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.invite_code import InviteCode from app.models.subscription import Subscription from app.models.user import User class TestInviteCodeCreation: """Test invite code creation with plan/trial fields.""" @pytest.mark.asyncio async def test_create_invite_with_plan( self, client: AsyncClient, admin_auth_headers: dict ): response = await client.post( "/api/v1/invites", json={"assigned_plan": "pro", "note": "Beta tester"}, headers=admin_auth_headers, ) assert response.status_code == 201 data = response.json() assert data["assigned_plan"] == "pro" assert data["has_trial"] is False assert data["trial_duration_days"] is None @pytest.mark.asyncio async def test_create_invite_with_trial( self, client: AsyncClient, admin_auth_headers: dict ): response = await client.post( "/api/v1/invites", json={"assigned_plan": "pro", "trial_duration_days": 14}, headers=admin_auth_headers, ) assert response.status_code == 201 data = response.json() assert data["assigned_plan"] == "pro" assert data["trial_duration_days"] == 14 assert data["has_trial"] is True @pytest.mark.asyncio async def test_create_invite_with_email( self, client: AsyncClient, admin_auth_headers: dict ): response = await client.post( "/api/v1/invites", json={"assigned_plan": "enterprise", "email": "beta@example.com"}, headers=admin_auth_headers, ) assert response.status_code == 201 data = response.json() assert data["email"] == "beta@example.com" # Email not sent because RESEND_API_KEY not configured assert data["email_sent"] is False @pytest.mark.asyncio async def test_free_plan_rejects_trial( self, client: AsyncClient, admin_auth_headers: dict ): response = await client.post( "/api/v1/invites", json={"assigned_plan": "free", "trial_duration_days": 14}, headers=admin_auth_headers, ) assert response.status_code == 422 @pytest.mark.asyncio async def test_trial_duration_bounds( self, client: AsyncClient, admin_auth_headers: dict ): # Too low response = await client.post( "/api/v1/invites", json={"assigned_plan": "pro", "trial_duration_days": 0}, headers=admin_auth_headers, ) assert response.status_code == 422 # Too high response = await client.post( "/api/v1/invites", json={"assigned_plan": "pro", "trial_duration_days": 91}, headers=admin_auth_headers, ) assert response.status_code == 422 @pytest.mark.asyncio async def test_default_plan_is_free( self, client: AsyncClient, admin_auth_headers: dict ): response = await client.post( "/api/v1/invites", json={}, headers=admin_auth_headers, ) assert response.status_code == 201 assert response.json()["assigned_plan"] == "free" class TestRegistrationWithInvitePlan: """Test that registration applies invite code plan/trial to subscription.""" @pytest.mark.asyncio async def test_register_with_pro_trial_invite( self, client: AsyncClient, admin_auth_headers: dict, test_db: AsyncSession ): # Create a pro trial invite resp = await client.post( "/api/v1/invites", json={"assigned_plan": "pro", "trial_duration_days": 14}, headers=admin_auth_headers, ) code = resp.json()["code"] # Register with the invite code reg_resp = await client.post( "/api/v1/auth/register", json={ "email": "trial_user@example.com", "password": "SecurePass1", "name": "Trial User", "invite_code": code, }, ) assert reg_resp.status_code == 201 user_id = reg_resp.json()["id"] # Check subscription user = (await test_db.execute( select(User).where(User.id == user_id) )).scalar_one() sub = (await test_db.execute( select(Subscription).where(Subscription.account_id == user.account_id) )).scalar_one() assert sub.plan == "pro" assert sub.status == "trialing" assert sub.current_period_end is not None assert sub.current_period_end > datetime.now(timezone.utc) @pytest.mark.asyncio async def test_register_with_team_no_trial( self, client: AsyncClient, admin_auth_headers: dict, test_db: AsyncSession ): # Create team invite without trial resp = await client.post( "/api/v1/invites", json={"assigned_plan": "enterprise"}, headers=admin_auth_headers, ) code = resp.json()["code"] reg_resp = await client.post( "/api/v1/auth/register", json={ "email": "team_user@example.com", "password": "SecurePass1", "name": "Team User", "invite_code": code, }, ) assert reg_resp.status_code == 201 user_id = reg_resp.json()["id"] user = (await test_db.execute( select(User).where(User.id == user_id) )).scalar_one() sub = (await test_db.execute( select(Subscription).where(Subscription.account_id == user.account_id) )).scalar_one() assert sub.plan == "enterprise" assert sub.status == "active" class TestAdminSubscriptionManagement: """Test admin subscription plan change and trial extension endpoints.""" @pytest.mark.asyncio async def test_change_user_plan( self, client: AsyncClient, admin_auth_headers: dict, test_user: dict ): user_id = test_user["user_data"]["id"] response = await client.put( f"/api/v1/admin/users/{user_id}/subscription/plan", json={"plan": "pro"}, headers=admin_auth_headers, ) assert response.status_code == 200 assert response.json()["plan"] == "pro" @pytest.mark.asyncio async def test_extend_trial( self, client: AsyncClient, admin_auth_headers: dict, test_user: dict ): user_id = test_user["user_data"]["id"] response = await client.put( f"/api/v1/admin/users/{user_id}/subscription/extend-trial", json={"days": 14}, headers=admin_auth_headers, ) assert response.status_code == 200 data = response.json() assert data["status"] == "trialing" assert data["current_period_end"] is not None @pytest.mark.asyncio async def test_enriched_user_detail( self, client: AsyncClient, admin_auth_headers: dict, test_user: dict ): user_id = test_user["user_data"]["id"] response = await client.get( f"/api/v1/admin/users/{user_id}", headers=admin_auth_headers, ) assert response.status_code == 200 data = response.json() # Should have enriched fields assert "subscription" in data assert "account" in data assert "recent_sessions" in data assert "total_sessions" in data assert "recent_audit_logs" in data assert "total_audit_logs" in data