feat: add GET /auth/me/feature-flags resolution endpoint
Resolves feature flags for the current user using: account override > plan default > false Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from app.models.refresh_token import RefreshToken
|
|||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.subscription import Subscription
|
from app.models.subscription import Subscription
|
||||||
from app.models.account_invite import AccountInvite
|
from app.models.account_invite import AccountInvite
|
||||||
|
from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride
|
||||||
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
|
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
|
||||||
from app.schemas.token import Token
|
from app.schemas.token import Token
|
||||||
from app.schemas.auth_password import (
|
from app.schemas.auth_password import (
|
||||||
@@ -718,3 +719,59 @@ async def verify_email(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return {"message": "Email verified successfully"}
|
return {"message": "Email verified successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/feature-flags")
|
||||||
|
async def get_my_feature_flags(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Resolve feature flags for the current user's account and plan."""
|
||||||
|
plan = "free"
|
||||||
|
if current_user.account_id:
|
||||||
|
sub_result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.account_id == current_user.account_id,
|
||||||
|
Subscription.status.in_(["active", "trialing"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = sub_result.scalar_one_or_none()
|
||||||
|
if sub:
|
||||||
|
plan = sub.plan
|
||||||
|
|
||||||
|
flags_result = await db.execute(select(FeatureFlag))
|
||||||
|
flags = flags_result.scalars().all()
|
||||||
|
|
||||||
|
if not flags:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
flag_ids = [f.id for f in flags]
|
||||||
|
|
||||||
|
defaults_result = await db.execute(
|
||||||
|
select(PlanFeatureDefault).where(
|
||||||
|
PlanFeatureDefault.flag_id.in_(flag_ids),
|
||||||
|
PlanFeatureDefault.plan == plan,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
plan_defaults = {d.flag_id: d.enabled for d in defaults_result.scalars().all()}
|
||||||
|
|
||||||
|
overrides: dict = {}
|
||||||
|
if current_user.account_id:
|
||||||
|
overrides_result = await db.execute(
|
||||||
|
select(AccountFeatureOverride).where(
|
||||||
|
AccountFeatureOverride.flag_id.in_(flag_ids),
|
||||||
|
AccountFeatureOverride.account_id == current_user.account_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
overrides = {o.flag_id: o.enabled for o in overrides_result.scalars().all()}
|
||||||
|
|
||||||
|
resolved = {}
|
||||||
|
for flag in flags:
|
||||||
|
if flag.id in overrides:
|
||||||
|
resolved[flag.flag_key] = overrides[flag.id]
|
||||||
|
elif flag.id in plan_defaults:
|
||||||
|
resolved[flag.flag_key] = plan_defaults[flag.id]
|
||||||
|
else:
|
||||||
|
resolved[flag.flag_key] = False
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
107
backend/tests/test_feature_flags_resolution.py
Normal file
107
backend/tests/test_feature_flags_resolution.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Integration tests for feature flag resolution endpoint."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
async def _seed_feature_flag(db: AsyncSession, flag_key: str, display_name: str):
|
||||||
|
"""Insert a feature flag and return its id."""
|
||||||
|
result = await db.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO feature_flags (id, flag_key, display_name) "
|
||||||
|
"VALUES (gen_random_uuid(), :key, :name) RETURNING id"
|
||||||
|
),
|
||||||
|
{"key": flag_key, "name": display_name},
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
async def _seed_plan_default(db: AsyncSession, flag_id, plan: str, enabled: bool):
|
||||||
|
"""Insert a plan default for a flag."""
|
||||||
|
await db.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) "
|
||||||
|
"VALUES (gen_random_uuid(), :plan, :flag_id, :enabled)"
|
||||||
|
),
|
||||||
|
{"plan": plan, "flag_id": flag_id, "enabled": enabled},
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _seed_account_override(db: AsyncSession, flag_id, account_id, enabled: bool):
|
||||||
|
"""Insert an account override for a flag."""
|
||||||
|
await db.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO account_feature_overrides (id, account_id, flag_id, enabled) "
|
||||||
|
"VALUES (gen_random_uuid(), :account_id, :flag_id, :enabled)"
|
||||||
|
),
|
||||||
|
{"account_id": account_id, "flag_id": flag_id, "enabled": enabled},
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_account_id(db: AsyncSession, user_id: str):
|
||||||
|
"""Get account_id for a user."""
|
||||||
|
result = await db.execute(
|
||||||
|
text("SELECT account_id FROM users WHERE id = :uid"),
|
||||||
|
{"uid": user_id},
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFeatureFlagResolution:
|
||||||
|
"""Tests for GET /auth/me/feature-flags."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_flags_returns_empty(self, client: AsyncClient, auth_headers: dict):
|
||||||
|
"""When no flags exist, returns empty dict."""
|
||||||
|
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_default_resolves(
|
||||||
|
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
|
||||||
|
):
|
||||||
|
"""Flag with plan default for 'free' plan resolves correctly."""
|
||||||
|
flag_id = await _seed_feature_flag(test_db, "test_feature", "Test Feature")
|
||||||
|
await _seed_plan_default(test_db, flag_id, "free", True)
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["test_feature"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_plan_default_resolves_false(
|
||||||
|
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
|
||||||
|
):
|
||||||
|
"""Flag with no plan default for user's plan resolves to false."""
|
||||||
|
flag_id = await _seed_feature_flag(test_db, "pro_only", "Pro Only")
|
||||||
|
await _seed_plan_default(test_db, flag_id, "pro", True)
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["pro_only"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_account_override_beats_plan_default(
|
||||||
|
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
|
||||||
|
):
|
||||||
|
"""Account override takes precedence over plan default."""
|
||||||
|
flag_id = await _seed_feature_flag(test_db, "overridden", "Overridden Flag")
|
||||||
|
await _seed_plan_default(test_db, flag_id, "free", False)
|
||||||
|
account_id = await _get_account_id(test_db, test_user["user_data"]["id"])
|
||||||
|
await _seed_account_override(test_db, flag_id, account_id, True)
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["overridden"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthenticated_returns_401(self, client: AsyncClient):
|
||||||
|
"""Unauthenticated request returns 401."""
|
||||||
|
response = await client.get("/api/v1/auth/me/feature-flags")
|
||||||
|
assert response.status_code == 401
|
||||||
Reference in New Issue
Block a user