feat(auth): enforce absolute session cap in /auth/refresh

Fourth commit in the session-expiration-policy series. The gate that
ends "logged in forever" — refresh now rejects tokens whose original
login (auth_time) is older than abs_max seconds.

Algorithm (plan §4.5):
1. Decode JWT (dep already handles idle expiry).
2. Load user; reject inactive/missing as invalid_refresh_token.
3. Resolve effective auth_time/idle_max/abs_max, grandfathering
   pre-PR tokens by snapshotting current account policy.
4. Atomically revoke the JTI regardless of outcome — this consumes
   the token whether or not the absolute check passes, so an
   absolute-expired token cannot be replayed forever.
5. If the atomic UPDATE matched zero rows -> invalid_refresh_token.
6. If now >= auth_time + abs_max -> commit the revoke explicitly
   (so it survives the rollback hook in get_admin_db) and 401
   session_expired_absolute.
7. Otherwise mint via _mint_with_claims, carrying claims forward.

Boundary check uses `>=`, not `>` — a deadline equal to now is
expired. _refresh_session_tokens (commit 3) replaced by two narrower
helpers: _resolve_refresh_claims (grandfather logic, no mint) and
_mint_with_claims (mint with explicit claims, no grandfather). Makes
the endpoint's algorithm read top-down without indirection.

Tests added in test_session_policy.py:
- #8: backdate auth_time by exactly abs_max -> session_expired_absolute
  at the deadline boundary.
- #9: same token tried twice; first returns session_expired_absolute
  AND consumes the row; second returns invalid_refresh_token.
- #12: legacy token without auth_time/idle_max/abs_max gets one
  successful rotation; new JWT carries fresh policy snapshot from
  the account (3d/14d defaults under Strict).

25/25 across test_session_policy + test_auth + test_oauth_callbacks.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-13 16:26:00 -04:00
parent d6a02ee8da
commit b21d2fc234
2 changed files with 197 additions and 26 deletions

View File

@@ -110,25 +110,21 @@ async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
)
async def _refresh_session_tokens(
async def _resolve_refresh_claims(
payload: dict, user: User, db: AsyncSession
) -> Token:
"""Carry session-policy claims forward across a refresh-token rotation.
) -> tuple[int, int, int]:
"""Return (auth_time, idle_max_seconds, abs_max_seconds) for a refresh.
Grandfathers legacy tokens issued before this PR (no auth_time claim)
by snapshotting the account's current policy and treating now() as
auth_time — i.e. one free rotation under the new policy. Caller
commits.
Does NOT enforce the absolute cap — that lands in the next commit so
the cap can be rolled back independently if needed.
Grandfathers legacy tokens issued before the session-policy PR: tokens
missing any of auth_time/idle_max/abs_max get treated as if just minted
under the account's current policy. One free rotation under the new
rules — see plan §5.1. Callers that have the claims use them as-is.
"""
auth_time = payload.get("auth_time")
idle_max_seconds = payload.get("idle_max")
abs_max_seconds = payload.get("abs_max")
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
# Grandfather path — legacy token from before the session-policy PR.
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
@@ -137,6 +133,21 @@ async def _refresh_session_tokens(
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
return auth_time, idle_max_seconds, abs_max_seconds
async def _mint_with_claims(
user: User,
auth_time: int,
idle_max_seconds: int,
abs_max_seconds: int,
db: AsyncSession,
) -> Token:
"""Mint a refresh+access pair carrying explicit session-policy claims.
Used by /auth/refresh after the grandfather + absolute-cap checks
have already produced the effective claim values. Caller commits.
"""
now = datetime.now(timezone.utc)
refresh_token_str = create_refresh_token(
user_id=str(user.id),
@@ -452,13 +463,39 @@ async def refresh_token(
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Refresh access token using refresh token (rotation: old token is revoked)."""
"""Refresh access token, enforcing both idle and absolute session windows.
Algorithm (see plan §4.5):
1. Decode refresh JWT (the dep already rejects idle-expired tokens with
session_expired_idle).
2. Load the user. If missing or inactive, 401 invalid_refresh_token.
3. Resolve effective auth_time/idle_max/abs_max (grandfather legacy
tokens that pre-date this PR).
4. Atomically revoke the JTI regardless of outcome — so an absolute-
expired token cannot be replayed; the second attempt finds it
already revoked and gets invalid_refresh_token instead.
5. If the atomic UPDATE matched zero rows, 401 invalid_refresh_token.
6. If now >= auth_time + abs_max, 401 session_expired_absolute.
7. Otherwise mint new tokens carrying the claims forward.
"""
user_id = payload.get("sub")
jti = payload.get("jti")
# Atomically revoke the old refresh token (token rotation).
# Using a conditional UPDATE prevents the race where two concurrent
# refresh requests both read revoked_at=NULL and both succeed.
user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_refresh_token",
)
auth_time, idle_max_seconds, abs_max_seconds = await _resolve_refresh_claims(
payload, user, db
)
# Atomically revoke the old refresh token first — this consumes the
# token regardless of whether the absolute check passes, so an absolute-
# expired token cannot be replayed.
if jti:
token_hash = hash_token(jti)
result = await db.execute(
@@ -471,26 +508,29 @@ async def refresh_token(
.returning(RefreshToken.id, RefreshToken.user_id)
)
revoked_row = result.fetchone()
if not revoked_row:
# Either the token doesn't exist or was already revoked/used.
# Surfaced to the frontend as a plain logout — not "session
# expired" — because the user did not hit a policy boundary.
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_refresh_token"
detail="invalid_refresh_token",
)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
# Absolute-window check. Boundary is `>=`, not `>` — a deadline equal to
# now is expired. The token row has already been revoked above, so the
# client cannot retry this token even though we're raising after the
# consume.
now_unix = int(datetime.now(timezone.utc).timestamp())
if now_unix >= auth_time + abs_max_seconds:
# Commit the revoke so the consumed-on-failure invariant survives
# any subsequent rollback in the request lifecycle.
await db.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_refresh_token"
detail="session_expired_absolute",
)
token = await _refresh_session_tokens(payload, user, db)
token = await _mint_with_claims(
user, auth_time, idle_max_seconds, abs_max_seconds, db
)
await db.commit()
return token

