import secrets import string from datetime import datetime, timezone from typing import Annotated from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.admin_database import get_admin_db from app.core.config import settings from app.core.security import create_access_token, create_refresh_token from app.models.account import Account from app.models.oauth_identity import OAuthIdentity from app.models.user import User from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse from app.services.billing import BillingService from app.services.oauth_providers import ( google_exchange_code, microsoft_exchange_code, OAuthProfile, ) router = APIRouter(prefix="/auth", tags=["auth-oauth"]) def _generate_display_code(length: int = 8) -> str: """Match the helper used by /auth/register — A-Z + 0-9, length 8.""" alphabet = string.ascii_uppercase + string.digits return "".join(secrets.choice(alphabet) for _ in range(length)) async def _sign_in_or_register( db: AsyncSession, provider: str, profile: OAuthProfile ) -> tuple[User, bool]: """Returns (user, is_new_user). Idempotent on (provider, provider_subject).""" identity = ( await db.execute( select(OAuthIdentity).where( OAuthIdentity.provider == provider, OAuthIdentity.provider_subject == profile.provider_subject, ) ) ).scalar_one_or_none() if identity: user = ( await db.execute(select(User).where(User.id == identity.user_id)) ).scalar_one() return user, False user = ( await db.execute(select(User).where(User.email == profile.email)) ).scalar_one_or_none() is_new_user = user is None if is_new_user: account = Account( name=f"{profile.name}'s Account", display_code=_generate_display_code(), ) db.add(account) await db.flush() user = User( email=profile.email, name=profile.name, password_hash=None, account_id=account.id, account_role="owner", role="engineer", email_verified_at=datetime.now(timezone.utc), ) db.add(user) await db.flush() account.owner_id = user.id await db.flush() # start_trial commits internally; flushed account/user above. await BillingService.start_trial(db, account.id) db.add( OAuthIdentity( user_id=user.id, provider=provider, provider_subject=profile.provider_subject, provider_email_at_link=profile.email, ) ) await db.commit() await db.refresh(user) return user, is_new_user @router.post("/google/callback", response_model=OAuthCallbackResponse) async def google_callback( payload: OAuthCallbackPayload, db: Annotated[AsyncSession, Depends(get_admin_db)], ) -> OAuthCallbackResponse: if not settings.GOOGLE_CLIENT_ID: raise HTTPException(status_code=503, detail="Google sign-in not configured") redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback" profile = await google_exchange_code(payload.code, redirect_uri) user, is_new = await _sign_in_or_register(db, "google", profile) return OAuthCallbackResponse( access_token=create_access_token({"sub": str(user.id)}), refresh_token=create_refresh_token({"sub": str(user.id)}), is_new_user=is_new, ) @router.post("/microsoft/callback", response_model=OAuthCallbackResponse) async def microsoft_callback( payload: OAuthCallbackPayload, db: Annotated[AsyncSession, Depends(get_admin_db)], ) -> OAuthCallbackResponse: if not settings.MS_CLIENT_ID: raise HTTPException(status_code=503, detail="Microsoft sign-in not configured") redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback" profile = await microsoft_exchange_code(payload.code, redirect_uri) user, is_new = await _sign_in_or_register(db, "microsoft", profile) return OAuthCallbackResponse( access_token=create_access_token({"sub": str(user.id)}), refresh_token=create_refresh_token({"sub": str(user.id)}), is_new_user=is_new, )