diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py index 85dde423..0e0f457f 100644 --- a/backend/app/api/endpoints/accounts.py +++ b/backend/app/api/endpoints/accounts.py @@ -23,8 +23,9 @@ from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCre from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails from app.schemas.user import UserResponse, AccountRoleUpdate from app.core.security import verify_password -from app.api.deps import get_current_active_user, require_account_owner -from app.services.seat_enforcement import check_seat_available +from app.api.deps import get_current_active_user, require_account_owner, require_engineer_or_admin +from app.services.seat_enforcement import check_seat_available, get_seat_usage +from app.schemas.seat_enforcement import SeatUsage _SEAT_CHECKED_ROLES = frozenset({"engineer", "l1_tech"}) @@ -128,6 +129,41 @@ async def get_my_members( return result.scalars().all() +@router.get("/me/seats", response_model=SeatUsage) +async def get_my_account_seat_usage( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_engineer_or_admin)], +): + """Returns engineer + l1_tech seat-usage counts. Accessible to engineer+. + + Powers the SeatCounterWidget on admin/users and account/users surfaces. + """ + account = await _load_account(db, current_user.account_id) + sub = await get_account_subscription(current_user.account_id, db) + if sub is None: + # No subscription → treat as unlimited; return live counts with no limit + from sqlalchemy import func + engineer_count = (await db.execute( + select(func.count(User.id)) + .where(User.account_id == account.id) + .where(User.account_role == "engineer") + .where(User.is_active.is_(True)) + )).scalar_one() + l1_count = (await db.execute( + select(func.count(User.id)) + .where(User.account_id == account.id) + .where(User.account_role == "l1_tech") + .where(User.is_active.is_(True)) + )).scalar_one() + from app.schemas.seat_enforcement import SeatCheckResult + return SeatUsage( + engineer=SeatCheckResult(available=True, current=engineer_count, limit=None, role="engineer"), + l1_tech=SeatCheckResult(available=True, current=l1_count, limit=None, role="l1_tech"), + ) + engineer, l1_tech = await get_seat_usage(account, sub, db) + return SeatUsage(engineer=engineer, l1_tech=l1_tech) + + @router.patch("/me", response_model=AccountResponse) async def update_my_account( data: AccountUpdate, diff --git a/backend/tests/test_invite_seat_enforcement.py b/backend/tests/test_invite_seat_enforcement.py index 11609bfa..5d002775 100644 --- a/backend/tests/test_invite_seat_enforcement.py +++ b/backend/tests/test_invite_seat_enforcement.py @@ -394,3 +394,64 @@ async def test_role_change_demotion_bypasses_seat_check(client: AsyncClient, tes ) assert resp.status_code == 200, resp.text assert resp.json()["account_role"] == "viewer" + + +# --------------------------------------------------------------------------- +# GET /me/seats — seat counter widget endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_seats_returns_both_role_counts(client: AsyncClient, test_db: AsyncSession): + """GET /accounts/me/seats returns engineer + l1_tech seat usage.""" + owner = await _register(client, email="owner_seats@example.com") + account_id = uuid.UUID(owner["account_id"]) + headers = await _login(client, email="owner_seats@example.com") + await _set_sub(test_db, account_id, seat_limit=5, l1_seat_limit=3) + # Add 2 engineers and 1 l1_tech as members + for i in range(2): + await _add_member(test_db, account_id, role="engineer", suffix=f"e{i}") + await _add_member(test_db, account_id, role="l1_tech", suffix="l1") + + resp = await client.get("/api/v1/accounts/me/seats", headers=headers) + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["engineer"]["role"] == "engineer" + assert body["engineer"]["current"] == 2 + assert body["engineer"]["limit"] == 5 + assert body["engineer"]["available"] is True + assert body["l1_tech"]["role"] == "l1_tech" + assert body["l1_tech"]["current"] == 1 + assert body["l1_tech"]["limit"] == 3 + assert body["l1_tech"]["available"] is True + + +@pytest.mark.asyncio +async def test_get_seats_blocked_for_viewer(client: AsyncClient, test_db: AsyncSession): + """GET /accounts/me/seats → 403 for viewer role (engineer+ required).""" + from app.core.security import get_password_hash + + # Register an owner for the account + owner = await _register(client, email="owner_seats2@example.com") + account_id = uuid.UUID(owner["account_id"]) + await _set_sub(test_db, account_id, seat_limit=5, l1_seat_limit=3) + + # Create a viewer user with a known password directly in the DB + viewer_password = "ViewerPass123!" + viewer = User( + id=uuid.uuid4(), + email="viewer_seats@example.com", + name="Viewer Seats", + account_id=account_id, + account_role="viewer", + role="engineer", # system role field (default) + is_active=True, + password_hash=get_password_hash(viewer_password), + ) + test_db.add(viewer) + await test_db.commit() + + # Log in as the viewer + viewer_headers = await _login(client, email="viewer_seats@example.com", password=viewer_password) + + resp = await client.get("/api/v1/accounts/me/seats", headers=viewer_headers) + assert resp.status_code == 403, resp.text