diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 7f5e017f..7b8a1ae1 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -47,8 +47,16 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["authentication"]) -async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None: - """Decode a refresh token JWT and store its hash in the database.""" +async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None: + """Decode a refresh token JWT and store its hash in the database. + + Module-public so OAuth callback endpoints (and any future token-issuing + surface) can register the JTI in the ``refresh_tokens`` table the same + way ``/auth/login`` does. Without this the first ``/auth/refresh`` call + will reject the token as "revoked" because no row exists. + + Caller is responsible for committing the session. + """ payload = decode_token(refresh_token_str) if payload and payload.get("jti"): token_record = RefreshToken( @@ -320,7 +328,7 @@ async def login( refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB - await _store_refresh_token(db, refresh_token_str, user.id) + await store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( @@ -355,7 +363,7 @@ async def login_json( refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store refresh token hash in DB - await _store_refresh_token(db, refresh_token_str, user.id) + await store_refresh_token(db, refresh_token_str, user.id) await db.commit() return Token( @@ -413,7 +421,7 @@ async def refresh_token( new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) # Store new refresh token - await _store_refresh_token(db, new_refresh_token_str, user.id) + await store_refresh_token(db, new_refresh_token_str, user.id) await db.commit() return Token( diff --git a/backend/app/api/endpoints/oauth.py b/backend/app/api/endpoints/oauth.py index dcf49263..446c686f 100644 --- a/backend/app/api/endpoints/oauth.py +++ b/backend/app/api/endpoints/oauth.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.api.endpoints.auth import store_refresh_token from app.core.admin_database import get_admin_db from app.core.config import settings from app.core.security import create_access_token, create_refresh_token @@ -186,9 +187,16 @@ async def google_callback( account_invite_code=payload.account_invite_code, invited_email=payload.invited_email, ) + refresh_token_str = create_refresh_token({"sub": str(user.id)}) + # Persist the refresh-token JTI so the first /auth/refresh call doesn't + # reject this token as "revoked" (the rotation logic requires a row to + # mark as used). _sign_in_or_register already committed; this needs a + # second commit. + await store_refresh_token(db, refresh_token_str, user.id) + await db.commit() return OAuthCallbackResponse( access_token=create_access_token({"sub": str(user.id)}), - refresh_token=create_refresh_token({"sub": str(user.id)}), + refresh_token=refresh_token_str, is_new_user=is_new, ) @@ -209,8 +217,15 @@ async def microsoft_callback( account_invite_code=payload.account_invite_code, invited_email=payload.invited_email, ) + refresh_token_str = create_refresh_token({"sub": str(user.id)}) + # Persist the refresh-token JTI so the first /auth/refresh call doesn't + # reject this token as "revoked" (the rotation logic requires a row to + # mark as used). _sign_in_or_register already committed; this needs a + # second commit. + await store_refresh_token(db, refresh_token_str, user.id) + await db.commit() return OAuthCallbackResponse( access_token=create_access_token({"sub": str(user.id)}), - refresh_token=create_refresh_token({"sub": str(user.id)}), + refresh_token=refresh_token_str, is_new_user=is_new, ) diff --git a/backend/tests/test_oauth_callbacks.py b/backend/tests/test_oauth_callbacks.py index f31e1688..7c14c7b7 100644 --- a/backend/tests/test_oauth_callbacks.py +++ b/backend/tests/test_oauth_callbacks.py @@ -2,8 +2,10 @@ import uuid import pytest from unittest.mock import patch from sqlalchemy import select +from app.core.security import decode_token, hash_token from app.models.user import User from app.models.oauth_identity import OAuthIdentity +from app.models.refresh_token import RefreshToken from app.models.subscription import Subscription from app.services.oauth_providers import OAuthProfile @@ -118,3 +120,77 @@ async def test_microsoft_callback_creates_user(client, test_db, monkeypatch): select(OAuthIdentity).where(OAuthIdentity.user_id == user.id) )).scalar_one() assert identity.provider == "microsoft" + + +@pytest.mark.asyncio +async def test_oauth_google_callback_stores_refresh_token_jti( + client, test_db, monkeypatch +): + """A successful Google OAuth callback must persist the refresh-token JTI + in the refresh_tokens table — otherwise /auth/refresh rejects it.""" + from app.core.config import settings + monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy") + monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy") + + profile = OAuthProfile( + provider_subject="google_subject_jti_test", + email="jtitest@example.com", + name="JTI Test", + ) + with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile): + response = await client.post( + "/api/v1/auth/google/callback", json={"code": "auth_code_xyz"} + ) + assert response.status_code == 200, response.json() + body = response.json() + refresh_token_str = body["refresh_token"] + + payload = decode_token(refresh_token_str) + assert payload is not None + jti = payload["jti"] + token_hash = hash_token(jti) + + user = (await test_db.execute( + select(User).where(User.email == "jtitest@example.com") + )).scalar_one() + + stored = (await test_db.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + )).scalar_one_or_none() + assert stored is not None, "OAuth callback did not persist refresh-token JTI" + assert stored.user_id == user.id + assert stored.revoked_at is None + + +@pytest.mark.asyncio +async def test_oauth_refresh_works_after_oauth_signup( + client, test_db, monkeypatch +): + """End-to-end: OAuth callback issues a refresh token; calling /auth/refresh + with that token must succeed (not be rejected as revoked).""" + from app.core.config import settings + monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy") + monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy") + + profile = OAuthProfile( + provider_subject="google_subject_refresh_test", + email="refresh-after-oauth@example.com", + name="Refresh After OAuth", + ) + with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile): + callback_resp = await client.post( + "/api/v1/auth/google/callback", json={"code": "auth_code_xyz"} + ) + assert callback_resp.status_code == 200, callback_resp.json() + refresh_token_str = callback_resp.json()["refresh_token"] + + refresh_resp = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {refresh_token_str}"}, + ) + assert refresh_resp.status_code == 200, refresh_resp.json() + refreshed = refresh_resp.json() + assert refreshed["access_token"] + assert refreshed["refresh_token"] + # Token rotation: new refresh token differs from the original. + assert refreshed["refresh_token"] != refresh_token_str