import secrets import string import uuid from datetime import datetime, timezone, timedelta from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update as sa_update from app.core.config import settings from app.core.settings_manager import SettingsManager from app.core.database import get_db from app.core.rate_limit import limiter from app.core.security import ( verify_password, get_password_hash, create_access_token, create_refresh_token, create_password_reset_token, create_email_verification_token, decode_token, hash_token, ) from app.models.user import User from app.models.invite_code import InviteCode from app.models.refresh_token import RefreshToken from app.models.account import Account from app.models.subscription import Subscription from app.models.account_invite import AccountInvite from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate from app.schemas.token import Token from app.schemas.auth_password import ( ChangePasswordRequest, ForgotPasswordRequest, VerifyResetTokenRequest, VerifyResetTokenResponse, ResetPasswordRequest, ) from app.models.password_reset_token import PasswordResetToken from app.models.email_verification_token import EmailVerificationToken from app.core.email import EmailService from app.api.deps import get_current_active_user, get_refresh_token_payload from app.core.audit import log_audit router = APIRouter(prefix="/auth", tags=["authentication"]) async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None: """Decode a refresh token JWT and store its hash in the database.""" payload = decode_token(refresh_token_str) if payload and payload.get("jti"): token_record = RefreshToken( token_hash=hash_token(payload["jti"]), user_id=user_id, expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), ) db.add(token_record) def _generate_display_code() -> str: """Generate a random 8-character alphanumeric display code.""" chars = string.ascii_uppercase + string.digits return ''.join(secrets.choice(chars) for _ in range(8)) @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @limiter.limit("3/minute") async def register( request: Request, user_data: UserCreate, db: Annotated[AsyncSession, Depends(get_db)] ): """Register a new user. Supports two flows: - account_invite_code: Join an existing account (bypasses platform invite gate) - invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled) After user creation, if no account invite was used, a personal Account and free Subscription are created automatically. """ # Check for account invite code FIRST — bypasses platform invite gate. # SELECT FOR UPDATE prevents two concurrent registrations from both # reading the same invite as unused and double-spending it. account_invite_record = None if user_data.account_invite_code: result = await db.execute( select(AccountInvite) .where(AccountInvite.code == user_data.account_invite_code) .with_for_update() ) account_invite_record = result.scalar_one_or_none() if not account_invite_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid account invite code" ) if account_invite_record.is_used: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account invite code has already been used" ) if account_invite_record.is_expired: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account invite code has expired" ) # Validate platform invite code (skip if account invite was provided) invite_code_record = None if not account_invite_record: if settings.REQUIRE_INVITE_CODE and not user_data.invite_code: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code is required" ) if user_data.invite_code: # Look up invite code (case-insensitive) — applies plan/trial regardless of REQUIRE_INVITE_CODE. # FOR UPDATE prevents double-spend by concurrent registrations. result = await db.execute( select(InviteCode) .where(InviteCode.code == user_data.invite_code.upper()) .with_for_update() ) invite_code_record = result.scalar_one_or_none() if not invite_code_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid invite code" ) if invite_code_record.is_used: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code has already been used" ) if invite_code_record.is_expired: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code has expired" ) # Check if email already exists result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) if account_invite_record: # Join existing account via account invite new_user = User( email=user_data.email, password_hash=get_password_hash(user_data.password), name=user_data.name, role="engineer", invite_code_id=invite_code_record.id if invite_code_record else None, account_id=account_invite_record.account_id, account_role=account_invite_record.role, ) db.add(new_user) await db.flush() # Mark account invite as used account_invite_record.accepted_by_id = new_user.id account_invite_record.used_at = datetime.now(timezone.utc) else: # Create personal Account first (user needs account_id for NOT NULL constraint) new_account = Account( name=f"{user_data.name}'s Account", display_code=_generate_display_code(), ) db.add(new_account) await db.flush() # Get account ID new_user = User( email=user_data.email, password_hash=get_password_hash(user_data.password), name=user_data.name, role="engineer", invite_code_id=invite_code_record.id if invite_code_record else None, account_id=new_account.id, account_role="owner", ) db.add(new_user) await db.flush() # Get user ID # Now set account owner and create subscription new_account.owner_id = new_user.id # Apply plan/trial from invite code if present sub_plan = "free" sub_status = "active" period_start = None period_end = None if invite_code_record and invite_code_record.assigned_plan: sub_plan = invite_code_record.assigned_plan if invite_code_record.trial_duration_days: sub_status = "trialing" period_start = datetime.now(timezone.utc) period_end = period_start + timedelta(days=invite_code_record.trial_duration_days) new_subscription = Subscription( account_id=new_account.id, plan=sub_plan, status=sub_status, current_period_start=period_start, current_period_end=period_end, ) db.add(new_subscription) # Mark platform invite code as used if invite_code_record: invite_code_record.used_by_id = new_user.id invite_code_record.used_at = datetime.now(timezone.utc) await db.commit() await db.refresh(new_user) return new_user @router.post("/login", response_model=Token) @limiter.limit("5/minute") async def login( request: Request, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[AsyncSession, Depends(get_db)] ): """Login and get access token.""" # Find user by email result = await db.execute(select(User).where(User.email == form_data.username)) user = result.scalar_one_or_none() if not user or not verify_password(form_data.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) # Update last login user.last_login = datetime.now(timezone.utc) # Create tokens access_token = create_access_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB await _store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer", must_change_password=user.must_change_password, ) @router.post("/login/json", response_model=Token) @limiter.limit("5/minute") async def login_json( request: Request, credentials: UserLogin, db: Annotated[AsyncSession, Depends(get_db)] ): """Login with JSON body (alternative to form data).""" result = await db.execute(select(User).where(User.email == credentials.email)) user = result.scalar_one_or_none() if not user or not verify_password(credentials.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password" ) user.last_login = datetime.now(timezone.utc) access_token = create_access_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB await _store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer", must_change_password=user.must_change_password, ) @router.post("/refresh", response_model=Token) @limiter.limit("10/minute") async def refresh_token( request: Request, payload: Annotated[dict, Depends(get_refresh_token_payload)], db: Annotated[AsyncSession, Depends(get_db)] ): """Refresh access token using refresh token (rotation: old token is revoked).""" user_id = payload.get("sub") jti = payload.get("jti") # Atomically revoke the old refresh token (token rotation). # Using a conditional UPDATE prevents the race where two concurrent # refresh requests both read revoked_at=NULL and both succeed. if jti: token_hash = hash_token(jti) result = await db.execute( sa_update(RefreshToken) .where( RefreshToken.token_hash == token_hash, RefreshToken.revoked_at.is_(None), ) .values(revoked_at=datetime.now(timezone.utc)) .returning(RefreshToken.id, RefreshToken.user_id) ) revoked_row = result.fetchone() if not revoked_row: # Either the token doesn't exist or was already revoked/used raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has been revoked" ) result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" ) access_token = create_access_token(data={"sub": str(user.id)}) new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store new refresh token await _store_refresh_token(db, new_refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=new_refresh_token_str, token_type="bearer" ) @router.get("/me", response_model=UserResponse) async def get_me( current_user: Annotated[User, Depends(get_current_active_user)] ): """Get current authenticated user.""" return current_user @router.patch("/me", response_model=UserResponse) async def update_me( data: UserUpdate, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)] ): """Update current user's profile (name, email).""" update_fields = data.model_fields_set - {"current_password"} if not update_fields: return current_user # Email change requires current_password if "email" in data.model_fields_set: if not data.current_password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Current password is required to change email" ) if not verify_password(data.current_password, current_user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Current password is incorrect" ) # Check uniqueness result = await db.execute( select(User).where(User.email == data.email, User.id != current_user.id) ) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) current_user.email = data.email if "name" in data.model_fields_set and data.name is not None: current_user.name = data.name # Handle simple string profile fields for field in ("phone", "job_title", "timezone"): if field in data.model_fields_set: setattr(current_user, field, getattr(data, field)) await log_audit(db, current_user.id, "auth.profile_update", "user", current_user.id) await db.commit() await db.refresh(current_user) return current_user @router.post("/logout") async def logout( payload: Annotated[dict, Depends(get_refresh_token_payload)], db: Annotated[AsyncSession, Depends(get_db)] ): """Logout user by revoking the refresh token.""" jti = payload.get("jti") if jti: token_hash = hash_token(jti) result = await db.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) stored_token = result.scalar_one_or_none() if stored_token and not stored_token.is_revoked: stored_token.revoked_at = datetime.now(timezone.utc) await db.commit() return {"message": "Successfully logged out"} @router.post("/password/change") @limiter.limit("5/minute") async def change_password( request: Request, data: ChangePasswordRequest, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)] ): """Change the current user's password.""" if not verify_password(data.current_password, current_user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Current password is incorrect" ) if data.current_password == data.new_password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="New password must be different from current password" ) current_user.password_hash = get_password_hash(data.new_password) current_user.must_change_password = False # Revoke all refresh tokens for this user result = await db.execute( select(RefreshToken).where( RefreshToken.user_id == current_user.id, RefreshToken.revoked_at.is_(None) ) ) active_tokens = result.scalars().all() for token in active_tokens: token.revoked_at = datetime.now(timezone.utc) await log_audit(db, current_user.id, "auth.password_change", "user", current_user.id) await db.commit() return {"message": "Password changed successfully"} @router.post("/password/forgot") @limiter.limit("3/minute") async def forgot_password( request: Request, data: ForgotPasswordRequest, db: Annotated[AsyncSession, Depends(get_db)] ): """Request a password reset email. Always returns success (anti-enumeration).""" result = await db.execute(select(User).where(User.email == data.email)) user = result.scalar_one_or_none() if user: # Create reset token JWT raw_token = create_password_reset_token(str(user.id)) payload = decode_token(raw_token) if payload and payload.get("jti"): token_record = PasswordResetToken( token_hash=hash_token(payload["jti"]), user_id=user.id, expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), ) db.add(token_record) await db.commit() # Send email (best-effort) reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}" await EmailService.send_password_reset_email( to_email=user.email, reset_url=reset_url, ) await log_audit(db, user.id, "auth.password_reset.request", "user", user.id) await db.commit() return {"message": "If an account with that email exists, a reset link has been sent."} @router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse) async def verify_reset_token( data: VerifyResetTokenRequest, db: Annotated[AsyncSession, Depends(get_db)] ): """Verify a password reset token is valid.""" payload = decode_token(data.token) if not payload or payload.get("type") != "password_reset": return VerifyResetTokenResponse(valid=False) jti = payload.get("jti") if not jti: return VerifyResetTokenResponse(valid=False) result = await db.execute( select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti)) ) token_record = result.scalar_one_or_none() if not token_record or not token_record.is_valid: return VerifyResetTokenResponse(valid=False) # Get user email for display user_result = await db.execute(select(User.email).where(User.id == token_record.user_id)) email = user_result.scalar_one_or_none() return VerifyResetTokenResponse(valid=True, email=email) @router.post("/password/reset") @limiter.limit("5/minute") async def reset_password( request: Request, data: ResetPasswordRequest, db: Annotated[AsyncSession, Depends(get_db)] ): """Reset password using a valid reset token.""" payload = decode_token(data.token) if not payload or payload.get("type") != "password_reset": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired reset token" ) jti = payload.get("jti") user_id = payload.get("sub") if not jti or not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token" ) # Validate token in DB (single-use). # FOR UPDATE prevents two concurrent reset requests from both succeeding. result = await db.execute( select(PasswordResetToken) .where(PasswordResetToken.token_hash == hash_token(jti)) .with_for_update() ) token_record = result.scalar_one_or_none() if not token_record or not token_record.is_valid: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token has already been used or has expired" ) # Get user result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token" ) # Update password user.password_hash = get_password_hash(data.new_password) user.must_change_password = False # Mark token as used token_record.used_at = datetime.now(timezone.utc) # Revoke all refresh tokens rt_result = await db.execute( select(RefreshToken).where( RefreshToken.user_id == user.id, RefreshToken.revoked_at.is_(None) ) ) for rt in rt_result.scalars().all(): rt.revoked_at = datetime.now(timezone.utc) await log_audit(db, user.id, "auth.password_reset.complete", "user", user.id) await db.commit() return {"message": "Password has been reset successfully"} @router.get("/email/verification-status") async def get_verification_status( db: Annotated[AsyncSession, Depends(get_db)] ): """Check if email verification is enabled on the platform.""" enabled = await SettingsManager.get("email_verification_enabled", db, default=True) return {"enabled": enabled} @router.post("/email/send-verification") @limiter.limit("3/minute") async def send_verification_email( request: Request, current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)] ): """Send an email verification link to the current user.""" verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True) if not verification_enabled: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Email verification is currently disabled" ) if current_user.email_verified_at is not None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already verified" ) raw_token = create_email_verification_token(str(current_user.id)) payload = decode_token(raw_token) if payload and payload.get("jti"): token_record = EmailVerificationToken( token_hash=hash_token(payload["jti"]), user_id=current_user.id, expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), ) db.add(token_record) await db.commit() verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}" await EmailService.send_email_verification_email( to_email=current_user.email, verification_url=verification_url, ) return {"message": "Verification email sent"} @router.post("/email/verify") async def verify_email( data: dict, db: Annotated[AsyncSession, Depends(get_db)] ): """Verify an email using a token. Public endpoint.""" token = data.get("token") if not token: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Token is required" ) payload = decode_token(token) if not payload or payload.get("type") != "email_verification": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired verification token" ) jti = payload.get("jti") user_id = payload.get("sub") if not jti or not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token" ) # FOR UPDATE prevents two concurrent verification requests from both succeeding. result = await db.execute( select(EmailVerificationToken) .where(EmailVerificationToken.token_hash == hash_token(jti)) .with_for_update() ) token_record = result.scalar_one_or_none() if not token_record or not token_record.is_valid: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Verification token has already been used or has expired" ) # Mark token as used token_record.used_at = datetime.now(timezone.utc) # Mark user email as verified result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token" ) user.email_verified_at = datetime.now(timezone.utc) await log_audit(db, user.id, "auth.email_verified", "user", user.id) await db.commit() return {"message": "Email verified successfully"} @router.get("/me/feature-flags", response_model=dict[str, bool]) async def get_my_feature_flags( current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ) -> dict[str, bool]: """Resolve feature flags for the current user's account and plan.""" plan = "free" if current_user.account_id: sub_result = await db.execute( select(Subscription).where( Subscription.account_id == current_user.account_id, Subscription.status.in_(["active", "trialing"]), ) ) sub = sub_result.scalar_one_or_none() if sub: plan = sub.plan flags_result = await db.execute(select(FeatureFlag)) flags = flags_result.scalars().all() if not flags: return {} flag_ids = [f.id for f in flags] defaults_result = await db.execute( select(PlanFeatureDefault).where( PlanFeatureDefault.flag_id.in_(flag_ids), PlanFeatureDefault.plan == plan, ) ) plan_defaults = {d.flag_id: d.enabled for d in defaults_result.scalars().all()} overrides: dict[uuid.UUID, bool] = {} if current_user.account_id: overrides_result = await db.execute( select(AccountFeatureOverride).where( AccountFeatureOverride.flag_id.in_(flag_ids), AccountFeatureOverride.account_id == current_user.account_id, ) ) overrides = {o.flag_id: o.enabled for o in overrides_result.scalars().all()} resolved = {} for flag in flags: if flag.id in overrides: resolved[flag.flag_key] = overrides[flag.id] elif flag.id in plan_defaults: resolved[flag.flag_key] = plan_defaults[flag.id] else: resolved[flag.flag_key] = False return resolved