fix(auth): store OAuth refresh token JTI to fix /auth/refresh after OAuth signup

OAuth callbacks (POST /auth/google/callback, POST /auth/microsoft/callback)
issued refresh tokens via create_refresh_token() but never persisted the JTI
in the refresh_tokens table. The /auth/refresh rotation logic does a
conditional UPDATE that requires a matching unrevoked row; without it the
first refresh attempt 401s with "Refresh token has been revoked" and OAuth
users get effectively logged out after the ~5 minute access-token expiry.

- Promote _store_refresh_token to module-public store_refresh_token in
  app.api.endpoints.auth (existing callers in /login, /login/json, /refresh
  updated in-place — same module, just renamed).
- OAuth callbacks now call store_refresh_token(...) + db.commit() after
  _sign_in_or_register returns. _sign_in_or_register already commits the
  user/account/identity rows; the refresh-token row gets its own commit.
- Tests:
  - test_oauth_google_callback_stores_refresh_token_jti — asserts the JTI
    hash is in refresh_tokens after a Google callback.
  - test_oauth_refresh_works_after_oauth_signup — full e2e: callback -> use
    returned refresh token at /auth/refresh -> 200 with rotated tokens.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-07 01:30:14 -04:00
parent fee4cb5b74
commit 5e0c9d2de1
3 changed files with 106 additions and 7 deletions

View File

@@ -47,8 +47,16 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["authentication"]) router = APIRouter(prefix="/auth", tags=["authentication"])
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None: async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database.""" """Decode a refresh token JWT and store its hash in the database.
Module-public so OAuth callback endpoints (and any future token-issuing
surface) can register the JTI in the ``refresh_tokens`` table the same
way ``/auth/login`` does. Without this the first ``/auth/refresh`` call
will reject the token as "revoked" because no row exists.
Caller is responsible for committing the session.
"""
payload = decode_token(refresh_token_str) payload = decode_token(refresh_token_str)
if payload and payload.get("jti"): if payload and payload.get("jti"):
token_record = RefreshToken( token_record = RefreshToken(
@@ -320,7 +328,7 @@ async def login(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB # Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id) await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return Token( return Token(
@@ -355,7 +363,7 @@ async def login_json(
refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB # Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id) await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return Token( return Token(
@@ -413,7 +421,7 @@ async def refresh_token(
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)}) new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token # Store new refresh token
await _store_refresh_token(db, new_refresh_token_str, user.id) await store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit() await db.commit()
return Token( return Token(

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.endpoints.auth import store_refresh_token
from app.core.admin_database import get_admin_db from app.core.admin_database import get_admin_db
from app.core.config import settings from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token from app.core.security import create_access_token, create_refresh_token
@@ -186,9 +187,16 @@ async def google_callback(
account_invite_code=payload.account_invite_code, account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email, invited_email=payload.invited_email,
) )
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse( return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}), access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}), refresh_token=refresh_token_str,
is_new_user=is_new, is_new_user=is_new,
) )
@@ -209,8 +217,15 @@ async def microsoft_callback(
account_invite_code=payload.account_invite_code, account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email, invited_email=payload.invited_email,
) )
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return OAuthCallbackResponse( return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}), access_token=create_access_token({"sub": str(user.id)}),
refresh_token=create_refresh_token({"sub": str(user.id)}), refresh_token=refresh_token_str,
is_new_user=is_new, is_new_user=is_new,
) )

View File

@@ -2,8 +2,10 @@ import uuid
import pytest import pytest
from unittest.mock import patch from unittest.mock import patch
from sqlalchemy import select from sqlalchemy import select
from app.core.security import decode_token, hash_token
from app.models.user import User from app.models.user import User
from app.models.oauth_identity import OAuthIdentity from app.models.oauth_identity import OAuthIdentity
from app.models.refresh_token import RefreshToken
from app.models.subscription import Subscription from app.models.subscription import Subscription
from app.services.oauth_providers import OAuthProfile from app.services.oauth_providers import OAuthProfile
@@ -118,3 +120,77 @@ async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id) select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
)).scalar_one() )).scalar_one()
assert identity.provider == "microsoft" assert identity.provider == "microsoft"
@pytest.mark.asyncio
async def test_oauth_google_callback_stores_refresh_token_jti(
client, test_db, monkeypatch
):
"""A successful Google OAuth callback must persist the refresh-token JTI
in the refresh_tokens table — otherwise /auth/refresh rejects it."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_jti_test",
email="jtitest@example.com",
name="JTI Test",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert response.status_code == 200, response.json()
body = response.json()
refresh_token_str = body["refresh_token"]
payload = decode_token(refresh_token_str)
assert payload is not None
jti = payload["jti"]
token_hash = hash_token(jti)
user = (await test_db.execute(
select(User).where(User.email == "jtitest@example.com")
)).scalar_one()
stored = (await test_db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)).scalar_one_or_none()
assert stored is not None, "OAuth callback did not persist refresh-token JTI"
assert stored.user_id == user.id
assert stored.revoked_at is None
@pytest.mark.asyncio
async def test_oauth_refresh_works_after_oauth_signup(
client, test_db, monkeypatch
):
"""End-to-end: OAuth callback issues a refresh token; calling /auth/refresh
with that token must succeed (not be rejected as revoked)."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_refresh_test",
email="refresh-after-oauth@example.com",
name="Refresh After OAuth",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
callback_resp = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert callback_resp.status_code == 200, callback_resp.json()
refresh_token_str = callback_resp.json()["refresh_token"]
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token_str}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
refreshed = refresh_resp.json()
assert refreshed["access_token"]
assert refreshed["refresh_token"]
# Token rotation: new refresh token differs from the original.
assert refreshed["refresh_token"] != refresh_token_str