import secrets import string from datetime import datetime, timezone from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.config import settings from app.core.database import get_db from app.core.rate_limit import limiter from app.core.security import ( verify_password, get_password_hash, create_access_token, create_refresh_token, decode_token, hash_token, ) from app.models.user import User from app.models.invite_code import InviteCode from app.models.refresh_token import RefreshToken from app.models.account import Account from app.models.subscription import Subscription from app.models.account_invite import AccountInvite from app.schemas.user import UserCreate, UserResponse, UserLogin from app.schemas.token import Token from app.api.deps import get_current_active_user, get_refresh_token_payload router = APIRouter(prefix="/auth", tags=["authentication"]) async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None: """Decode a refresh token JWT and store its hash in the database.""" payload = decode_token(refresh_token_str) if payload and payload.get("jti"): token_record = RefreshToken( token_hash=hash_token(payload["jti"]), user_id=user_id, expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), ) db.add(token_record) def _generate_display_code() -> str: """Generate a random 8-character alphanumeric display code.""" chars = string.ascii_uppercase + string.digits return ''.join(secrets.choice(chars) for _ in range(8)) @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @limiter.limit("3/minute") async def register( request: Request, user_data: UserCreate, db: Annotated[AsyncSession, Depends(get_db)] ): """Register a new user. Supports two flows: - account_invite_code: Join an existing account (bypasses platform invite gate) - invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled) After user creation, if no account invite was used, a personal Account and free Subscription are created automatically. """ # Check for account invite code FIRST — bypasses platform invite gate account_invite_record = None if user_data.account_invite_code: result = await db.execute( select(AccountInvite).where( AccountInvite.code == user_data.account_invite_code ) ) account_invite_record = result.scalar_one_or_none() if not account_invite_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid account invite code" ) if account_invite_record.is_used: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account invite code has already been used" ) if account_invite_record.is_expired: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account invite code has expired" ) # Validate platform invite code if required (skip if account invite was provided) invite_code_record = None if not account_invite_record and settings.REQUIRE_INVITE_CODE: if not user_data.invite_code: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code is required" ) # Look up invite code (case-insensitive) result = await db.execute( select(InviteCode).where(InviteCode.code == user_data.invite_code.upper()) ) invite_code_record = result.scalar_one_or_none() if not invite_code_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid invite code" ) if invite_code_record.is_used: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code has already been used" ) if invite_code_record.is_expired: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite code has expired" ) # Check if email already exists result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) if account_invite_record: # Join existing account via account invite new_user = User( email=user_data.email, password_hash=get_password_hash(user_data.password), name=user_data.name, role="engineer", invite_code_id=invite_code_record.id if invite_code_record else None, account_id=account_invite_record.account_id, account_role=account_invite_record.role, ) db.add(new_user) await db.flush() # Mark account invite as used account_invite_record.accepted_by_id = new_user.id account_invite_record.used_at = datetime.now(timezone.utc) else: # Create personal Account first (user needs account_id for NOT NULL constraint) new_account = Account( name=f"{user_data.name}'s Account", display_code=_generate_display_code(), ) db.add(new_account) await db.flush() # Get account ID new_user = User( email=user_data.email, password_hash=get_password_hash(user_data.password), name=user_data.name, role="engineer", invite_code_id=invite_code_record.id if invite_code_record else None, account_id=new_account.id, account_role="owner", ) db.add(new_user) await db.flush() # Get user ID # Now set account owner and create subscription new_account.owner_id = new_user.id new_subscription = Subscription( account_id=new_account.id, plan="free", status="active", ) db.add(new_subscription) # Mark platform invite code as used if invite_code_record: invite_code_record.used_by_id = new_user.id invite_code_record.used_at = datetime.now(timezone.utc) await db.commit() await db.refresh(new_user) return new_user @router.post("/login", response_model=Token) @limiter.limit("5/minute") async def login( request: Request, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[AsyncSession, Depends(get_db)] ): """Login and get access token.""" # Find user by email result = await db.execute(select(User).where(User.email == form_data.username)) user = result.scalar_one_or_none() if not user or not verify_password(form_data.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"}, ) # Update last login user.last_login = datetime.now(timezone.utc) # Create tokens access_token = create_access_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB await _store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/login/json", response_model=Token) @limiter.limit("5/minute") async def login_json( request: Request, credentials: UserLogin, db: Annotated[AsyncSession, Depends(get_db)] ): """Login with JSON body (alternative to form data).""" result = await db.execute(select(User).where(User.email == credentials.email)) user = result.scalar_one_or_none() if not user or not verify_password(credentials.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password" ) user.last_login = datetime.now(timezone.utc) access_token = create_access_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB await _store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/refresh", response_model=Token) @limiter.limit("10/minute") async def refresh_token( request: Request, payload: Annotated[dict, Depends(get_refresh_token_payload)], db: Annotated[AsyncSession, Depends(get_db)] ): """Refresh access token using refresh token (rotation: old token is revoked).""" user_id = payload.get("sub") jti = payload.get("jti") # Validate refresh token hasn't been revoked if jti: token_hash = hash_token(jti) result = await db.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) stored_token = result.scalar_one_or_none() if stored_token and stored_token.is_revoked: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has been revoked" ) # Revoke the old refresh token (token rotation) if stored_token: stored_token.revoked_at = datetime.now(timezone.utc) result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" ) access_token = create_access_token(data={"sub": str(user.id)}) new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store new refresh token await _store_refresh_token(db, new_refresh_token_str, user.id) await db.commit() return Token( access_token=access_token, refresh_token=new_refresh_token_str, token_type="bearer" ) @router.get("/me", response_model=UserResponse) async def get_me( current_user: Annotated[User, Depends(get_current_active_user)] ): """Get current authenticated user.""" return current_user @router.post("/logout") async def logout( payload: Annotated[dict, Depends(get_refresh_token_payload)], db: Annotated[AsyncSession, Depends(get_db)] ): """Logout user by revoking the refresh token.""" jti = payload.get("jti") if jti: token_hash = hash_token(jti) result = await db.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) stored_token = result.scalar_one_or_none() if stored_token and not stored_token.is_revoked: stored_token.revoked_at = datetime.now(timezone.utc) await db.commit() return {"message": "Successfully logged out"}