From f4606f073ad0a38dfc39844499ad8fbc30716a3d Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Wed, 6 May 2026 14:58:35 -0400 Subject: [PATCH] feat(auth): add Google OAuth callback with oauth_identities linking Co-Authored-By: Claude Opus 4.7 --- backend/app/api/endpoints/oauth.py | 106 ++++++++++++++++++++++++ backend/app/api/router.py | 2 + backend/app/core/config.py | 7 ++ backend/app/schemas/oauth.py | 13 +++ backend/app/services/oauth_providers.py | 71 ++++++++++++++++ backend/tests/test_oauth_callbacks.py | 95 +++++++++++++++++++++ 6 files changed, 294 insertions(+) create mode 100644 backend/app/api/endpoints/oauth.py create mode 100644 backend/app/schemas/oauth.py create mode 100644 backend/app/services/oauth_providers.py create mode 100644 backend/tests/test_oauth_callbacks.py diff --git a/backend/app/api/endpoints/oauth.py b/backend/app/api/endpoints/oauth.py new file mode 100644 index 00000000..ba81d2f0 --- /dev/null +++ b/backend/app/api/endpoints/oauth.py @@ -0,0 +1,106 @@ +import secrets +import string +from datetime import datetime, timezone +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.admin_database import get_admin_db +from app.core.config import settings +from app.core.security import create_access_token, create_refresh_token +from app.models.account import Account +from app.models.oauth_identity import OAuthIdentity +from app.models.user import User +from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse +from app.services.billing import BillingService +from app.services.oauth_providers import ( + google_exchange_code, + microsoft_exchange_code, + OAuthProfile, +) + +router = APIRouter(prefix="/auth", tags=["auth-oauth"]) + + +def _generate_display_code(length: int = 8) -> str: + """Match the helper used by /auth/register — A-Z + 0-9, length 8.""" + alphabet = string.ascii_uppercase + string.digits + return "".join(secrets.choice(alphabet) for _ in range(length)) + + +async def _sign_in_or_register( + db: AsyncSession, provider: str, profile: OAuthProfile +) -> tuple[User, bool]: + """Returns (user, is_new_user). Idempotent on (provider, provider_subject).""" + identity = ( + await db.execute( + select(OAuthIdentity).where( + OAuthIdentity.provider == provider, + OAuthIdentity.provider_subject == profile.provider_subject, + ) + ) + ).scalar_one_or_none() + + if identity: + user = ( + await db.execute(select(User).where(User.id == identity.user_id)) + ).scalar_one() + return user, False + + user = ( + await db.execute(select(User).where(User.email == profile.email)) + ).scalar_one_or_none() + is_new_user = user is None + if is_new_user: + account = Account( + name=f"{profile.name}'s Account", + display_code=_generate_display_code(), + ) + db.add(account) + await db.flush() + user = User( + email=profile.email, + name=profile.name, + password_hash=None, + account_id=account.id, + account_role="owner", + role="engineer", + email_verified_at=datetime.now(timezone.utc), + ) + db.add(user) + await db.flush() + account.owner_id = user.id + await db.flush() + # start_trial commits internally; flushed account/user above. + await BillingService.start_trial(db, account.id) + + db.add( + OAuthIdentity( + user_id=user.id, + provider=provider, + provider_subject=profile.provider_subject, + provider_email_at_link=profile.email, + ) + ) + await db.commit() + await db.refresh(user) + return user, is_new_user + + +@router.post("/google/callback", response_model=OAuthCallbackResponse) +async def google_callback( + payload: OAuthCallbackPayload, + db: Annotated[AsyncSession, Depends(get_admin_db)], +) -> OAuthCallbackResponse: + if not settings.GOOGLE_CLIENT_ID: + raise HTTPException(status_code=503, detail="Google sign-in not configured") + redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback" + profile = await google_exchange_code(payload.code, redirect_uri) + user, is_new = await _sign_in_or_register(db, "google", profile) + return OAuthCallbackResponse( + access_token=create_access_token({"sub": str(user.id)}), + refresh_token=create_refresh_token({"sub": str(user.id)}), + is_new_user=is_new, + ) diff --git a/backend/app/api/router.py b/backend/app/api/router.py index c3018855..01ce9a8a 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -41,6 +41,7 @@ from app.api.endpoints import ( maintenance_schedules, network_diagrams, notifications, + oauth as oauth_endpoints, onboarding, public_templates, ratings, @@ -82,6 +83,7 @@ api_router = APIRouter() # in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS. # --------------------------------------------------------------------------- api_router.include_router(auth.router) +api_router.include_router(oauth_endpoints.router) api_router.include_router(billing.router) # Reachable when subscription locked api_router.include_router(shared.router) # Public share links (no auth) api_router.include_router(shares.public_router) # Public session share links (optional auth) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 23795e42..f2c28593 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -194,6 +194,13 @@ class Settings(BaseSettings): """Check if ConnectWise integration is configured.""" return self.CW_CLIENT_ID is not None + # OAuth providers (self-serve signup) + GOOGLE_CLIENT_ID: Optional[str] = None + GOOGLE_CLIENT_SECRET: Optional[str] = None + MS_CLIENT_ID: Optional[str] = None + MS_CLIENT_SECRET: Optional[str] = None + OAUTH_REDIRECT_BASE: str = "http://localhost:5173" + # Monitoring SENTRY_DSN: Optional[str] = None diff --git a/backend/app/schemas/oauth.py b/backend/app/schemas/oauth.py new file mode 100644 index 00000000..47ddf9ca --- /dev/null +++ b/backend/app/schemas/oauth.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class OAuthCallbackPayload(BaseModel): + code: str + state: str | None = None + + +class OAuthCallbackResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + is_new_user: bool diff --git a/backend/app/services/oauth_providers.py b/backend/app/services/oauth_providers.py new file mode 100644 index 00000000..947743a5 --- /dev/null +++ b/backend/app/services/oauth_providers.py @@ -0,0 +1,71 @@ +"""OAuth provider helpers. Each provider exposes: +- exchange_code(code, redirect_uri) -> OAuthProfile +""" +from dataclasses import dataclass + +import httpx +from app.core.config import settings + + +@dataclass +class OAuthProfile: + provider_subject: str + email: str + name: str + + +async def google_exchange_code(code: str, redirect_uri: str) -> OAuthProfile: + async with httpx.AsyncClient(timeout=10) as cli: + token_response = await cli.post( + "https://oauth2.googleapis.com/token", + data={ + "code": code, + "client_id": settings.GOOGLE_CLIENT_ID, + "client_secret": settings.GOOGLE_CLIENT_SECRET, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + token_response.raise_for_status() + access_token = token_response.json()["access_token"] + + userinfo = await cli.get( + "https://openidconnect.googleapis.com/v1/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + userinfo.raise_for_status() + data = userinfo.json() + return OAuthProfile( + provider_subject=data["sub"], + email=data["email"], + name=data.get("name") or data["email"].split("@")[0], + ) + + +async def microsoft_exchange_code(code: str, redirect_uri: str) -> OAuthProfile: + async with httpx.AsyncClient(timeout=10) as cli: + token_response = await cli.post( + "https://login.microsoftonline.com/common/oauth2/v2.0/token", + data={ + "code": code, + "client_id": settings.MS_CLIENT_ID, + "client_secret": settings.MS_CLIENT_SECRET, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + "scope": "openid email profile", + }, + ) + token_response.raise_for_status() + access_token = token_response.json()["access_token"] + + userinfo = await cli.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {access_token}"}, + ) + userinfo.raise_for_status() + data = userinfo.json() + return OAuthProfile( + provider_subject=data["id"], + email=data.get("mail") or data["userPrincipalName"], + name=data.get("displayName") or data["userPrincipalName"].split("@")[0], + ) diff --git a/backend/tests/test_oauth_callbacks.py b/backend/tests/test_oauth_callbacks.py new file mode 100644 index 00000000..41064214 --- /dev/null +++ b/backend/tests/test_oauth_callbacks.py @@ -0,0 +1,95 @@ +import uuid +import pytest +from unittest.mock import patch +from sqlalchemy import select +from app.models.user import User +from app.models.oauth_identity import OAuthIdentity +from app.models.subscription import Subscription +from app.services.oauth_providers import OAuthProfile + + +@pytest.mark.asyncio +async def test_google_callback_creates_user_account_subscription( + client, test_db, monkeypatch +): + """Brand-new user via Google OAuth -> User + Account + Subscription + OAuthIdentity.""" + from app.core.config import settings + monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy") + monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy") + + profile = OAuthProfile( + provider_subject="google_subject_123", + email="newuser@example.com", + name="New User", + ) + with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile): + response = await client.post( + "/api/v1/auth/google/callback", json={"code": "auth_code_xyz"} + ) + assert response.status_code == 200, response.json() + body = response.json() + assert body["is_new_user"] is True + assert body["access_token"] + + user = (await test_db.execute( + select(User).where(User.email == "newuser@example.com") + )).scalar_one() + assert user.password_hash is None + assert user.email_verified_at is not None + + identity = (await test_db.execute( + select(OAuthIdentity).where(OAuthIdentity.user_id == user.id) + )).scalar_one() + assert identity.provider == "google" + assert identity.provider_subject == "google_subject_123" + + sub = (await test_db.execute( + select(Subscription).where(Subscription.account_id == user.account_id) + )).scalar_one() + assert sub.status == "trialing" + assert sub.plan == "pro" + + +@pytest.mark.asyncio +async def test_google_callback_existing_user_is_idempotent( + client, test_db, test_user, monkeypatch +): + """When test_user's email is already registered, OAuth links + returns the + same user. Two calls with same provider_subject must not duplicate + OAuthIdentity rows.""" + from app.core.config import settings + monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy") + monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy") + + user_id = uuid.UUID(test_user["user_data"]["id"]) + email = test_user["email"] + name = test_user["user_data"]["name"] + + profile = OAuthProfile( + provider_subject="google_subject_456", + email=email, + name=name, + ) + with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile): + r1 = await client.post("/api/v1/auth/google/callback", json={"code": "x"}) + r2 = await client.post("/api/v1/auth/google/callback", json={"code": "x"}) + assert r1.status_code == 200 + assert r2.status_code == 200 + assert r1.json()["is_new_user"] is False + assert r2.json()["is_new_user"] is False + + identities = (await test_db.execute( + select(OAuthIdentity).where(OAuthIdentity.user_id == user_id) + )).scalars().all() + assert len(identities) == 1 + + +@pytest.mark.asyncio +async def test_google_callback_503_when_unconfigured(client, monkeypatch): + from app.core.config import settings + monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None) + + response = await client.post( + "/api/v1/auth/google/callback", json={"code": "x"} + ) + assert response.status_code == 503