feat(auth): embed auth_time/idle_max/abs_max in refresh tokens at every login

Third commit in the session-expiration-policy series. Every refresh token
issued from now on carries the policy snapshot in its JWT (in seconds,
for direct Unix math), and every login/OAuth response surfaces both
expiry windows as ISO timestamps. /auth/refresh carries the claims
forward unchanged — including auth_time, which never resets on rotation.

Does NOT yet enforce the absolute cap — that's commit 4, sequenced so
the gate can be reverted independently if pilots hit an edge case.
But the wire is fully populated, and a grandfather path is already in
_refresh_session_tokens for tokens issued before this PR.

Key changes:
- core/security.py: create_refresh_token signature changes to
  (user_id, *, auth_time, idle_max_seconds, abs_max_seconds). Adds
  resolve_session_policy(account) -> (idle_minutes, absolute_minutes)
  applying defaults for NULL overrides.
- schemas/token.py + schemas/oauth.py: Token and OAuthCallbackResponse
  gain idle_expires_at + absolute_expires_at (Optional[datetime],
  Pydantic emits ISO 8601 UTC strings).
- endpoints/auth.py: new _mint_session_tokens(user, db) and
  _refresh_session_tokens(payload, user, db) helpers. /auth/login,
  /auth/login/json, and /auth/refresh now route through them. The
  refresh endpoint's pre-existing "Refresh token has been revoked"
  error normalized to the taxonomy detail "invalid_refresh_token".
- endpoints/oauth.py: both Google and Microsoft callbacks call
  _mint_session_tokens; OAuthCallbackResponse carries the expiry
  fields through.
- tests: two new cases in test_session_policy.py — login_json embeds
  the claims with strict defaults (3d/14d -> 259200/1209600 sec) and
  surfaces matching ISO expiry fields; refresh carries auth_time,
  idle_max, abs_max forward unchanged across rotation.

35/35 across test_session_policy + test_auth + test_oauth_callbacks +
test_account_invite_lookup + test_account_management.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-13 16:22:53 -04:00
parent 2375948b7a
commit d6a02ee8da
6 changed files with 255 additions and 66 deletions

View File

@@ -20,6 +20,7 @@ from app.core.security import (
create_email_verification_token,
decode_token,
hash_token,
resolve_session_policy,
)
from app.models.user import User
from app.models.invite_code import InviteCode
@@ -67,6 +68,97 @@ async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id)
db.add(token_record)
async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
"""Mint a fresh refresh+access pair for a new login.
Snapshots the account's current session policy into the refresh JWT
(auth_time/idle_max/abs_max) and registers the JTI in refresh_tokens.
Caller is responsible for committing the session. Use this for every
NEW login (password, OAuth, etc.) — for /auth/refresh use
_refresh_session_tokens instead, which carries claims forward.
See docs/plans/2026-05-13-session-expiration-policy.md §4.6.
"""
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
now = datetime.now(timezone.utc)
auth_time_unix = int(now.timestamp())
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time_unix,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time_unix + abs_max_seconds, tz=timezone.utc
),
)
async def _refresh_session_tokens(
payload: dict, user: User, db: AsyncSession
) -> Token:
"""Carry session-policy claims forward across a refresh-token rotation.
Grandfathers legacy tokens issued before this PR (no auth_time claim)
by snapshotting the account's current policy and treating now() as
auth_time — i.e. one free rotation under the new policy. Caller
commits.
Does NOT enforce the absolute cap — that lands in the next commit so
the cap can be rolled back independently if needed.
"""
auth_time = payload.get("auth_time")
idle_max_seconds = payload.get("idle_max")
abs_max_seconds = payload.get("abs_max")
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
# Grandfather path — legacy token from before the session-policy PR.
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
auth_time = int(datetime.now(timezone.utc).timestamp())
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
now = datetime.now(timezone.utc)
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time + abs_max_seconds, tz=timezone.utc
),
)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
@@ -323,20 +415,9 @@ async def login(
# Update last login
user.last_login = datetime.now(timezone.utc)
# Create tokens
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await store_refresh_token(db, refresh_token_str, user.id)
token = await _mint_session_tokens(user, db)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
return token
@router.post("/login/json", response_model=Token)
@@ -359,19 +440,9 @@ async def login_json(
user.last_login = datetime.now(timezone.utc)
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await store_refresh_token(db, refresh_token_str, user.id)
token = await _mint_session_tokens(user, db)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
return token
@router.post("/refresh", response_model=Token)
@@ -402,10 +473,12 @@ async def refresh_token(
revoked_row = result.fetchone()
if not revoked_row:
# Either the token doesn't exist or was already revoked/used
# Either the token doesn't exist or was already revoked/used.
# Surfaced to the frontend as a plain logout — not "session
# expired" — because the user did not hit a policy boundary.
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
detail="invalid_refresh_token"
)
result = await db.execute(select(User).where(User.id == user_id))
@@ -414,21 +487,12 @@ async def refresh_token(
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
detail="invalid_refresh_token"
)
access_token = create_access_token(data={"sub": str(user.id)})
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await store_refresh_token(db, new_refresh_token_str, user.id)
token = await _refresh_session_tokens(payload, user, db)
await db.commit()
return Token(
access_token=access_token,
refresh_token=new_refresh_token_str,
token_type="bearer"
)
return token
@router.get("/me", response_model=UserResponse)

View File

@@ -7,10 +7,9 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.endpoints.auth import store_refresh_token
from app.api.endpoints.auth import _mint_session_tokens
from app.core.admin_database import get_admin_db
from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.oauth_identity import OAuthIdentity
@@ -187,17 +186,14 @@ async def google_callback(
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
token = await _mint_session_tokens(user, db)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
access_token=token.access_token,
refresh_token=token.refresh_token,
is_new_user=is_new,
idle_expires_at=token.idle_expires_at,
absolute_expires_at=token.absolute_expires_at,
)
@@ -217,15 +213,12 @@ async def microsoft_callback(
account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email,
)
refresh_token_str = create_refresh_token({"sub": str(user.id)})
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
token = await _mint_session_tokens(user, db)
await db.commit()
return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}),
refresh_token=refresh_token_str,
access_token=token.access_token,
refresh_token=token.refresh_token,
is_new_user=is_new,
idle_expires_at=token.idle_expires_at,
absolute_expires_at=token.absolute_expires_at,
)

