Persists welcome-wizard Step 1/2/3 progress for self-serve signup Phase 2. PATCH validates step cannot decrease, ignores `data` on action="skip", and is idempotent on re-PATCH of the same step. POST /users/me/onboarding-dismiss-rest backs the wizard's "Skip the rest" button. Both routes added to _EMAIL_VERIFICATION_ALLOWLIST and _SUBSCRIPTION_GUARD_ALLOWLIST so the wizard runs before email verification and during the trial. 4 integration tests cover field writes, skip semantics, decrease guard, and dismiss-rest. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
339 lines
11 KiB
Python
339 lines
11 KiB
Python
from typing import Annotated, Optional
|
|
from uuid import UUID
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
import sentry_sdk
|
|
|
|
from app.core.database import get_db
|
|
from app.core.security import decode_token
|
|
from app.models.user import User
|
|
from app.models.plan_limits import PlanLimits
|
|
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
|
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
|
|
|
|
# Routes that are allowed even when must_change_password is True
|
|
_PASSWORD_CHANGE_ALLOWLIST = {
|
|
"/api/v1/auth/password/change",
|
|
"/api/v1/auth/logout",
|
|
"/api/v1/auth/me",
|
|
}
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|
|
|
|
|
async def get_current_user(
|
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
|
token: Annotated[str, Depends(oauth2_scheme)]
|
|
) -> User:
|
|
"""Get current authenticated user from JWT token.
|
|
|
|
Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context
|
|
sets app.current_account_id, so the users table RLS would block the lookup.
|
|
"""
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
payload = decode_token(token)
|
|
if payload is None:
|
|
raise credentials_exception
|
|
|
|
token_type = payload.get("type")
|
|
if token_type != "access":
|
|
raise credentials_exception
|
|
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
raise credentials_exception
|
|
|
|
try:
|
|
user_uuid = UUID(user_id)
|
|
except ValueError:
|
|
raise credentials_exception
|
|
|
|
result = await db.execute(select(User).where(User.id == user_uuid))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
raise credentials_exception
|
|
|
|
return user
|
|
|
|
|
|
async def get_refresh_token_payload(
|
|
token: Annotated[str, Depends(oauth2_scheme)]
|
|
) -> dict:
|
|
"""Extract and validate a refresh token from the Authorization header."""
|
|
payload = decode_token(token)
|
|
if payload is None or payload.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
return payload
|
|
|
|
|
|
async def get_current_active_user(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
|
) -> User:
|
|
"""Ensure user is active (not disabled). Enforces must_change_password —
|
|
blocks all routes except allowlist.
|
|
|
|
Trial expiry enforcement now happens via require_active_subscription in
|
|
individual routers, NOT here. This dep no longer mutates Subscription
|
|
state.
|
|
"""
|
|
if not current_user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Account has been deactivated"
|
|
)
|
|
|
|
# Enforce must_change_password (backend hard lock)
|
|
if current_user.must_change_password:
|
|
if request.url.path not in _PASSWORD_CHANGE_ALLOWLIST:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="password_change_required"
|
|
)
|
|
|
|
# Set Sentry user context for error attribution
|
|
sentry_sdk.set_user({"id": str(current_user.id), "email": current_user.email})
|
|
|
|
return current_user
|
|
|
|
|
|
async def require_admin(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""Require super admin access."""
|
|
if not current_user.is_super_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Super admin access required"
|
|
)
|
|
return current_user
|
|
|
|
|
|
async def require_engineer_or_admin(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""Require engineer, account owner, or super admin role (blocks viewers)."""
|
|
if current_user.is_super_admin:
|
|
return current_user
|
|
if current_user.account_role in ("owner", "admin", "engineer"):
|
|
return current_user
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Engineer or admin access required"
|
|
)
|
|
|
|
|
|
async def require_team_admin(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""Require team admin, account owner, or super admin role."""
|
|
if current_user.is_super_admin:
|
|
return current_user
|
|
if current_user.is_team_admin:
|
|
return current_user
|
|
if current_user.account_role == "owner":
|
|
return current_user
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Team admin access required"
|
|
)
|
|
|
|
|
|
async def require_account_owner(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
) -> User:
|
|
"""Require account owner or super admin access."""
|
|
if current_user.is_super_admin:
|
|
return current_user
|
|
if current_user.account_role == "owner":
|
|
return current_user
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Account owner access required"
|
|
)
|
|
|
|
|
|
def get_service_account_id(request: Request) -> Optional[UUID]:
|
|
"""Return the cached ResolutionFlow service account UUID from app.state.
|
|
|
|
Returns None in test environments where lifespan startup did not run.
|
|
"""
|
|
return getattr(request.app.state, "service_account_id", None)
|
|
|
|
|
|
async def get_plan_limits_for_user(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> Optional[PlanLimits]:
|
|
"""Get plan limits for the current user's account."""
|
|
from app.core.subscriptions import get_user_plan_limits
|
|
return await get_user_plan_limits(current_user.account_id, db)
|
|
|
|
|
|
async def require_tenant_context(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
"""Set per-request tenant context for RLS.
|
|
|
|
Raises 403 if the authenticated user has no account_id — never falls back
|
|
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
|
|
malformed account).
|
|
|
|
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
|
|
issue set_config('app.current_account_id', …, true) on every transaction.
|
|
|
|
Applied to every user-facing router. NOT applied to /admin/* routers or
|
|
public endpoints (auth, shared, webhooks).
|
|
"""
|
|
if current_user.account_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="User account required",
|
|
)
|
|
token = set_current_account_id(current_user.account_id)
|
|
try:
|
|
yield
|
|
finally:
|
|
clear_current_account_id(token)
|
|
|
|
|
|
async def require_admin_db(
|
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
|
current_user: Annotated[User, Depends(require_admin)],
|
|
) -> AsyncSession:
|
|
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
|
|
|
|
Use on /admin/* endpoints that query RLS-protected tables. Replaces
|
|
Depends(get_db) on the db parameter of those endpoints.
|
|
The current_user dep is still declared separately on the endpoint if
|
|
the user object is needed in the handler.
|
|
"""
|
|
return db
|
|
|
|
|
|
_SUBSCRIPTION_GUARD_ALLOWLIST = {
|
|
"/api/v1/auth/me",
|
|
"/api/v1/auth/logout",
|
|
"/api/v1/auth/password/change",
|
|
"/api/v1/auth/email/send-verification",
|
|
"/api/v1/auth/email/verify",
|
|
"/api/v1/billing/state",
|
|
"/api/v1/billing/checkout-session",
|
|
"/api/v1/billing/portal-session",
|
|
"/api/v1/users/me",
|
|
"/api/v1/users/me/onboarding-step",
|
|
"/api/v1/users/me/onboarding-dismiss-rest",
|
|
}
|
|
|
|
|
|
async def require_active_subscription(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
|
):
|
|
"""Returns the Subscription row when the account has access; raises 402
|
|
when locked. Mounted on routers requiring Pro entitlement.
|
|
|
|
'Locked' = (trialing AND current_period_end < now()) OR
|
|
(canceled OR incomplete OR no subscription).
|
|
Active states: active, complimentary, trialing-with-time-remaining, past_due.
|
|
"""
|
|
if request.url.path in _SUBSCRIPTION_GUARD_ALLOWLIST:
|
|
return None
|
|
|
|
from app.models.subscription import Subscription
|
|
from datetime import datetime, timezone
|
|
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.account_id == current_user.account_id)
|
|
)
|
|
sub = result.scalar_one_or_none()
|
|
|
|
if sub is None:
|
|
raise HTTPException(
|
|
status_code=402,
|
|
detail={"error": "no_subscription", "upgrade_url": "/account/billing/select-plan"},
|
|
)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
is_live = (
|
|
sub.status in ("active", "complimentary", "past_due")
|
|
or (
|
|
sub.status == "trialing"
|
|
and sub.current_period_end is not None
|
|
and sub.current_period_end > now
|
|
)
|
|
)
|
|
if not is_live:
|
|
raise HTTPException(
|
|
status_code=402,
|
|
detail={
|
|
"error": "subscription_inactive",
|
|
"status": sub.status,
|
|
"plan": sub.plan,
|
|
"current_period_end": sub.current_period_end.isoformat() if sub.current_period_end else None,
|
|
"upgrade_url": "/account/billing/select-plan",
|
|
},
|
|
)
|
|
|
|
return sub
|
|
|
|
|
|
_EMAIL_VERIFICATION_ALLOWLIST = {
|
|
"/api/v1/auth/me",
|
|
"/api/v1/auth/logout",
|
|
"/api/v1/auth/email/send-verification",
|
|
"/api/v1/auth/email/verify",
|
|
"/api/v1/auth/password/change",
|
|
"/api/v1/users/me",
|
|
"/api/v1/users/me/onboarding-step",
|
|
"/api/v1/users/me/onboarding-dismiss-rest",
|
|
"/api/v1/billing/state",
|
|
"/api/v1/billing/checkout-session",
|
|
"/api/v1/billing/portal-session",
|
|
}
|
|
|
|
VERIFICATION_GRACE_DAYS = 7
|
|
|
|
|
|
async def require_verified_email_after_grace(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
):
|
|
"""Enforces 'this user has verified email OR is still in 7-day grace.'
|
|
OAuth signups bypass cleanly because /auth/{google,microsoft}/callback
|
|
sets users.email_verified_at = now() (provider-attested)."""
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
if request.url.path in _EMAIL_VERIFICATION_ALLOWLIST:
|
|
return
|
|
|
|
if current_user.email_verified_at is not None:
|
|
return
|
|
|
|
grace_ends = current_user.created_at + timedelta(days=VERIFICATION_GRACE_DAYS)
|
|
if datetime.now(timezone.utc) < grace_ends:
|
|
return
|
|
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail={
|
|
"error": "email_not_verified",
|
|
"grace_ended_at": grace_ends.isoformat(),
|
|
"resend_url": "/api/v1/auth/email/send-verification",
|
|
},
|
|
)
|