From cfe0e6cae62f82021d04461f4a7a0dfc70166861 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 04:02:20 -0400 Subject: [PATCH] refactor(deps): remove trial auto-downgrade; expiry now non-mutating per spec Co-Authored-By: Claude Opus 4.7 --- backend/app/api/deps.py | 29 +++--------- ...est_get_current_active_user_no_mutation.py | 45 +++++++++++++++++++ 2 files changed, 50 insertions(+), 24 deletions(-) create mode 100644 backend/tests/test_get_current_active_user_no_mutation.py diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 79770ed9..a1ae7ea5 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -83,11 +83,12 @@ async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)], db: Annotated[AsyncSession, Depends(get_admin_db)], ) -> User: - """Ensure user is active (not disabled). Auto-downgrades expired trials. - Enforces must_change_password — blocks all routes except allowlist. + """Ensure user is active (not disabled). Enforces must_change_password — + blocks all routes except allowlist. - Uses get_admin_db: runs before require_tenant_context sets the ContextVar, - so tenant-scoped tables (subscriptions) would return 0 rows via app role. + Trial expiry enforcement now happens via require_active_subscription in + individual routers, NOT here. This dep no longer mutates Subscription + state. """ if not current_user.is_active: raise HTTPException( @@ -106,26 +107,6 @@ async def get_current_active_user( # Set Sentry user context for error attribution sentry_sdk.set_user({"id": str(current_user.id), "email": current_user.email}) - # Lightweight trial expiry check - if current_user.account_id: - from app.models.subscription import Subscription - from datetime import datetime, timezone - result = await db.execute( - select(Subscription).where(Subscription.account_id == current_user.account_id) - ) - subscription = result.scalar_one_or_none() - if ( - subscription - and subscription.status == "trialing" - and subscription.current_period_end - and subscription.current_period_end < datetime.now(timezone.utc) - ): - subscription.plan = "free" - subscription.status = "active" - subscription.current_period_end = None - subscription.current_period_start = None - await db.commit() - return current_user diff --git a/backend/tests/test_get_current_active_user_no_mutation.py b/backend/tests/test_get_current_active_user_no_mutation.py new file mode 100644 index 00000000..c2911ecb --- /dev/null +++ b/backend/tests/test_get_current_active_user_no_mutation.py @@ -0,0 +1,45 @@ +import uuid +import pytest +from datetime import datetime, timezone, timedelta +from sqlalchemy import select +from app.models.subscription import Subscription + + +@pytest.mark.asyncio +async def test_expired_trial_is_not_mutated_by_get_current_active_user( + test_db, client, test_user, auth_headers +): + """The previous deps.py:109 logic mutated trialing→active+free on expiry. + That's gone. An expired-trial Subscription should retain status='trialing' + and current_period_end after any authenticated request.""" + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + + # If a Subscription already exists for this account (e.g. created by + # the register handler), update it; otherwise insert a new one. + existing = await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + sub = existing.scalar_one_or_none() + expired_end = datetime.now(timezone.utc) - timedelta(hours=1) + if sub is None: + sub = Subscription( + account_id=account_id, + plan="pro", + status="trialing", + current_period_end=expired_end, + ) + test_db.add(sub) + else: + sub.plan = "pro" + sub.status = "trialing" + sub.current_period_end = expired_end + await test_db.commit() + + # Call any authenticated endpoint that goes through get_current_active_user. + response = await client.get("/api/v1/auth/me", headers=auth_headers) + assert response.status_code == 200 + + await test_db.refresh(sub) + assert sub.status == "trialing" + assert sub.plan == "pro" + assert sub.current_period_end is not None