diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index ed913441..75d8a6e2 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -26,6 +26,7 @@ from app.models.refresh_token import RefreshToken from app.models.account import Account from app.models.subscription import Subscription from app.models.account_invite import AccountInvite +from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate from app.schemas.token import Token from app.schemas.auth_password import ( @@ -718,3 +719,59 @@ async def verify_email( await db.commit() return {"message": "Email verified successfully"} + + +@router.get("/me/feature-flags") +async def get_my_feature_flags( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> dict[str, bool]: + """Resolve feature flags for the current user's account and plan.""" + plan = "free" + if current_user.account_id: + sub_result = await db.execute( + select(Subscription).where( + Subscription.account_id == current_user.account_id, + Subscription.status.in_(["active", "trialing"]), + ) + ) + sub = sub_result.scalar_one_or_none() + if sub: + plan = sub.plan + + flags_result = await db.execute(select(FeatureFlag)) + flags = flags_result.scalars().all() + + if not flags: + return {} + + flag_ids = [f.id for f in flags] + + defaults_result = await db.execute( + select(PlanFeatureDefault).where( + PlanFeatureDefault.flag_id.in_(flag_ids), + PlanFeatureDefault.plan == plan, + ) + ) + plan_defaults = {d.flag_id: d.enabled for d in defaults_result.scalars().all()} + + overrides: dict = {} + if current_user.account_id: + overrides_result = await db.execute( + select(AccountFeatureOverride).where( + AccountFeatureOverride.flag_id.in_(flag_ids), + AccountFeatureOverride.account_id == current_user.account_id, + ) + ) + overrides = {o.flag_id: o.enabled for o in overrides_result.scalars().all()} + + resolved = {} + for flag in flags: + if flag.id in overrides: + resolved[flag.flag_key] = overrides[flag.id] + elif flag.id in plan_defaults: + resolved[flag.flag_key] = plan_defaults[flag.id] + else: + resolved[flag.flag_key] = False + + return resolved diff --git a/backend/tests/test_feature_flags_resolution.py b/backend/tests/test_feature_flags_resolution.py new file mode 100644 index 00000000..85b07bba --- /dev/null +++ b/backend/tests/test_feature_flags_resolution.py @@ -0,0 +1,107 @@ +"""Integration tests for feature flag resolution endpoint.""" + +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import text + + +async def _seed_feature_flag(db: AsyncSession, flag_key: str, display_name: str): + """Insert a feature flag and return its id.""" + result = await db.execute( + text( + "INSERT INTO feature_flags (id, flag_key, display_name) " + "VALUES (gen_random_uuid(), :key, :name) RETURNING id" + ), + {"key": flag_key, "name": display_name}, + ) + await db.commit() + return result.scalar_one() + + +async def _seed_plan_default(db: AsyncSession, flag_id, plan: str, enabled: bool): + """Insert a plan default for a flag.""" + await db.execute( + text( + "INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) " + "VALUES (gen_random_uuid(), :plan, :flag_id, :enabled)" + ), + {"plan": plan, "flag_id": flag_id, "enabled": enabled}, + ) + await db.commit() + + +async def _seed_account_override(db: AsyncSession, flag_id, account_id, enabled: bool): + """Insert an account override for a flag.""" + await db.execute( + text( + "INSERT INTO account_feature_overrides (id, account_id, flag_id, enabled) " + "VALUES (gen_random_uuid(), :account_id, :flag_id, :enabled)" + ), + {"account_id": account_id, "flag_id": flag_id, "enabled": enabled}, + ) + await db.commit() + + +async def _get_account_id(db: AsyncSession, user_id: str): + """Get account_id for a user.""" + result = await db.execute( + text("SELECT account_id FROM users WHERE id = :uid"), + {"uid": user_id}, + ) + return result.scalar_one() + + +class TestFeatureFlagResolution: + """Tests for GET /auth/me/feature-flags.""" + + @pytest.mark.asyncio + async def test_no_flags_returns_empty(self, client: AsyncClient, auth_headers: dict): + """When no flags exist, returns empty dict.""" + response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers) + assert response.status_code == 200 + assert response.json() == {} + + @pytest.mark.asyncio + async def test_plan_default_resolves( + self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession + ): + """Flag with plan default for 'free' plan resolves correctly.""" + flag_id = await _seed_feature_flag(test_db, "test_feature", "Test Feature") + await _seed_plan_default(test_db, flag_id, "free", True) + + response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["test_feature"] is True + + @pytest.mark.asyncio + async def test_no_plan_default_resolves_false( + self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession + ): + """Flag with no plan default for user's plan resolves to false.""" + flag_id = await _seed_feature_flag(test_db, "pro_only", "Pro Only") + await _seed_plan_default(test_db, flag_id, "pro", True) + + response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["pro_only"] is False + + @pytest.mark.asyncio + async def test_account_override_beats_plan_default( + self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession + ): + """Account override takes precedence over plan default.""" + flag_id = await _seed_feature_flag(test_db, "overridden", "Overridden Flag") + await _seed_plan_default(test_db, flag_id, "free", False) + account_id = await _get_account_id(test_db, test_user["user_data"]["id"]) + await _seed_account_override(test_db, flag_id, account_id, True) + + response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["overridden"] is True + + @pytest.mark.asyncio + async def test_unauthenticated_returns_401(self, client: AsyncClient): + """Unauthenticated request returns 401.""" + response = await client.get("/api/v1/auth/me/feature-flags") + assert response.status_code == 401