View File

@@ -42,14 +42,54 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
return encoded_jwt
def create_refresh_token(data: dict) -> str:
"""Create a JWT refresh token with a unique jti for revocation tracking."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
def create_refresh_token(
user_id: str,
*,
auth_time: int,
idle_max_seconds: int,
abs_max_seconds: int,
) -> str:
"""Create a JWT refresh token with session-policy claims embedded.
The JWT carries five claims beyond the standard `sub`/`type`/`jti`:
- `auth_time`: Unix-seconds timestamp of the original login; never reset
on rotation. Used by `/auth/refresh` to enforce the absolute cap.
- `idle_max`: idle window in seconds, snapshotted from the account's
policy at login. Carried forward across rotations unchanged.
- `abs_max`: absolute lifetime in seconds, snapshotted at login.
- `exp`: current idle deadline (`now + idle_max`). Standard JWT expiry.
See docs/plans/2026-05-13-session-expiration-policy.md §4.2 for the unit
convention (everything outside the JWT is minutes; inside the JWT it's
seconds so `auth_time + abs_max` is direct Unix math).
"""
now = datetime.now(timezone.utc)
expire = now + timedelta(seconds=idle_max_seconds)
jti = str(uuid.uuid4())
to_encode.update({"exp": expire, "type": "refresh", "jti": jti})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
to_encode = {
"sub": user_id,
"type": "refresh",
"jti": jti,
"exp": expire,
"auth_time": auth_time,
"idle_max": idle_max_seconds,
"abs_max": abs_max_seconds,
}
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def resolve_session_policy(account) -> tuple[int, int]:
"""Return (idle_minutes, absolute_minutes) for an account.
NULL overrides fall back to the system defaults from Settings. Partial
overrides (one column NULL, one set) are intentionally allowed at this
layer; the PATCH /accounts/me/security endpoint validates the resolved
effective values to enforce idle <= absolute. See plan §4.3.
"""
idle = account.session_idle_minutes or settings.SESSION_IDLE_MINUTES_DEFAULT
absolute = account.session_absolute_minutes or settings.SESSION_ABSOLUTE_MINUTES_DEFAULT
return idle, absolute
def hash_token(jti: str) -> str:

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from pydantic import BaseModel
@@ -16,6 +18,11 @@ class OAuthCallbackResponse(BaseModel):
refresh_token: str
token_type: str = "bearer"
is_new_user: bool
# Session-policy expiry windows — mirrors Token in token.py so the
# frontend can drive expiry-soon toasts identically for password and
# OAuth logins.
idle_expires_at: datetime | None = None
absolute_expires_at: datetime | None = None
class InviteLookupResponse(BaseModel):

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
@@ -7,6 +8,12 @@ class Token(BaseModel):
refresh_token: str
token_type: str = "bearer"
must_change_password: bool = False
# Session-policy expiry windows derived from the refresh JWT. Frontend
# uses these to drive the "your session ends soon" toast and to know
# when /auth/refresh will reject for absolute expiry. See
# docs/plans/2026-05-13-session-expiration-policy.md §4.2.
idle_expires_at: Optional[datetime] = None
absolute_expires_at: Optional[datetime] = None
class TokenPayload(BaseModel):

View File

@@ -3,8 +3,9 @@
See docs/plans/2026-05-13-session-expiration-policy.md.
Test numbers below correspond to the cases listed in §6 of the plan.
This file grows across commits — commit 2 lands the error-detail
taxonomy tests (#11 + a wrong-type case + a bad-signature case).
This file grows across commits:
- Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
"""
import uuid
@@ -101,3 +102,80 @@ class TestRefreshTokenErrorTaxonomy:
assert response.status_code == 401
assert response.json()["detail"] == "invalid_refresh_token"
class TestSessionPolicyClaims:
"""§6 tests #1 and #14 — session-policy claims stamped at login.
Every token-issuing endpoint embeds auth_time/idle_max/abs_max in
the refresh JWT and surfaces idle_expires_at/absolute_expires_at on
the response.
"""
@pytest.mark.asyncio
async def test_login_json_embeds_session_claims_with_defaults(
self, client: AsyncClient, test_user: dict
):
before = datetime.now(timezone.utc)
response = await client.post(
"/api/v1/auth/login/json",
json={
"email": test_user["email"],
"password": test_user["password"],
},
)
assert response.status_code == 200, response.json()
body = response.json()
after = datetime.now(timezone.utc)
# Response surfaces both expiry windows as ISO strings.
assert body["idle_expires_at"] is not None
assert body["absolute_expires_at"] is not None
idle_at = datetime.fromisoformat(body["idle_expires_at"])
abs_at = datetime.fromisoformat(body["absolute_expires_at"])
# Strict default: 3 days idle, 14 days absolute.
assert timedelta(days=3) - timedelta(seconds=10) <= idle_at - before <= timedelta(days=3) + timedelta(seconds=10)
assert timedelta(days=14) - timedelta(seconds=10) <= abs_at - before <= timedelta(days=14) + timedelta(seconds=10)
# JWT carries the claims in seconds, plus auth_time as Unix seconds.
decoded = jwt.decode(
body["refresh_token"], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
assert decoded["idle_max"] == 3 * 24 * 60 * 60 # 259200
assert decoded["abs_max"] == 14 * 24 * 60 * 60 # 1209600
assert int(before.timestamp()) <= decoded["auth_time"] <= int(after.timestamp())
@pytest.mark.asyncio
async def test_refresh_carries_claims_forward_unchanged(
self, client: AsyncClient, test_user: dict
):
# Login produces the original session.
login_resp = await client.post(
"/api/v1/auth/login/json",
json={"email": test_user["email"], "password": test_user["password"]},
)
original_refresh = login_resp.json()["refresh_token"]
original_payload = jwt.decode(
original_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Refresh rotates the token but must carry auth_time/idle_max/abs_max
# forward unchanged so the absolute window doesn't slide.
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {original_refresh}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
new_refresh = refresh_resp.json()["refresh_token"]
new_payload = jwt.decode(
new_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
assert new_payload["auth_time"] == original_payload["auth_time"]
assert new_payload["idle_max"] == original_payload["idle_max"]
assert new_payload["abs_max"] == original_payload["abs_max"]
# Idle deadline does slide because exp = now + idle_max.
assert new_payload["exp"] >= original_payload["exp"]
# JTI rotates.
assert new_payload["jti"] != original_payload["jti"]