fix(auth): store OAuth refresh token JTI to fix /auth/refresh after OAuth signup

OAuth callbacks (POST /auth/google/callback, POST /auth/microsoft/callback)
issued refresh tokens via create_refresh_token() but never persisted the JTI
in the refresh_tokens table. The /auth/refresh rotation logic does a
conditional UPDATE that requires a matching unrevoked row; without it the
first refresh attempt 401s with "Refresh token has been revoked" and OAuth
users get effectively logged out after the ~5 minute access-token expiry.

- Promote _store_refresh_token to module-public store_refresh_token in
  app.api.endpoints.auth (existing callers in /login, /login/json, /refresh
  updated in-place — same module, just renamed).
- OAuth callbacks now call store_refresh_token(...) + db.commit() after
  _sign_in_or_register returns. _sign_in_or_register already commits the
  user/account/identity rows; the refresh-token row gets its own commit.
- Tests:
  - test_oauth_google_callback_stores_refresh_token_jti — asserts the JTI
    hash is in refresh_tokens after a Google callback.
  - test_oauth_refresh_works_after_oauth_signup — full e2e: callback -> use
    returned refresh token at /auth/refresh -> 200 with rotated tokens.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-07 01:30:14 -04:00
parent fee4cb5b74
commit 5e0c9d2de1
3 changed files with 106 additions and 7 deletions

View File

@@ -2,8 +2,10 @@ import uuid
import pytest
from unittest.mock import patch
from sqlalchemy import select
from app.core.security import decode_token, hash_token
from app.models.user import User
from app.models.oauth_identity import OAuthIdentity
from app.models.refresh_token import RefreshToken
from app.models.subscription import Subscription
from app.services.oauth_providers import OAuthProfile
@@ -118,3 +120,77 @@ async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
)).scalar_one()
assert identity.provider == "microsoft"
@pytest.mark.asyncio
async def test_oauth_google_callback_stores_refresh_token_jti(
client, test_db, monkeypatch
):
"""A successful Google OAuth callback must persist the refresh-token JTI
in the refresh_tokens table — otherwise /auth/refresh rejects it."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_jti_test",
email="jtitest@example.com",
name="JTI Test",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert response.status_code == 200, response.json()
body = response.json()
refresh_token_str = body["refresh_token"]
payload = decode_token(refresh_token_str)
assert payload is not None
jti = payload["jti"]
token_hash = hash_token(jti)
user = (await test_db.execute(
select(User).where(User.email == "jtitest@example.com")
)).scalar_one()
stored = (await test_db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)).scalar_one_or_none()
assert stored is not None, "OAuth callback did not persist refresh-token JTI"
assert stored.user_id == user.id
assert stored.revoked_at is None
@pytest.mark.asyncio
async def test_oauth_refresh_works_after_oauth_signup(
client, test_db, monkeypatch
):
"""End-to-end: OAuth callback issues a refresh token; calling /auth/refresh
with that token must succeed (not be rejected as revoked)."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_refresh_test",
email="refresh-after-oauth@example.com",
name="Refresh After OAuth",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
callback_resp = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert callback_resp.status_code == 200, callback_resp.json()
refresh_token_str = callback_resp.json()["refresh_token"]
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token_str}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
refreshed = refresh_resp.json()
assert refreshed["access_token"]
assert refreshed["refresh_token"]
# Token rotation: new refresh token differs from the original.
assert refreshed["refresh_token"] != refresh_token_str