Code review fixes for feature flags endpoint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
779 lines
27 KiB
Python
779 lines
27 KiB
Python
import secrets
|
|
import string
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Annotated
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update as sa_update
|
|
from app.core.config import settings
|
|
from app.core.settings_manager import SettingsManager
|
|
from app.core.database import get_db
|
|
from app.core.rate_limit import limiter
|
|
from app.core.security import (
|
|
verify_password,
|
|
get_password_hash,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
create_password_reset_token,
|
|
create_email_verification_token,
|
|
decode_token,
|
|
hash_token,
|
|
)
|
|
from app.models.user import User
|
|
from app.models.invite_code import InviteCode
|
|
from app.models.refresh_token import RefreshToken
|
|
from app.models.account import Account
|
|
from app.models.subscription import Subscription
|
|
from app.models.account_invite import AccountInvite
|
|
from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride
|
|
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
|
|
from app.schemas.token import Token
|
|
from app.schemas.auth_password import (
|
|
ChangePasswordRequest,
|
|
ForgotPasswordRequest,
|
|
VerifyResetTokenRequest,
|
|
VerifyResetTokenResponse,
|
|
ResetPasswordRequest,
|
|
)
|
|
from app.models.password_reset_token import PasswordResetToken
|
|
from app.models.email_verification_token import EmailVerificationToken
|
|
from app.core.email import EmailService
|
|
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
|
from app.core.audit import log_audit
|
|
|
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
|
|
|
|
|
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
|
|
"""Decode a refresh token JWT and store its hash in the database."""
|
|
payload = decode_token(refresh_token_str)
|
|
if payload and payload.get("jti"):
|
|
token_record = RefreshToken(
|
|
token_hash=hash_token(payload["jti"]),
|
|
user_id=user_id,
|
|
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
)
|
|
db.add(token_record)
|
|
|
|
|
|
def _generate_display_code() -> str:
|
|
"""Generate a random 8-character alphanumeric display code."""
|
|
chars = string.ascii_uppercase + string.digits
|
|
return ''.join(secrets.choice(chars) for _ in range(8))
|
|
|
|
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
@limiter.limit("3/minute")
|
|
async def register(
|
|
request: Request,
|
|
user_data: UserCreate,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Register a new user.
|
|
|
|
Supports two flows:
|
|
- account_invite_code: Join an existing account (bypasses platform invite gate)
|
|
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
|
|
|
|
After user creation, if no account invite was used, a personal Account
|
|
and free Subscription are created automatically.
|
|
"""
|
|
# Check for account invite code FIRST — bypasses platform invite gate.
|
|
# SELECT FOR UPDATE prevents two concurrent registrations from both
|
|
# reading the same invite as unused and double-spending it.
|
|
account_invite_record = None
|
|
if user_data.account_invite_code:
|
|
result = await db.execute(
|
|
select(AccountInvite)
|
|
.where(AccountInvite.code == user_data.account_invite_code)
|
|
.with_for_update()
|
|
)
|
|
account_invite_record = result.scalar_one_or_none()
|
|
|
|
if not account_invite_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid account invite code"
|
|
)
|
|
|
|
if account_invite_record.is_used:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Account invite code has already been used"
|
|
)
|
|
|
|
if account_invite_record.is_expired:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Account invite code has expired"
|
|
)
|
|
|
|
# Validate platform invite code (skip if account invite was provided)
|
|
invite_code_record = None
|
|
if not account_invite_record:
|
|
if settings.REQUIRE_INVITE_CODE and not user_data.invite_code:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code is required"
|
|
)
|
|
|
|
if user_data.invite_code:
|
|
# Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE.
|
|
# FOR UPDATE prevents double-spend by concurrent registrations.
|
|
result = await db.execute(
|
|
select(InviteCode)
|
|
.where(InviteCode.code == user_data.invite_code.upper())
|
|
.with_for_update()
|
|
)
|
|
invite_code_record = result.scalar_one_or_none()
|
|
|
|
if not invite_code_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid invite code"
|
|
)
|
|
|
|
if invite_code_record.is_used:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code has already been used"
|
|
)
|
|
|
|
if invite_code_record.is_expired:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code has expired"
|
|
)
|
|
|
|
# Check if email already exists
|
|
result = await db.execute(select(User).where(User.email == user_data.email))
|
|
existing_user = result.scalar_one_or_none()
|
|
if existing_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email already registered"
|
|
)
|
|
|
|
if account_invite_record:
|
|
# Join existing account via account invite
|
|
new_user = User(
|
|
email=user_data.email,
|
|
password_hash=get_password_hash(user_data.password),
|
|
name=user_data.name,
|
|
role="engineer",
|
|
invite_code_id=invite_code_record.id if invite_code_record else None,
|
|
account_id=account_invite_record.account_id,
|
|
account_role=account_invite_record.role,
|
|
)
|
|
db.add(new_user)
|
|
await db.flush()
|
|
|
|
# Mark account invite as used
|
|
account_invite_record.accepted_by_id = new_user.id
|
|
account_invite_record.used_at = datetime.now(timezone.utc)
|
|
else:
|
|
# Create personal Account first (user needs account_id for NOT NULL constraint)
|
|
new_account = Account(
|
|
name=f"{user_data.name}'s Account",
|
|
display_code=_generate_display_code(),
|
|
)
|
|
db.add(new_account)
|
|
await db.flush() # Get account ID
|
|
|
|
new_user = User(
|
|
email=user_data.email,
|
|
password_hash=get_password_hash(user_data.password),
|
|
name=user_data.name,
|
|
role="engineer",
|
|
invite_code_id=invite_code_record.id if invite_code_record else None,
|
|
account_id=new_account.id,
|
|
account_role="owner",
|
|
)
|
|
db.add(new_user)
|
|
await db.flush() # Get user ID
|
|
|
|
# Now set account owner and create subscription
|
|
new_account.owner_id = new_user.id
|
|
|
|
# Apply plan/trial from invite code if present
|
|
sub_plan = "free"
|
|
sub_status = "active"
|
|
period_start = None
|
|
period_end = None
|
|
if invite_code_record and invite_code_record.assigned_plan:
|
|
sub_plan = invite_code_record.assigned_plan
|
|
if invite_code_record.trial_duration_days:
|
|
sub_status = "trialing"
|
|
period_start = datetime.now(timezone.utc)
|
|
period_end = period_start + timedelta(days=invite_code_record.trial_duration_days)
|
|
|
|
new_subscription = Subscription(
|
|
account_id=new_account.id,
|
|
plan=sub_plan,
|
|
status=sub_status,
|
|
current_period_start=period_start,
|
|
current_period_end=period_end,
|
|
)
|
|
db.add(new_subscription)
|
|
|
|
# Mark platform invite code as used
|
|
if invite_code_record:
|
|
invite_code_record.used_by_id = new_user.id
|
|
invite_code_record.used_at = datetime.now(timezone.utc)
|
|
|
|
await db.commit()
|
|
await db.refresh(new_user)
|
|
|
|
return new_user
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
@limiter.limit("5/minute")
|
|
async def login(
|
|
request: Request,
|
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Login and get access token."""
|
|
# Find user by email
|
|
result = await db.execute(select(User).where(User.email == form_data.username))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(form_data.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Update last login
|
|
user.last_login = datetime.now(timezone.utc)
|
|
|
|
# Create tokens
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
# Store refresh token hash in DB
|
|
await _store_refresh_token(db, refresh_token_str, user.id)
|
|
await db.commit()
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token_str,
|
|
token_type="bearer",
|
|
must_change_password=user.must_change_password,
|
|
)
|
|
|
|
|
|
@router.post("/login/json", response_model=Token)
|
|
@limiter.limit("5/minute")
|
|
async def login_json(
|
|
request: Request,
|
|
credentials: UserLogin,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Login with JSON body (alternative to form data)."""
|
|
result = await db.execute(select(User).where(User.email == credentials.email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(credentials.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password"
|
|
)
|
|
|
|
user.last_login = datetime.now(timezone.utc)
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
# Store refresh token hash in DB
|
|
await _store_refresh_token(db, refresh_token_str, user.id)
|
|
await db.commit()
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token_str,
|
|
token_type="bearer",
|
|
must_change_password=user.must_change_password,
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=Token)
|
|
@limiter.limit("10/minute")
|
|
async def refresh_token(
|
|
request: Request,
|
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
|
user_id = payload.get("sub")
|
|
jti = payload.get("jti")
|
|
|
|
# Atomically revoke the old refresh token (token rotation).
|
|
# Using a conditional UPDATE prevents the race where two concurrent
|
|
# refresh requests both read revoked_at=NULL and both succeed.
|
|
if jti:
|
|
token_hash = hash_token(jti)
|
|
result = await db.execute(
|
|
sa_update(RefreshToken)
|
|
.where(
|
|
RefreshToken.token_hash == token_hash,
|
|
RefreshToken.revoked_at.is_(None),
|
|
)
|
|
.values(revoked_at=datetime.now(timezone.utc))
|
|
.returning(RefreshToken.id, RefreshToken.user_id)
|
|
)
|
|
revoked_row = result.fetchone()
|
|
|
|
if not revoked_row:
|
|
# Either the token doesn't exist or was already revoked/used
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Refresh token has been revoked"
|
|
)
|
|
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found"
|
|
)
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
# Store new refresh token
|
|
await _store_refresh_token(db, new_refresh_token_str, user.id)
|
|
await db.commit()
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=new_refresh_token_str,
|
|
token_type="bearer"
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_me(
|
|
current_user: Annotated[User, Depends(get_current_active_user)]
|
|
):
|
|
"""Get current authenticated user."""
|
|
return current_user
|
|
|
|
|
|
@router.patch("/me", response_model=UserResponse)
|
|
async def update_me(
|
|
data: UserUpdate,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Update current user's profile (name, email)."""
|
|
update_fields = data.model_fields_set - {"current_password"}
|
|
if not update_fields:
|
|
return current_user
|
|
|
|
# Email change requires current_password
|
|
if "email" in data.model_fields_set:
|
|
if not data.current_password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Current password is required to change email"
|
|
)
|
|
if not verify_password(data.current_password, current_user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Current password is incorrect"
|
|
)
|
|
# Check uniqueness
|
|
result = await db.execute(
|
|
select(User).where(User.email == data.email, User.id != current_user.id)
|
|
)
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email already registered"
|
|
)
|
|
current_user.email = data.email
|
|
|
|
if "name" in data.model_fields_set and data.name is not None:
|
|
current_user.name = data.name
|
|
|
|
# Handle simple string profile fields
|
|
for field in ("phone", "job_title", "timezone"):
|
|
if field in data.model_fields_set:
|
|
setattr(current_user, field, getattr(data, field))
|
|
|
|
await log_audit(db, current_user.id, "auth.profile_update", "user", current_user.id)
|
|
await db.commit()
|
|
await db.refresh(current_user)
|
|
return current_user
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout(
|
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Logout user by revoking the refresh token."""
|
|
jti = payload.get("jti")
|
|
if jti:
|
|
token_hash = hash_token(jti)
|
|
result = await db.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
stored_token = result.scalar_one_or_none()
|
|
if stored_token and not stored_token.is_revoked:
|
|
stored_token.revoked_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
return {"message": "Successfully logged out"}
|
|
|
|
|
|
@router.post("/password/change")
|
|
@limiter.limit("5/minute")
|
|
async def change_password(
|
|
request: Request,
|
|
data: ChangePasswordRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Change the current user's password."""
|
|
if not verify_password(data.current_password, current_user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Current password is incorrect"
|
|
)
|
|
|
|
if data.current_password == data.new_password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="New password must be different from current password"
|
|
)
|
|
|
|
current_user.password_hash = get_password_hash(data.new_password)
|
|
current_user.must_change_password = False
|
|
|
|
# Revoke all refresh tokens for this user
|
|
result = await db.execute(
|
|
select(RefreshToken).where(
|
|
RefreshToken.user_id == current_user.id,
|
|
RefreshToken.revoked_at.is_(None)
|
|
)
|
|
)
|
|
active_tokens = result.scalars().all()
|
|
for token in active_tokens:
|
|
token.revoked_at = datetime.now(timezone.utc)
|
|
|
|
await log_audit(db, current_user.id, "auth.password_change", "user", current_user.id)
|
|
await db.commit()
|
|
|
|
return {"message": "Password changed successfully"}
|
|
|
|
|
|
@router.post("/password/forgot")
|
|
@limiter.limit("3/minute")
|
|
async def forgot_password(
|
|
request: Request,
|
|
data: ForgotPasswordRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
|
result = await db.execute(select(User).where(User.email == data.email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user:
|
|
# Create reset token JWT
|
|
raw_token = create_password_reset_token(str(user.id))
|
|
payload = decode_token(raw_token)
|
|
if payload and payload.get("jti"):
|
|
token_record = PasswordResetToken(
|
|
token_hash=hash_token(payload["jti"]),
|
|
user_id=user.id,
|
|
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
)
|
|
db.add(token_record)
|
|
await db.commit()
|
|
|
|
# Send email (best-effort)
|
|
reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}"
|
|
await EmailService.send_password_reset_email(
|
|
to_email=user.email,
|
|
reset_url=reset_url,
|
|
)
|
|
|
|
await log_audit(db, user.id, "auth.password_reset.request", "user", user.id)
|
|
await db.commit()
|
|
|
|
return {"message": "If an account with that email exists, a reset link has been sent."}
|
|
|
|
|
|
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
|
async def verify_reset_token(
|
|
data: VerifyResetTokenRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Verify a password reset token is valid."""
|
|
payload = decode_token(data.token)
|
|
if not payload or payload.get("type") != "password_reset":
|
|
return VerifyResetTokenResponse(valid=False)
|
|
|
|
jti = payload.get("jti")
|
|
if not jti:
|
|
return VerifyResetTokenResponse(valid=False)
|
|
|
|
result = await db.execute(
|
|
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
|
|
)
|
|
token_record = result.scalar_one_or_none()
|
|
|
|
if not token_record or not token_record.is_valid:
|
|
return VerifyResetTokenResponse(valid=False)
|
|
|
|
# Get user email for display
|
|
user_result = await db.execute(select(User.email).where(User.id == token_record.user_id))
|
|
email = user_result.scalar_one_or_none()
|
|
|
|
return VerifyResetTokenResponse(valid=True, email=email)
|
|
|
|
|
|
@router.post("/password/reset")
|
|
@limiter.limit("5/minute")
|
|
async def reset_password(
|
|
request: Request,
|
|
data: ResetPasswordRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Reset password using a valid reset token."""
|
|
payload = decode_token(data.token)
|
|
if not payload or payload.get("type") != "password_reset":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid or expired reset token"
|
|
)
|
|
|
|
jti = payload.get("jti")
|
|
user_id = payload.get("sub")
|
|
if not jti or not user_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid reset token"
|
|
)
|
|
|
|
# Validate token in DB (single-use).
|
|
# FOR UPDATE prevents two concurrent reset requests from both succeeding.
|
|
result = await db.execute(
|
|
select(PasswordResetToken)
|
|
.where(PasswordResetToken.token_hash == hash_token(jti))
|
|
.with_for_update()
|
|
)
|
|
token_record = result.scalar_one_or_none()
|
|
|
|
if not token_record or not token_record.is_valid:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Reset token has already been used or has expired"
|
|
)
|
|
|
|
# Get user
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid reset token"
|
|
)
|
|
|
|
# Update password
|
|
user.password_hash = get_password_hash(data.new_password)
|
|
user.must_change_password = False
|
|
|
|
# Mark token as used
|
|
token_record.used_at = datetime.now(timezone.utc)
|
|
|
|
# Revoke all refresh tokens
|
|
rt_result = await db.execute(
|
|
select(RefreshToken).where(
|
|
RefreshToken.user_id == user.id,
|
|
RefreshToken.revoked_at.is_(None)
|
|
)
|
|
)
|
|
for rt in rt_result.scalars().all():
|
|
rt.revoked_at = datetime.now(timezone.utc)
|
|
|
|
await log_audit(db, user.id, "auth.password_reset.complete", "user", user.id)
|
|
await db.commit()
|
|
|
|
return {"message": "Password has been reset successfully"}
|
|
|
|
|
|
@router.get("/email/verification-status")
|
|
async def get_verification_status(
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Check if email verification is enabled on the platform."""
|
|
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
|
return {"enabled": enabled}
|
|
|
|
|
|
@router.post("/email/send-verification")
|
|
@limiter.limit("3/minute")
|
|
async def send_verification_email(
|
|
request: Request,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Send an email verification link to the current user."""
|
|
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
|
if not verification_enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Email verification is currently disabled"
|
|
)
|
|
|
|
if current_user.email_verified_at is not None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email is already verified"
|
|
)
|
|
|
|
raw_token = create_email_verification_token(str(current_user.id))
|
|
payload = decode_token(raw_token)
|
|
if payload and payload.get("jti"):
|
|
token_record = EmailVerificationToken(
|
|
token_hash=hash_token(payload["jti"]),
|
|
user_id=current_user.id,
|
|
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
)
|
|
db.add(token_record)
|
|
await db.commit()
|
|
|
|
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
|
|
await EmailService.send_email_verification_email(
|
|
to_email=current_user.email,
|
|
verification_url=verification_url,
|
|
)
|
|
|
|
return {"message": "Verification email sent"}
|
|
|
|
|
|
@router.post("/email/verify")
|
|
async def verify_email(
|
|
data: dict,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Verify an email using a token. Public endpoint."""
|
|
token = data.get("token")
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Token is required"
|
|
)
|
|
|
|
payload = decode_token(token)
|
|
if not payload or payload.get("type") != "email_verification":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid or expired verification token"
|
|
)
|
|
|
|
jti = payload.get("jti")
|
|
user_id = payload.get("sub")
|
|
if not jti or not user_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid verification token"
|
|
)
|
|
|
|
# FOR UPDATE prevents two concurrent verification requests from both succeeding.
|
|
result = await db.execute(
|
|
select(EmailVerificationToken)
|
|
.where(EmailVerificationToken.token_hash == hash_token(jti))
|
|
.with_for_update()
|
|
)
|
|
token_record = result.scalar_one_or_none()
|
|
|
|
if not token_record or not token_record.is_valid:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Verification token has already been used or has expired"
|
|
)
|
|
|
|
# Mark token as used
|
|
token_record.used_at = datetime.now(timezone.utc)
|
|
|
|
# Mark user email as verified
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid verification token"
|
|
)
|
|
|
|
user.email_verified_at = datetime.now(timezone.utc)
|
|
await log_audit(db, user.id, "auth.email_verified", "user", user.id)
|
|
await db.commit()
|
|
|
|
return {"message": "Email verified successfully"}
|
|
|
|
|
|
@router.get("/me/feature-flags", response_model=dict[str, bool])
|
|
async def get_my_feature_flags(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> dict[str, bool]:
|
|
"""Resolve feature flags for the current user's account and plan."""
|
|
plan = "free"
|
|
if current_user.account_id:
|
|
sub_result = await db.execute(
|
|
select(Subscription).where(
|
|
Subscription.account_id == current_user.account_id,
|
|
Subscription.status.in_(["active", "trialing"]),
|
|
)
|
|
)
|
|
sub = sub_result.scalar_one_or_none()
|
|
if sub:
|
|
plan = sub.plan
|
|
|
|
flags_result = await db.execute(select(FeatureFlag))
|
|
flags = flags_result.scalars().all()
|
|
|
|
if not flags:
|
|
return {}
|
|
|
|
flag_ids = [f.id for f in flags]
|
|
|
|
defaults_result = await db.execute(
|
|
select(PlanFeatureDefault).where(
|
|
PlanFeatureDefault.flag_id.in_(flag_ids),
|
|
PlanFeatureDefault.plan == plan,
|
|
)
|
|
)
|
|
plan_defaults = {d.flag_id: d.enabled for d in defaults_result.scalars().all()}
|
|
|
|
overrides: dict[uuid.UUID, bool] = {}
|
|
if current_user.account_id:
|
|
overrides_result = await db.execute(
|
|
select(AccountFeatureOverride).where(
|
|
AccountFeatureOverride.flag_id.in_(flag_ids),
|
|
AccountFeatureOverride.account_id == current_user.account_id,
|
|
)
|
|
)
|
|
overrides = {o.flag_id: o.enabled for o in overrides_result.scalars().all()}
|
|
|
|
resolved = {}
|
|
for flag in flags:
|
|
if flag.id in overrides:
|
|
resolved[flag.flag_key] = overrides[flag.id]
|
|
elif flag.id in plan_defaults:
|
|
resolved[flag.flag_key] = plan_defaults[flag.id]
|
|
else:
|
|
resolved[flag.flag_key] = False
|
|
|
|
return resolved
|