feat(auth): add Google OAuth callback with oauth_identities linking
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
106
backend/app/api/endpoints/oauth.py
Normal file
106
backend/app/api/endpoints/oauth.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.admin_database import get_admin_db
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.security import create_access_token, create_refresh_token
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
from app.services.oauth_providers import (
|
||||||
|
google_exchange_code,
|
||||||
|
microsoft_exchange_code,
|
||||||
|
OAuthProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth-oauth"])
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_display_code(length: int = 8) -> str:
|
||||||
|
"""Match the helper used by /auth/register — A-Z + 0-9, length 8."""
|
||||||
|
alphabet = string.ascii_uppercase + string.digits
|
||||||
|
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
async def _sign_in_or_register(
|
||||||
|
db: AsyncSession, provider: str, profile: OAuthProfile
|
||||||
|
) -> tuple[User, bool]:
|
||||||
|
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
|
||||||
|
identity = (
|
||||||
|
await db.execute(
|
||||||
|
select(OAuthIdentity).where(
|
||||||
|
OAuthIdentity.provider == provider,
|
||||||
|
OAuthIdentity.provider_subject == profile.provider_subject,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if identity:
|
||||||
|
user = (
|
||||||
|
await db.execute(select(User).where(User.id == identity.user_id))
|
||||||
|
).scalar_one()
|
||||||
|
return user, False
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await db.execute(select(User).where(User.email == profile.email))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
is_new_user = user is None
|
||||||
|
if is_new_user:
|
||||||
|
account = Account(
|
||||||
|
name=f"{profile.name}'s Account",
|
||||||
|
display_code=_generate_display_code(),
|
||||||
|
)
|
||||||
|
db.add(account)
|
||||||
|
await db.flush()
|
||||||
|
user = User(
|
||||||
|
email=profile.email,
|
||||||
|
name=profile.name,
|
||||||
|
password_hash=None,
|
||||||
|
account_id=account.id,
|
||||||
|
account_role="owner",
|
||||||
|
role="engineer",
|
||||||
|
email_verified_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
account.owner_id = user.id
|
||||||
|
await db.flush()
|
||||||
|
# start_trial commits internally; flushed account/user above.
|
||||||
|
await BillingService.start_trial(db, account.id)
|
||||||
|
|
||||||
|
db.add(
|
||||||
|
OAuthIdentity(
|
||||||
|
user_id=user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_subject=profile.provider_subject,
|
||||||
|
provider_email_at_link=profile.email,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user, is_new_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/google/callback", response_model=OAuthCallbackResponse)
|
||||||
|
async def google_callback(
|
||||||
|
payload: OAuthCallbackPayload,
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
) -> OAuthCallbackResponse:
|
||||||
|
if not settings.GOOGLE_CLIENT_ID:
|
||||||
|
raise HTTPException(status_code=503, detail="Google sign-in not configured")
|
||||||
|
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
|
||||||
|
profile = await google_exchange_code(payload.code, redirect_uri)
|
||||||
|
user, is_new = await _sign_in_or_register(db, "google", profile)
|
||||||
|
return OAuthCallbackResponse(
|
||||||
|
access_token=create_access_token({"sub": str(user.id)}),
|
||||||
|
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||||
|
is_new_user=is_new,
|
||||||
|
)
|
||||||
@@ -41,6 +41,7 @@ from app.api.endpoints import (
|
|||||||
maintenance_schedules,
|
maintenance_schedules,
|
||||||
network_diagrams,
|
network_diagrams,
|
||||||
notifications,
|
notifications,
|
||||||
|
oauth as oauth_endpoints,
|
||||||
onboarding,
|
onboarding,
|
||||||
public_templates,
|
public_templates,
|
||||||
ratings,
|
ratings,
|
||||||
@@ -82,6 +83,7 @@ api_router = APIRouter()
|
|||||||
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
api_router.include_router(auth.router)
|
api_router.include_router(auth.router)
|
||||||
|
api_router.include_router(oauth_endpoints.router)
|
||||||
api_router.include_router(billing.router) # Reachable when subscription locked
|
api_router.include_router(billing.router) # Reachable when subscription locked
|
||||||
api_router.include_router(shared.router) # Public share links (no auth)
|
api_router.include_router(shared.router) # Public share links (no auth)
|
||||||
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
||||||
|
|||||||
@@ -194,6 +194,13 @@ class Settings(BaseSettings):
|
|||||||
"""Check if ConnectWise integration is configured."""
|
"""Check if ConnectWise integration is configured."""
|
||||||
return self.CW_CLIENT_ID is not None
|
return self.CW_CLIENT_ID is not None
|
||||||
|
|
||||||
|
# OAuth providers (self-serve signup)
|
||||||
|
GOOGLE_CLIENT_ID: Optional[str] = None
|
||||||
|
GOOGLE_CLIENT_SECRET: Optional[str] = None
|
||||||
|
MS_CLIENT_ID: Optional[str] = None
|
||||||
|
MS_CLIENT_SECRET: Optional[str] = None
|
||||||
|
OAUTH_REDIRECT_BASE: str = "http://localhost:5173"
|
||||||
|
|
||||||
# Monitoring
|
# Monitoring
|
||||||
SENTRY_DSN: Optional[str] = None
|
SENTRY_DSN: Optional[str] = None
|
||||||
|
|
||||||
|
|||||||
13
backend/app/schemas/oauth.py
Normal file
13
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackPayload(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
is_new_user: bool
|
||||||
71
backend/app/services/oauth_providers.py
Normal file
71
backend/app/services/oauth_providers.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""OAuth provider helpers. Each provider exposes:
|
||||||
|
- exchange_code(code, redirect_uri) -> OAuthProfile
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthProfile:
|
||||||
|
provider_subject: str
|
||||||
|
email: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
async def google_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as cli:
|
||||||
|
token_response = await cli.post(
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": settings.GOOGLE_CLIENT_ID,
|
||||||
|
"client_secret": settings.GOOGLE_CLIENT_SECRET,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token_response.raise_for_status()
|
||||||
|
access_token = token_response.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo = await cli.get(
|
||||||
|
"https://openidconnect.googleapis.com/v1/userinfo",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo.raise_for_status()
|
||||||
|
data = userinfo.json()
|
||||||
|
return OAuthProfile(
|
||||||
|
provider_subject=data["sub"],
|
||||||
|
email=data["email"],
|
||||||
|
name=data.get("name") or data["email"].split("@")[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def microsoft_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as cli:
|
||||||
|
token_response = await cli.post(
|
||||||
|
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": settings.MS_CLIENT_ID,
|
||||||
|
"client_secret": settings.MS_CLIENT_SECRET,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token_response.raise_for_status()
|
||||||
|
access_token = token_response.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo = await cli.get(
|
||||||
|
"https://graph.microsoft.com/v1.0/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo.raise_for_status()
|
||||||
|
data = userinfo.json()
|
||||||
|
return OAuthProfile(
|
||||||
|
provider_subject=data["id"],
|
||||||
|
email=data.get("mail") or data["userPrincipalName"],
|
||||||
|
name=data.get("displayName") or data["userPrincipalName"].split("@")[0],
|
||||||
|
)
|
||||||
95
backend/tests/test_oauth_callbacks.py
Normal file
95
backend/tests/test_oauth_callbacks.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
from app.services.oauth_providers import OAuthProfile
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_creates_user_account_subscription(
|
||||||
|
client, test_db, monkeypatch
|
||||||
|
):
|
||||||
|
"""Brand-new user via Google OAuth -> User + Account + Subscription + OAuthIdentity."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||||
|
|
||||||
|
profile = OAuthProfile(
|
||||||
|
provider_subject="google_subject_123",
|
||||||
|
email="newuser@example.com",
|
||||||
|
name="New User",
|
||||||
|
)
|
||||||
|
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
body = response.json()
|
||||||
|
assert body["is_new_user"] is True
|
||||||
|
assert body["access_token"]
|
||||||
|
|
||||||
|
user = (await test_db.execute(
|
||||||
|
select(User).where(User.email == "newuser@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
assert user.password_hash is None
|
||||||
|
assert user.email_verified_at is not None
|
||||||
|
|
||||||
|
identity = (await test_db.execute(
|
||||||
|
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert identity.provider == "google"
|
||||||
|
assert identity.provider_subject == "google_subject_123"
|
||||||
|
|
||||||
|
sub = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == user.account_id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert sub.status == "trialing"
|
||||||
|
assert sub.plan == "pro"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_existing_user_is_idempotent(
|
||||||
|
client, test_db, test_user, monkeypatch
|
||||||
|
):
|
||||||
|
"""When test_user's email is already registered, OAuth links + returns the
|
||||||
|
same user. Two calls with same provider_subject must not duplicate
|
||||||
|
OAuthIdentity rows."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||||
|
|
||||||
|
user_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
email = test_user["email"]
|
||||||
|
name = test_user["user_data"]["name"]
|
||||||
|
|
||||||
|
profile = OAuthProfile(
|
||||||
|
provider_subject="google_subject_456",
|
||||||
|
email=email,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||||
|
r1 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||||
|
r2 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||||
|
assert r1.status_code == 200
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r1.json()["is_new_user"] is False
|
||||||
|
assert r2.json()["is_new_user"] is False
|
||||||
|
|
||||||
|
identities = (await test_db.execute(
|
||||||
|
select(OAuthIdentity).where(OAuthIdentity.user_id == user_id)
|
||||||
|
)).scalars().all()
|
||||||
|
assert len(identities) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_503_when_unconfigured(client, monkeypatch):
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/google/callback", json={"code": "x"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 503
|
||||||
Reference in New Issue
Block a user