From 47ff8ad2b5e3c7ea84738da0653d6e9a15481db6 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Thu, 28 May 2026 12:49:59 -0400 Subject: [PATCH] feat(l1): enforce seat limits on invite, accept-invite, role-change For engineer + l1_tech roles, check_seat_available is called at each mutation point. Returns 402 Payment Required with structured detail {code: 'seat_limit_exceeded', role, current, limit, upgrade_url} when seats are full. Grandfathering: existing over-seated accounts keep existing users; only new mutations are blocked. Also updates AccountInviteCreate and AccountRoleUpdate schemas to accept l1_tech as a valid role value. Co-Authored-By: Claude Opus 4.7 --- backend/app/api/endpoints/accounts.py | 52 +++ backend/app/api/endpoints/auth.py | 27 ++ backend/app/api/endpoints/oauth.py | 23 ++ backend/app/schemas/account.py | 2 +- backend/app/schemas/user.py | 2 +- backend/tests/test_invite_seat_enforcement.py | 363 ++++++++++++++++++ 6 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 backend/tests/test_invite_seat_enforcement.py diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py index 6b3fab83..c4c37488 100644 --- a/backend/app/api/endpoints/accounts.py +++ b/backend/app/api/endpoints/accounts.py @@ -24,10 +24,50 @@ from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, U from app.schemas.user import UserResponse, AccountRoleUpdate from app.core.security import verify_password from app.api.deps import get_current_active_user, require_account_owner +from app.services.seat_enforcement import check_seat_available + +_SEAT_CHECKED_ROLES = frozenset({"engineer", "l1_tech"}) router = APIRouter(prefix="/accounts", tags=["accounts"]) +async def _load_account(db: AsyncSession, account_id: UUID) -> Account: + """Load an Account by id; raises 404 if missing.""" + result = await db.execute(select(Account).where(Account.id == account_id)) + account = result.scalar_one_or_none() + if account is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found") + return account + + +async def _enforce_seat_limit(db: AsyncSession, account_id: UUID, role: str) -> None: + """Raise HTTP 402 if the account has no capacity for the given role. + + Only fires for seat-counted roles (engineer, l1_tech). + Accounts without a subscription (free / pre-billing) are not blocked. + Grandfathering: if current > limit, existing users keep access; this + helper only blocks new additions. + """ + if role not in _SEAT_CHECKED_ROLES: + return + sub = await get_account_subscription(account_id, db) + if sub is None: + return # no subscription → no enforcement + account = await _load_account(db, account_id) + seat_result = await check_seat_available(account, sub, role, db) + if not seat_result.available: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail={ + "code": "seat_limit_exceeded", + "role": seat_result.role, + "current": seat_result.current, + "limit": seat_result.limit, + "upgrade_url": "/account/billing", + }, + ) + + @router.get("/me", response_model=AccountResponse) async def get_my_account( db: Annotated[AsyncSession, Depends(get_db)], @@ -141,6 +181,11 @@ async def update_member_role( detail="Cannot change your own role" ) + # Seat enforcement: check capacity before promoting to a seat-counted role. + # Demotions (engineer/l1_tech → viewer) and lateral moves skip the check. + if data.account_role != user.account_role: + await _enforce_seat_limit(db, current_user.account_id, data.account_role) + user.account_role = data.account_role await db.commit() await db.refresh(user) @@ -261,6 +306,9 @@ async def create_invite( current_user: Annotated[User, Depends(require_account_owner)] ): """Create an invite to join this account (owner only). Sends invite email.""" + # Seat enforcement: block invite if the target role is at capacity. + await _enforce_seat_limit(db, current_user.account_id, data.role) + code = secrets.token_urlsafe(16) expires_at = None @@ -317,6 +365,10 @@ async def create_invites_bulk( failed: list[dict] = [] for invite_data in payload.invites: try: + # Seat enforcement per invite row — 402 bubbles as an HTTPException + # which is caught below and recorded in `failed`. + await _enforce_seat_limit(db, current_user.account_id, invite_data.role) + code = secrets.token_urlsafe(16) expires_at = None if invite_data.expires_in_days: diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index fa73d819..d6c65ee4 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -289,6 +289,33 @@ async def register( detail="Invite code has expired" ) + # Seat enforcement: re-check at accept time (race-condition guard). + # Fires only when an account invite is being accepted and the target role + # is seat-counted (engineer, l1_tech). Accounts without a subscription + # (free / pre-billing) are not blocked. + if account_invite_record and account_invite_record.role in ("engineer", "l1_tech"): + from app.core.subscriptions import get_account_subscription + from app.services.seat_enforcement import check_seat_available + from app.models.account import Account as _Account + sub = await get_account_subscription(account_invite_record.account_id, db) + if sub is not None: + acct_result = await db.execute( + select(_Account).where(_Account.id == account_invite_record.account_id) + ) + acct = acct_result.scalar_one() + seat_result = await check_seat_available(acct, sub, account_invite_record.role, db) + if not seat_result.available: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail={ + "code": "seat_limit_exceeded", + "role": seat_result.role, + "current": seat_result.current, + "limit": seat_result.limit, + "upgrade_url": "/account/billing", + }, + ) + # Check if email already exists result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = result.scalar_one_or_none() diff --git a/backend/app/api/endpoints/oauth.py b/backend/app/api/endpoints/oauth.py index 233b50b6..d4ed3962 100644 --- a/backend/app/api/endpoints/oauth.py +++ b/backend/app/api/endpoints/oauth.py @@ -118,6 +118,29 @@ async def _sign_in_or_register( if is_new_user: if invite_record is not None: + # Seat enforcement: re-check at OAuth accept time (race-condition guard). + if invite_record.role in ("engineer", "l1_tech"): + from app.core.subscriptions import get_account_subscription + from app.services.seat_enforcement import check_seat_available + sub = await get_account_subscription(invite_record.account_id, db) + if sub is not None: + acct_result = await db.execute( + select(Account).where(Account.id == invite_record.account_id) + ) + acct = acct_result.scalar_one() + seat_result = await check_seat_available(acct, sub, invite_record.role, db) + if not seat_result.available: + raise HTTPException( + status_code=402, + detail={ + "code": "seat_limit_exceeded", + "role": seat_result.role, + "current": seat_result.current, + "limit": seat_result.limit, + "upgrade_url": "/account/billing", + }, + ) + # Join the invited account directly — no personal account, no # trial creation. user = User( diff --git a/backend/app/schemas/account.py b/backend/app/schemas/account.py index 3d1e0c28..b2145b4a 100644 --- a/backend/app/schemas/account.py +++ b/backend/app/schemas/account.py @@ -27,7 +27,7 @@ class TransferOwnershipRequest(BaseModel): class AccountInviteCreate(BaseModel): email: str = Field(..., max_length=255) - role: str = Field("engineer", pattern="^(engineer|viewer)$") + role: str = Field("engineer", pattern="^(engineer|viewer|l1_tech)$") expires_in_days: Optional[int] = Field(None, ge=1, le=30) diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 81d7c8b3..044e84e2 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -72,4 +72,4 @@ class RoleUpdate(BaseModel): class AccountRoleUpdate(BaseModel): # Ownership changes must go through the explicit transfer-ownership flow so # account.owner_id stays consistent with user.account_role. - account_role: str = Field(..., pattern="^(admin|engineer|viewer)$") + account_role: str = Field(..., pattern="^(admin|engineer|viewer|l1_tech)$") diff --git a/backend/tests/test_invite_seat_enforcement.py b/backend/tests/test_invite_seat_enforcement.py new file mode 100644 index 00000000..ddd11218 --- /dev/null +++ b/backend/tests/test_invite_seat_enforcement.py @@ -0,0 +1,363 @@ +"""Integration tests for seat enforcement at invite create, accept-invite, and +role-change endpoints. + +All tests use the `client` + `test_db` fixtures from conftest, which spin up +a fresh schema per test and wire the ASGI app to the test DB. +""" + +import uuid + +import pytest +from httpx import AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.account import Account +from app.models.account_invite import AccountInvite +from app.models.subscription import Subscription +from app.models.user import User + + +# --------------------------------------------------------------------------- +# Test-local helpers +# --------------------------------------------------------------------------- + +async def _register(client: AsyncClient, *, email: str, password: str = "TestPassword123!", name: str = "Test User") -> dict: + resp = await client.post("/api/v1/auth/register", json={"email": email, "password": password, "name": name}) + assert resp.status_code in (200, 201), resp.text + return resp.json() + + +async def _login(client: AsyncClient, *, email: str, password: str = "TestPassword123!") -> dict: + resp = await client.post("/api/v1/auth/login/json", json={"email": email, "password": password}) + assert resp.status_code == 200, resp.text + return {"Authorization": f"Bearer {resp.json()['access_token']}"} + + +async def _set_sub(db: AsyncSession, account_id: uuid.UUID, *, seat_limit: int | None, l1_seat_limit: int | None = None) -> None: + """Replace the account's subscription with specified limits.""" + await db.execute(delete(Subscription).where(Subscription.account_id == account_id)) + db.add(Subscription( + account_id=account_id, + plan="pro", + status="active", + seat_limit=seat_limit, + l1_seat_limit=l1_seat_limit, + )) + await db.commit() + + +async def _add_member(db: AsyncSession, account_id: uuid.UUID, *, role: str, suffix: str | None = None) -> User: + """Directly insert an active user with the given role into the account.""" + s = suffix or str(uuid.uuid4())[:8] + user = User( + id=uuid.uuid4(), + email=f"member-{s}@example.com", + name=f"Member {s}", + account_id=account_id, + account_role=role, + role="engineer", + is_active=True, + ) + db.add(user) + await db.commit() + return user + + +# --------------------------------------------------------------------------- +# Invite create — single invite endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_invite_engineer_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession): + """POST /me/invites → 402 when engineer seat limit is exhausted.""" + owner = await _register(client, email="owner1@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner1@example.com") + + # seat_limit=1, already 1 engineer → full + await _set_sub(test_db, account_id, seat_limit=1) + # The owner registers as engineer, but is actually 'owner' role — add a separate engineer + await _add_member(test_db, account_id, role="engineer") + + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "new-eng@example.com", "role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + assert body["detail"]["role"] == "engineer" + assert body["detail"]["current"] == 1 + assert body["detail"]["limit"] == 1 + assert "upgrade_url" in body["detail"] + + +@pytest.mark.asyncio +async def test_invite_l1_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession): + """POST /me/invites → 402 when l1_tech seat limit is exhausted.""" + owner = await _register(client, email="owner2@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner2@example.com") + + await _set_sub(test_db, account_id, seat_limit=10, l1_seat_limit=1) + await _add_member(test_db, account_id, role="l1_tech") + + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "new-l1@example.com", "role": "l1_tech"}, + headers=headers, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + assert body["detail"]["role"] == "l1_tech" + assert body["detail"]["current"] == 1 + assert body["detail"]["limit"] == 1 + + +@pytest.mark.asyncio +async def test_invite_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession): + """POST /me/invites → 201 when engineer seats have room.""" + owner = await _register(client, email="owner3@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner3@example.com") + + # seat_limit=5, 0 engineers → plenty of room + await _set_sub(test_db, account_id, seat_limit=5) + + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "new-eng2@example.com", "role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 201, resp.text + + +@pytest.mark.asyncio +async def test_invite_viewer_bypasses_seat_check(client: AsyncClient, test_db: AsyncSession): + """POST /me/invites → 201 for viewer role even when engineer seats full.""" + owner = await _register(client, email="owner4@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner4@example.com") + + # engineer seats exhausted — should not affect viewer invites + await _set_sub(test_db, account_id, seat_limit=1) + await _add_member(test_db, account_id, role="engineer") + + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "viewer@example.com", "role": "viewer"}, + headers=headers, + ) + assert resp.status_code == 201, resp.text + + +@pytest.mark.asyncio +async def test_invite_unlimited_seat_limit_always_succeeds(client: AsyncClient, test_db: AsyncSession): + """POST /me/invites → 201 when seat_limit is None (unlimited).""" + owner = await _register(client, email="owner5@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner5@example.com") + + # seat_limit=None = unlimited + await _set_sub(test_db, account_id, seat_limit=None) + # add many engineers + for i in range(5): + await _add_member(test_db, account_id, role="engineer", suffix=f"bulk{i}") + + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "new-unlimited@example.com", "role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 201, resp.text + + +@pytest.mark.asyncio +async def test_invite_grandfathered_account_blocks_new_invites(client: AsyncClient, test_db: AsyncSession): + """Grandfathering: existing over-seated account keeps existing users but + new engineer invites are still blocked (current > limit → blocked).""" + owner = await _register(client, email="owner6@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner6@example.com") + + # current=3 engineers > seat_limit=2 (over-seated / grandfathered) + await _set_sub(test_db, account_id, seat_limit=2) + for i in range(3): + await _add_member(test_db, account_id, role="engineer", suffix=f"gf{i}") + + # New invite must be blocked + resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "one-more@example.com", "role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + # current (3) > limit (2) — forward enforcement fires, existing users unaffected + assert body["detail"]["current"] == 3 + assert body["detail"]["limit"] == 2 + + +# --------------------------------------------------------------------------- +# Accept-invite race condition — auth.py register path +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_accept_invite_blocked_when_seats_full_at_accept_time(client: AsyncClient, test_db: AsyncSession): + """Race-condition guard: invite created when seats available, but by + accept time someone else consumed the last seat → 402.""" + # Step 1: create an owner and send an invite + owner = await _register(client, email="owner7@example.com") + account_id = uuid.UUID(owner["account_id"]) + owner_headers = await _login(client, email="owner7@example.com") + + await _set_sub(test_db, account_id, seat_limit=2) + + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "race@example.com", "role": "engineer"}, + headers=owner_headers, + ) + assert invite_resp.status_code == 201, invite_resp.text + invite_code = invite_resp.json()["code"] + + # Step 2: fill the seats after the invite was created (race condition) + await _add_member(test_db, account_id, role="engineer", suffix="race1") + await _add_member(test_db, account_id, role="engineer", suffix="race2") + + # Step 3: invitee tries to register — should get 402 + resp = await client.post( + "/api/v1/auth/register", + json={ + "email": "race@example.com", + "password": "TestPassword123!", + "name": "Race User", + "account_invite_code": invite_code, + }, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + + +@pytest.mark.asyncio +async def test_accept_invite_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession): + """Normal accept-invite path works when seats have room.""" + owner = await _register(client, email="owner8@example.com") + account_id = uuid.UUID(owner["account_id"]) + owner_headers = await _login(client, email="owner8@example.com") + + await _set_sub(test_db, account_id, seat_limit=5) + + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "acceptme@example.com", "role": "engineer"}, + headers=owner_headers, + ) + assert invite_resp.status_code == 201, invite_resp.text + invite_code = invite_resp.json()["code"] + + resp = await client.post( + "/api/v1/auth/register", + json={ + "email": "acceptme@example.com", + "password": "TestPassword123!", + "name": "Accept User", + "account_invite_code": invite_code, + }, + ) + assert resp.status_code in (200, 201), resp.text + assert resp.json()["account_id"] == str(account_id) + + +# --------------------------------------------------------------------------- +# Role-change endpoint — PATCH /me/members/{user_id}/role +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_role_change_viewer_to_engineer_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession): + """PATCH /me/members/{id}/role → 402 when promoting viewer → engineer and seats full.""" + owner = await _register(client, email="owner9@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner9@example.com") + + await _set_sub(test_db, account_id, seat_limit=1) + # Fill the engineer seat + await _add_member(test_db, account_id, role="engineer") + # Add a viewer to promote + viewer = await _add_member(test_db, account_id, role="viewer") + + resp = await client.patch( + f"/api/v1/accounts/me/members/{viewer.id}/role", + json={"account_role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + assert body["detail"]["role"] == "engineer" + + +@pytest.mark.asyncio +async def test_role_change_viewer_to_l1_blocked_when_seats_full(client: AsyncClient, test_db: AsyncSession): + """PATCH /me/members/{id}/role → 402 when promoting viewer → l1_tech and l1 seats full.""" + owner = await _register(client, email="owner10@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner10@example.com") + + await _set_sub(test_db, account_id, seat_limit=10, l1_seat_limit=1) + await _add_member(test_db, account_id, role="l1_tech") + viewer = await _add_member(test_db, account_id, role="viewer") + + resp = await client.patch( + f"/api/v1/accounts/me/members/{viewer.id}/role", + json={"account_role": "l1_tech"}, + headers=headers, + ) + assert resp.status_code == 402, resp.text + body = resp.json() + assert body["detail"]["code"] == "seat_limit_exceeded" + assert body["detail"]["role"] == "l1_tech" + + +@pytest.mark.asyncio +async def test_role_change_promotion_succeeds_when_seats_available(client: AsyncClient, test_db: AsyncSession): + """PATCH /me/members/{id}/role → 200 when seats are available.""" + owner = await _register(client, email="owner11@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner11@example.com") + + await _set_sub(test_db, account_id, seat_limit=5) + viewer = await _add_member(test_db, account_id, role="viewer") + + resp = await client.patch( + f"/api/v1/accounts/me/members/{viewer.id}/role", + json={"account_role": "engineer"}, + headers=headers, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["account_role"] == "engineer" + + +@pytest.mark.asyncio +async def test_role_change_demotion_bypasses_seat_check(client: AsyncClient, test_db: AsyncSession): + """PATCH /me/members/{id}/role → 200 for demotions even when seats full.""" + owner = await _register(client, email="owner12@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner12@example.com") + + # Seats full — but demotion should still succeed + await _set_sub(test_db, account_id, seat_limit=1) + engineer = await _add_member(test_db, account_id, role="engineer") + + resp = await client.patch( + f"/api/v1/accounts/me/members/{engineer.id}/role", + json={"account_role": "viewer"}, + headers=headers, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["account_role"] == "viewer"