View File

@@ -6,6 +6,7 @@ Test numbers below correspond to the cases listed in §6 of the plan.
This file grows across commits:
- Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
- Commit 4: absolute-cap enforcement + grandfather path (#8, #9, #12)
"""
import uuid
@@ -179,3 +180,133 @@ class TestSessionPolicyClaims:
assert new_payload["exp"] >= original_payload["exp"]
# JTI rotates.
assert new_payload["jti"] != original_payload["jti"]
def _backdate_auth_time(refresh_token: str, *, seconds_back: int) -> str:
"""Re-sign a refresh JWT with an earlier auth_time, preserving JTI.
The DB row in refresh_tokens is keyed on hash(jti), so preserving jti
lets the atomic revoke step still find the row. Used to simulate
"this session is past its absolute cap" without waiting two weeks.
"""
payload = jwt.decode(
refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
payload["auth_time"] = payload["auth_time"] - seconds_back
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
class TestAbsoluteCap:
"""§6 tests #8, #9, #12 — absolute-cap enforcement and grandfather path."""
@pytest.mark.asyncio
async def test_refresh_at_absolute_deadline_rejects(
self, client: AsyncClient, test_user: dict
):
"""§6 test #8 — boundary check uses `>=`, not `>`.
A token whose auth_time + abs_max equals now() is expired, not
valid. Backdate the original token's auth_time by exactly abs_max
seconds so now >= deadline.
"""
login_resp = await client.post(
"/api/v1/auth/login/json",
json={"email": test_user["email"], "password": test_user["password"]},
)
original = login_resp.json()["refresh_token"]
abs_max = jwt.decode(
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)["abs_max"]
expired = _backdate_auth_time(original, seconds_back=abs_max)
response = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {expired}"},
)
assert response.status_code == 401
assert response.json()["detail"] == "session_expired_absolute"
@pytest.mark.asyncio
async def test_absolute_expired_token_is_consumed(
self, client: AsyncClient, test_user: dict
):
"""§6 test #9 — first attempt returns session_expired_absolute and
revokes the row; second attempt sees the revoked row and returns
invalid_refresh_token. Prevents replay of an absolute-expired token.
"""
login_resp = await client.post(
"/api/v1/auth/login/json",
json={"email": test_user["email"], "password": test_user["password"]},
)
original = login_resp.json()["refresh_token"]
abs_max = jwt.decode(
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)["abs_max"]
expired = _backdate_auth_time(original, seconds_back=abs_max + 1)
first = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {expired}"},
)
assert first.status_code == 401
assert first.json()["detail"] == "session_expired_absolute"
second = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {expired}"},
)
assert second.status_code == 401
assert second.json()["detail"] == "invalid_refresh_token"
@pytest.mark.asyncio
async def test_grandfather_path_for_legacy_token(
self, client: AsyncClient, test_user: dict, test_db
):
"""§6 test #12 — refresh token issued before this PR (no auth_time
claim) gets one successful rotation; the new token has fresh
auth_time/idle_max/abs_max claims snapshotted from current policy.
"""
from app.core.security import hash_token
from app.models.refresh_token import RefreshToken
login_resp = await client.post(
"/api/v1/auth/login/json",
json={"email": test_user["email"], "password": test_user["password"]},
)
original = login_resp.json()["refresh_token"]
original_payload = jwt.decode(
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Strip the new claims to simulate a token issued before this PR.
# JTI preserved so the DB-side revoke still finds the row.
legacy_payload = {
"sub": original_payload["sub"],
"type": "refresh",
"jti": original_payload["jti"],
"exp": original_payload["exp"],
}
legacy_token = jwt.encode(
legacy_payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
response = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {legacy_token}"},
)
assert response.status_code == 200, response.json()
new_payload = jwt.decode(
response.json()["refresh_token"],
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
)
assert new_payload.get("auth_time") is not None
assert new_payload.get("idle_max") == 3 * 24 * 60 * 60
assert new_payload.get("abs_max") == 14 * 24 * 60 * 60
# auth_time was set to ~now during grandfather, not preserved from
# the legacy token (since the legacy token didn't have one).
now_unix = int(datetime.now(timezone.utc).timestamp())
assert abs(new_payload["auth_time"] - now_unix) < 10