from datetime import datetime, timezone, timedelta from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.database import get_db from app.core.rate_limit import limiter from app.core.audit import log_audit from app.core.email import EmailService from app.models.user import User from app.models.invite_code import InviteCode from app.schemas.invite_code import InviteCodeCreate, InviteCodeResponse, InviteCodeValidation from app.api.deps import require_admin router = APIRouter(prefix="/invites", tags=["invites"]) @router.post("", response_model=InviteCodeResponse, status_code=status.HTTP_201_CREATED) async def create_invite_code( invite_data: InviteCodeCreate, current_user: Annotated[User, Depends(require_admin)], db: Annotated[AsyncSession, Depends(get_db)] ): """Create a new invite code. Admin only.""" invite_code = InviteCode( created_by_id=current_user.id, expires_at=invite_data.expires_at, note=invite_data.note, email=invite_data.email, assigned_plan=invite_data.assigned_plan, trial_duration_days=invite_data.trial_duration_days, ) db.add(invite_code) await db.flush() # Send invite email if email provided email_sent = False if invite_data.email: email_sent = await EmailService.send_invite_email( to_email=invite_data.email, code=invite_code.code, plan=invite_data.assigned_plan, trial_days=invite_data.trial_duration_days, ) if email_sent: invite_code.email_sent_at = datetime.now(timezone.utc) await log_audit( db, current_user.id, "invite.create", "invite_code", invite_code.id, { "code": invite_code.code, "plan": invite_data.assigned_plan, "email": invite_data.email, "email_sent": email_sent, }, ) await db.commit() await db.refresh(invite_code) return invite_code @router.get("", response_model=list[InviteCodeResponse]) async def list_invite_codes( current_user: Annotated[User, Depends(require_admin)], db: Annotated[AsyncSession, Depends(get_db)] ): """List all invite codes. Admin only.""" result = await db.execute( select(InviteCode).order_by(InviteCode.created_at.desc()) ) invite_codes = result.scalars().all() return invite_codes @router.delete("/{code}", status_code=status.HTTP_204_NO_CONTENT) async def revoke_invite_code( code: str, current_user: Annotated[User, Depends(require_admin)], db: Annotated[AsyncSession, Depends(get_db)] ): """Revoke (delete) an invite code. Admin only.""" result = await db.execute( select(InviteCode).where(InviteCode.code == code) ) invite_code = result.scalar_one_or_none() if not invite_code: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invite code not found" ) if invite_code.is_used: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke a used invite code" ) await db.delete(invite_code) await db.commit() @router.post("/{code}/resend", response_model=InviteCodeResponse) async def resend_invite_code( code: str, current_user: Annotated[User, Depends(require_admin)], db: Annotated[AsyncSession, Depends(get_db)] ): """Revoke an existing invite code and create a new one with the same properties, then email it.""" result = await db.execute( select(InviteCode).where(InviteCode.code == code) ) old_invite = result.scalar_one_or_none() if not old_invite: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invite code not found" ) if old_invite.is_used: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Cannot resend a used invite code" ) if not old_invite.email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot resend an invite code without a recipient email" ) # Recalculate expiration from now if the old one had an expiration new_expires_at = None if old_invite.expires_at and old_invite.created_at: original_duration = old_invite.expires_at - old_invite.created_at new_expires_at = datetime.now(timezone.utc) + original_duration elif old_invite.expires_at: new_expires_at = old_invite.expires_at # Capture properties before deleting email = old_invite.email assigned_plan = old_invite.assigned_plan trial_duration_days = old_invite.trial_duration_days note = old_invite.note old_code = old_invite.code await db.delete(old_invite) await db.flush() # Create new invite with same properties new_invite = InviteCode( created_by_id=current_user.id, expires_at=new_expires_at, note=note, email=email, assigned_plan=assigned_plan, trial_duration_days=trial_duration_days, ) db.add(new_invite) await db.flush() # Send email email_sent = await EmailService.send_invite_email( to_email=email, code=new_invite.code, plan=assigned_plan, trial_days=trial_duration_days, ) if email_sent: new_invite.email_sent_at = datetime.now(timezone.utc) await log_audit( db, current_user.id, "invite.resend", "invite_code", new_invite.id, { "old_code": old_code, "new_code": new_invite.code, "email": email, "email_sent": email_sent, }, ) await db.commit() await db.refresh(new_invite) return new_invite @router.get("/validate/{code}", response_model=InviteCodeValidation) @limiter.limit("5/minute") async def validate_invite_code( request: Request, code: str, db: Annotated[AsyncSession, Depends(get_db)] ): """Check if an invite code is valid. Public endpoint for UX.""" result = await db.execute( select(InviteCode).where(InviteCode.code == code.upper()) ) invite_code = result.scalar_one_or_none() if not invite_code: return InviteCodeValidation(valid=False, message="Invalid invite code") if invite_code.is_used: return InviteCodeValidation(valid=False, message="Invite code has already been used") if invite_code.is_expired: return InviteCodeValidation(valid=False, message="Invite code has expired") return InviteCodeValidation(valid=True, message="Invite code is valid")