feat(auth): enforce absolute session cap in /auth/refresh
Fourth commit in the session-expiration-policy series. The gate that ends "logged in forever" — refresh now rejects tokens whose original login (auth_time) is older than abs_max seconds. Algorithm (plan §4.5): 1. Decode JWT (dep already handles idle expiry). 2. Load user; reject inactive/missing as invalid_refresh_token. 3. Resolve effective auth_time/idle_max/abs_max, grandfathering pre-PR tokens by snapshotting current account policy. 4. Atomically revoke the JTI regardless of outcome — this consumes the token whether or not the absolute check passes, so an absolute-expired token cannot be replayed forever. 5. If the atomic UPDATE matched zero rows -> invalid_refresh_token. 6. If now >= auth_time + abs_max -> commit the revoke explicitly (so it survives the rollback hook in get_admin_db) and 401 session_expired_absolute. 7. Otherwise mint via _mint_with_claims, carrying claims forward. Boundary check uses `>=`, not `>` — a deadline equal to now is expired. _refresh_session_tokens (commit 3) replaced by two narrower helpers: _resolve_refresh_claims (grandfather logic, no mint) and _mint_with_claims (mint with explicit claims, no grandfather). Makes the endpoint's algorithm read top-down without indirection. Tests added in test_session_policy.py: - #8: backdate auth_time by exactly abs_max -> session_expired_absolute at the deadline boundary. - #9: same token tried twice; first returns session_expired_absolute AND consumes the row; second returns invalid_refresh_token. - #12: legacy token without auth_time/idle_max/abs_max gets one successful rotation; new JWT carries fresh policy snapshot from the account (3d/14d defaults under Strict). 25/25 across test_session_policy + test_auth + test_oauth_callbacks. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -110,25 +110,21 @@ async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _refresh_session_tokens(
|
async def _resolve_refresh_claims(
|
||||||
payload: dict, user: User, db: AsyncSession
|
payload: dict, user: User, db: AsyncSession
|
||||||
) -> Token:
|
) -> tuple[int, int, int]:
|
||||||
"""Carry session-policy claims forward across a refresh-token rotation.
|
"""Return (auth_time, idle_max_seconds, abs_max_seconds) for a refresh.
|
||||||
|
|
||||||
Grandfathers legacy tokens issued before this PR (no auth_time claim)
|
Grandfathers legacy tokens issued before the session-policy PR: tokens
|
||||||
by snapshotting the account's current policy and treating now() as
|
missing any of auth_time/idle_max/abs_max get treated as if just minted
|
||||||
auth_time — i.e. one free rotation under the new policy. Caller
|
under the account's current policy. One free rotation under the new
|
||||||
commits.
|
rules — see plan §5.1. Callers that have the claims use them as-is.
|
||||||
|
|
||||||
Does NOT enforce the absolute cap — that lands in the next commit so
|
|
||||||
the cap can be rolled back independently if needed.
|
|
||||||
"""
|
"""
|
||||||
auth_time = payload.get("auth_time")
|
auth_time = payload.get("auth_time")
|
||||||
idle_max_seconds = payload.get("idle_max")
|
idle_max_seconds = payload.get("idle_max")
|
||||||
abs_max_seconds = payload.get("abs_max")
|
abs_max_seconds = payload.get("abs_max")
|
||||||
|
|
||||||
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
|
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
|
||||||
# Grandfather path — legacy token from before the session-policy PR.
|
|
||||||
account = (
|
account = (
|
||||||
await db.execute(select(Account).where(Account.id == user.account_id))
|
await db.execute(select(Account).where(Account.id == user.account_id))
|
||||||
).scalar_one()
|
).scalar_one()
|
||||||
@@ -137,6 +133,21 @@ async def _refresh_session_tokens(
|
|||||||
idle_max_seconds = idle_minutes * 60
|
idle_max_seconds = idle_minutes * 60
|
||||||
abs_max_seconds = abs_minutes * 60
|
abs_max_seconds = abs_minutes * 60
|
||||||
|
|
||||||
|
return auth_time, idle_max_seconds, abs_max_seconds
|
||||||
|
|
||||||
|
|
||||||
|
async def _mint_with_claims(
|
||||||
|
user: User,
|
||||||
|
auth_time: int,
|
||||||
|
idle_max_seconds: int,
|
||||||
|
abs_max_seconds: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Token:
|
||||||
|
"""Mint a refresh+access pair carrying explicit session-policy claims.
|
||||||
|
|
||||||
|
Used by /auth/refresh after the grandfather + absolute-cap checks
|
||||||
|
have already produced the effective claim values. Caller commits.
|
||||||
|
"""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
refresh_token_str = create_refresh_token(
|
refresh_token_str = create_refresh_token(
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
@@ -452,13 +463,39 @@ async def refresh_token(
|
|||||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
"""Refresh access token, enforcing both idle and absolute session windows.
|
||||||
|
|
||||||
|
Algorithm (see plan §4.5):
|
||||||
|
|
||||||
|
1. Decode refresh JWT (the dep already rejects idle-expired tokens with
|
||||||
|
session_expired_idle).
|
||||||
|
2. Load the user. If missing or inactive, 401 invalid_refresh_token.
|
||||||
|
3. Resolve effective auth_time/idle_max/abs_max (grandfather legacy
|
||||||
|
tokens that pre-date this PR).
|
||||||
|
4. Atomically revoke the JTI regardless of outcome — so an absolute-
|
||||||
|
expired token cannot be replayed; the second attempt finds it
|
||||||
|
already revoked and gets invalid_refresh_token instead.
|
||||||
|
5. If the atomic UPDATE matched zero rows, 401 invalid_refresh_token.
|
||||||
|
6. If now >= auth_time + abs_max, 401 session_expired_absolute.
|
||||||
|
7. Otherwise mint new tokens carrying the claims forward.
|
||||||
|
"""
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("sub")
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
|
|
||||||
# Atomically revoke the old refresh token (token rotation).
|
user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
|
||||||
# Using a conditional UPDATE prevents the race where two concurrent
|
if not user or not user.is_active:
|
||||||
# refresh requests both read revoked_at=NULL and both succeed.
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_refresh_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_time, idle_max_seconds, abs_max_seconds = await _resolve_refresh_claims(
|
||||||
|
payload, user, db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Atomically revoke the old refresh token first — this consumes the
|
||||||
|
# token regardless of whether the absolute check passes, so an absolute-
|
||||||
|
# expired token cannot be replayed.
|
||||||
if jti:
|
if jti:
|
||||||
token_hash = hash_token(jti)
|
token_hash = hash_token(jti)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -471,26 +508,29 @@ async def refresh_token(
|
|||||||
.returning(RefreshToken.id, RefreshToken.user_id)
|
.returning(RefreshToken.id, RefreshToken.user_id)
|
||||||
)
|
)
|
||||||
revoked_row = result.fetchone()
|
revoked_row = result.fetchone()
|
||||||
|
|
||||||
if not revoked_row:
|
if not revoked_row:
|
||||||
# Either the token doesn't exist or was already revoked/used.
|
|
||||||
# Surfaced to the frontend as a plain logout — not "session
|
|
||||||
# expired" — because the user did not hit a policy boundary.
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="invalid_refresh_token"
|
detail="invalid_refresh_token",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
# Absolute-window check. Boundary is `>=`, not `>` — a deadline equal to
|
||||||
user = result.scalar_one_or_none()
|
# now is expired. The token row has already been revoked above, so the
|
||||||
|
# client cannot retry this token even though we're raising after the
|
||||||
if not user:
|
# consume.
|
||||||
|
now_unix = int(datetime.now(timezone.utc).timestamp())
|
||||||
|
if now_unix >= auth_time + abs_max_seconds:
|
||||||
|
# Commit the revoke so the consumed-on-failure invariant survives
|
||||||
|
# any subsequent rollback in the request lifecycle.
|
||||||
|
await db.commit()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="invalid_refresh_token"
|
detail="session_expired_absolute",
|
||||||
)
|
)
|
||||||
|
|
||||||
token = await _refresh_session_tokens(payload, user, db)
|
token = await _mint_with_claims(
|
||||||
|
user, auth_time, idle_max_seconds, abs_max_seconds, db
|
||||||
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Test numbers below correspond to the cases listed in §6 of the plan.
|
|||||||
This file grows across commits:
|
This file grows across commits:
|
||||||
- Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
|
- Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
|
||||||
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
|
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
|
||||||
|
- Commit 4: absolute-cap enforcement + grandfather path (#8, #9, #12)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@@ -179,3 +180,133 @@ class TestSessionPolicyClaims:
|
|||||||
assert new_payload["exp"] >= original_payload["exp"]
|
assert new_payload["exp"] >= original_payload["exp"]
|
||||||
# JTI rotates.
|
# JTI rotates.
|
||||||
assert new_payload["jti"] != original_payload["jti"]
|
assert new_payload["jti"] != original_payload["jti"]
|
||||||
|
|
||||||
|
|
||||||
|
def _backdate_auth_time(refresh_token: str, *, seconds_back: int) -> str:
|
||||||
|
"""Re-sign a refresh JWT with an earlier auth_time, preserving JTI.
|
||||||
|
|
||||||
|
The DB row in refresh_tokens is keyed on hash(jti), so preserving jti
|
||||||
|
lets the atomic revoke step still find the row. Used to simulate
|
||||||
|
"this session is past its absolute cap" without waiting two weeks.
|
||||||
|
"""
|
||||||
|
payload = jwt.decode(
|
||||||
|
refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
payload["auth_time"] = payload["auth_time"] - seconds_back
|
||||||
|
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbsoluteCap:
|
||||||
|
"""§6 tests #8, #9, #12 — absolute-cap enforcement and grandfather path."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_at_absolute_deadline_rejects(
|
||||||
|
self, client: AsyncClient, test_user: dict
|
||||||
|
):
|
||||||
|
"""§6 test #8 — boundary check uses `>=`, not `>`.
|
||||||
|
|
||||||
|
A token whose auth_time + abs_max equals now() is expired, not
|
||||||
|
valid. Backdate the original token's auth_time by exactly abs_max
|
||||||
|
seconds so now >= deadline.
|
||||||
|
"""
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": test_user["email"], "password": test_user["password"]},
|
||||||
|
)
|
||||||
|
original = login_resp.json()["refresh_token"]
|
||||||
|
abs_max = jwt.decode(
|
||||||
|
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)["abs_max"]
|
||||||
|
|
||||||
|
expired = _backdate_auth_time(original, seconds_back=abs_max)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {expired}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "session_expired_absolute"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_absolute_expired_token_is_consumed(
|
||||||
|
self, client: AsyncClient, test_user: dict
|
||||||
|
):
|
||||||
|
"""§6 test #9 — first attempt returns session_expired_absolute and
|
||||||
|
revokes the row; second attempt sees the revoked row and returns
|
||||||
|
invalid_refresh_token. Prevents replay of an absolute-expired token.
|
||||||
|
"""
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": test_user["email"], "password": test_user["password"]},
|
||||||
|
)
|
||||||
|
original = login_resp.json()["refresh_token"]
|
||||||
|
abs_max = jwt.decode(
|
||||||
|
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)["abs_max"]
|
||||||
|
expired = _backdate_auth_time(original, seconds_back=abs_max + 1)
|
||||||
|
|
||||||
|
first = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {expired}"},
|
||||||
|
)
|
||||||
|
assert first.status_code == 401
|
||||||
|
assert first.json()["detail"] == "session_expired_absolute"
|
||||||
|
|
||||||
|
second = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {expired}"},
|
||||||
|
)
|
||||||
|
assert second.status_code == 401
|
||||||
|
assert second.json()["detail"] == "invalid_refresh_token"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grandfather_path_for_legacy_token(
|
||||||
|
self, client: AsyncClient, test_user: dict, test_db
|
||||||
|
):
|
||||||
|
"""§6 test #12 — refresh token issued before this PR (no auth_time
|
||||||
|
claim) gets one successful rotation; the new token has fresh
|
||||||
|
auth_time/idle_max/abs_max claims snapshotted from current policy.
|
||||||
|
"""
|
||||||
|
from app.core.security import hash_token
|
||||||
|
from app.models.refresh_token import RefreshToken
|
||||||
|
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": test_user["email"], "password": test_user["password"]},
|
||||||
|
)
|
||||||
|
original = login_resp.json()["refresh_token"]
|
||||||
|
original_payload = jwt.decode(
|
||||||
|
original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip the new claims to simulate a token issued before this PR.
|
||||||
|
# JTI preserved so the DB-side revoke still finds the row.
|
||||||
|
legacy_payload = {
|
||||||
|
"sub": original_payload["sub"],
|
||||||
|
"type": "refresh",
|
||||||
|
"jti": original_payload["jti"],
|
||||||
|
"exp": original_payload["exp"],
|
||||||
|
}
|
||||||
|
legacy_token = jwt.encode(
|
||||||
|
legacy_payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {legacy_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
new_payload = jwt.decode(
|
||||||
|
response.json()["refresh_token"],
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM],
|
||||||
|
)
|
||||||
|
assert new_payload.get("auth_time") is not None
|
||||||
|
assert new_payload.get("idle_max") == 3 * 24 * 60 * 60
|
||||||
|
assert new_payload.get("abs_max") == 14 * 24 * 60 * 60
|
||||||
|
# auth_time was set to ~now during grandfather, not preserved from
|
||||||
|
# the legacy token (since the legacy token didn't have one).
|
||||||
|
now_unix = int(datetime.now(timezone.utc).timestamp())
|
||||||
|
assert abs(new_payload["auth_time"] - now_unix) < 10
|
||||||
|
|||||||
Reference in New Issue
Block a user