feat(auth): embed auth_time/idle_max/abs_max in refresh tokens at every login

Third commit in the session-expiration-policy series. Every refresh token
issued from now on carries the policy snapshot in its JWT (in seconds,
for direct Unix math), and every login/OAuth response surfaces both
expiry windows as ISO timestamps. /auth/refresh carries the claims
forward unchanged — including auth_time, which never resets on rotation.

Does NOT yet enforce the absolute cap — that's commit 4, sequenced so
the gate can be reverted independently if pilots hit an edge case.
But the wire is fully populated, and a grandfather path is already in
_refresh_session_tokens for tokens issued before this PR.

Key changes:
- core/security.py: create_refresh_token signature changes to
  (user_id, *, auth_time, idle_max_seconds, abs_max_seconds). Adds
  resolve_session_policy(account) -> (idle_minutes, absolute_minutes)
  applying defaults for NULL overrides.
- schemas/token.py + schemas/oauth.py: Token and OAuthCallbackResponse
  gain idle_expires_at + absolute_expires_at (Optional[datetime],
  Pydantic emits ISO 8601 UTC strings).
- endpoints/auth.py: new _mint_session_tokens(user, db) and
  _refresh_session_tokens(payload, user, db) helpers. /auth/login,
  /auth/login/json, and /auth/refresh now route through them. The
  refresh endpoint's pre-existing "Refresh token has been revoked"
  error normalized to the taxonomy detail "invalid_refresh_token".
- endpoints/oauth.py: both Google and Microsoft callbacks call
  _mint_session_tokens; OAuthCallbackResponse carries the expiry
  fields through.
- tests: two new cases in test_session_policy.py — login_json embeds
  the claims with strict defaults (3d/14d -> 259200/1209600 sec) and
  surfaces matching ISO expiry fields; refresh carries auth_time,
  idle_max, abs_max forward unchanged across rotation.

35/35 across test_session_policy + test_auth + test_oauth_callbacks +
test_account_invite_lookup + test_account_management.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-13 16:22:53 -04:00
parent 2375948b7a
commit d6a02ee8da
6 changed files with 255 additions and 66 deletions

View File

@@ -20,6 +20,7 @@ from app.core.security import (
create_email_verification_token, create_email_verification_token,
decode_token, decode_token,
hash_token, hash_token,
resolve_session_policy,
) )
from app.models.user import User from app.models.user import User
from app.models.invite_code import InviteCode from app.models.invite_code import InviteCode
@@ -67,6 +68,97 @@ async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id)
db.add(token_record) db.add(token_record)
async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
"""Mint a fresh refresh+access pair for a new login.
Snapshots the account's current session policy into the refresh JWT
(auth_time/idle_max/abs_max) and registers the JTI in refresh_tokens.
Caller is responsible for committing the session. Use this for every
NEW login (password, OAuth, etc.) — for /auth/refresh use
_refresh_session_tokens instead, which carries claims forward.
See docs/plans/2026-05-13-session-expiration-policy.md §4.6.
"""
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
now = datetime.now(timezone.utc)
auth_time_unix = int(now.timestamp())
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time_unix,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time_unix + abs_max_seconds, tz=timezone.utc
),
)
async def _refresh_session_tokens(
payload: dict, user: User, db: AsyncSession
) -> Token:
"""Carry session-policy claims forward across a refresh-token rotation.
Grandfathers legacy tokens issued before this PR (no auth_time claim)
by snapshotting the account's current policy and treating now() as
auth_time — i.e. one free rotation under the new policy. Caller
commits.
Does NOT enforce the absolute cap — that lands in the next commit so
the cap can be rolled back independently if needed.
"""
auth_time = payload.get("auth_time")
idle_max_seconds = payload.get("idle_max")
abs_max_seconds = payload.get("abs_max")
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
# Grandfather path — legacy token from before the session-policy PR.
account = (
await db.execute(select(Account).where(Account.id == user.account_id))
).scalar_one()
idle_minutes, abs_minutes = resolve_session_policy(account)
auth_time = int(datetime.now(timezone.utc).timestamp())
idle_max_seconds = idle_minutes * 60
abs_max_seconds = abs_minutes * 60
now = datetime.now(timezone.utc)
refresh_token_str = create_refresh_token(
user_id=str(user.id),
auth_time=auth_time,
idle_max_seconds=idle_max_seconds,
abs_max_seconds=abs_max_seconds,
)
access_token = create_access_token(data={"sub": str(user.id)})
await store_refresh_token(db, refresh_token_str, user.id)
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
absolute_expires_at=datetime.fromtimestamp(
auth_time + abs_max_seconds, tz=timezone.utc
),
)
def _generate_display_code() -> str: def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code.""" """Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits chars = string.ascii_uppercase + string.digits
@@ -323,20 +415,9 @@ async def login(
# Update last login # Update last login
user.last_login = datetime.now(timezone.utc) user.last_login = datetime.now(timezone.utc)
# Create tokens token = await _mint_session_tokens(user, db)
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return token
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
@router.post("/login/json", response_model=Token) @router.post("/login/json", response_model=Token)
@@ -359,19 +440,9 @@ async def login_json(
user.last_login = datetime.now(timezone.utc) user.last_login = datetime.now(timezone.utc)
access_token = create_access_token(data={"sub": str(user.id)}) token = await _mint_session_tokens(user, db)
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return token
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
must_change_password=user.must_change_password,
)
@router.post("/refresh", response_model=Token) @router.post("/refresh", response_model=Token)
@@ -402,10 +473,12 @@ async def refresh_token(
revoked_row = result.fetchone() revoked_row = result.fetchone()
if not revoked_row: if not revoked_row:
# Either the token doesn't exist or was already revoked/used # Either the token doesn't exist or was already revoked/used.
# Surfaced to the frontend as a plain logout — not "session
# expired" — because the user did not hit a policy boundary.
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked" detail="invalid_refresh_token"
) )
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
@@ -414,21 +487,12 @@ async def refresh_token(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found" detail="invalid_refresh_token"
) )
access_token = create_access_token(data={"sub": str(user.id)}) token = await _refresh_session_tokens(payload, user, db)
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit() await db.commit()
return token
return Token(
access_token=access_token,
refresh_token=new_refresh_token_str,
token_type="bearer"
)
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)

View File

@@ -7,10 +7,9 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.endpoints.auth import store_refresh_token from app.api.endpoints.auth import _mint_session_tokens
from app.core.admin_database import get_admin_db from app.core.admin_database import get_admin_db
from app.core.config import settings from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token
from app.models.account import Account from app.models.account import Account
from app.models.account_invite import AccountInvite from app.models.account_invite import AccountInvite
from app.models.oauth_identity import OAuthIdentity from app.models.oauth_identity import OAuthIdentity
@@ -187,17 +186,14 @@ async def google_callback(
account_invite_code=payload.account_invite_code, account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email, invited_email=payload.invited_email,
) )
refresh_token_str = create_refresh_token({"sub": str(user.id)}) token = await _mint_session_tokens(user, db)
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return OAuthCallbackResponse( return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}), access_token=token.access_token,
refresh_token=refresh_token_str, refresh_token=token.refresh_token,
is_new_user=is_new, is_new_user=is_new,
idle_expires_at=token.idle_expires_at,
absolute_expires_at=token.absolute_expires_at,
) )
@@ -217,15 +213,12 @@ async def microsoft_callback(
account_invite_code=payload.account_invite_code, account_invite_code=payload.account_invite_code,
invited_email=payload.invited_email, invited_email=payload.invited_email,
) )
refresh_token_str = create_refresh_token({"sub": str(user.id)}) token = await _mint_session_tokens(user, db)
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
# reject this token as "revoked" (the rotation logic requires a row to
# mark as used). _sign_in_or_register already committed; this needs a
# second commit.
await store_refresh_token(db, refresh_token_str, user.id)
await db.commit() await db.commit()
return OAuthCallbackResponse( return OAuthCallbackResponse(
access_token=create_access_token({"sub": str(user.id)}), access_token=token.access_token,
refresh_token=refresh_token_str, refresh_token=token.refresh_token,
is_new_user=is_new, is_new_user=is_new,
idle_expires_at=token.idle_expires_at,
absolute_expires_at=token.absolute_expires_at,
) )

View File

@@ -42,14 +42,54 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
return encoded_jwt return encoded_jwt
def create_refresh_token(data: dict) -> str: def create_refresh_token(
"""Create a JWT refresh token with a unique jti for revocation tracking.""" user_id: str,
to_encode = data.copy() *,
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) auth_time: int,
idle_max_seconds: int,
abs_max_seconds: int,
) -> str:
"""Create a JWT refresh token with session-policy claims embedded.
The JWT carries five claims beyond the standard `sub`/`type`/`jti`:
- `auth_time`: Unix-seconds timestamp of the original login; never reset
on rotation. Used by `/auth/refresh` to enforce the absolute cap.
- `idle_max`: idle window in seconds, snapshotted from the account's
policy at login. Carried forward across rotations unchanged.
- `abs_max`: absolute lifetime in seconds, snapshotted at login.
- `exp`: current idle deadline (`now + idle_max`). Standard JWT expiry.
See docs/plans/2026-05-13-session-expiration-policy.md §4.2 for the unit
convention (everything outside the JWT is minutes; inside the JWT it's
seconds so `auth_time + abs_max` is direct Unix math).
"""
now = datetime.now(timezone.utc)
expire = now + timedelta(seconds=idle_max_seconds)
jti = str(uuid.uuid4()) jti = str(uuid.uuid4())
to_encode.update({"exp": expire, "type": "refresh", "jti": jti}) to_encode = {
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) "sub": user_id,
return encoded_jwt "type": "refresh",
"jti": jti,
"exp": expire,
"auth_time": auth_time,
"idle_max": idle_max_seconds,
"abs_max": abs_max_seconds,
}
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def resolve_session_policy(account) -> tuple[int, int]:
"""Return (idle_minutes, absolute_minutes) for an account.
NULL overrides fall back to the system defaults from Settings. Partial
overrides (one column NULL, one set) are intentionally allowed at this
layer; the PATCH /accounts/me/security endpoint validates the resolved
effective values to enforce idle <= absolute. See plan §4.3.
"""
idle = account.session_idle_minutes or settings.SESSION_IDLE_MINUTES_DEFAULT
absolute = account.session_absolute_minutes or settings.SESSION_ABSOLUTE_MINUTES_DEFAULT
return idle, absolute
def hash_token(jti: str) -> str: def hash_token(jti: str) -> str:

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
@@ -16,6 +18,11 @@ class OAuthCallbackResponse(BaseModel):
refresh_token: str refresh_token: str
token_type: str = "bearer" token_type: str = "bearer"
is_new_user: bool is_new_user: bool
# Session-policy expiry windows — mirrors Token in token.py so the
# frontend can drive expiry-soon toasts identically for password and
# OAuth logins.
idle_expires_at: datetime | None = None
absolute_expires_at: datetime | None = None
class InviteLookupResponse(BaseModel): class InviteLookupResponse(BaseModel):

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -7,6 +8,12 @@ class Token(BaseModel):
refresh_token: str refresh_token: str
token_type: str = "bearer" token_type: str = "bearer"
must_change_password: bool = False must_change_password: bool = False
# Session-policy expiry windows derived from the refresh JWT. Frontend
# uses these to drive the "your session ends soon" toast and to know
# when /auth/refresh will reject for absolute expiry. See
# docs/plans/2026-05-13-session-expiration-policy.md §4.2.
idle_expires_at: Optional[datetime] = None
absolute_expires_at: Optional[datetime] = None
class TokenPayload(BaseModel): class TokenPayload(BaseModel):

View File

@@ -3,8 +3,9 @@
See docs/plans/2026-05-13-session-expiration-policy.md. See docs/plans/2026-05-13-session-expiration-policy.md.
Test numbers below correspond to the cases listed in §6 of the plan. Test numbers below correspond to the cases listed in §6 of the plan.
This file grows across commits — commit 2 lands the error-detail This file grows across commits:
taxonomy tests (#11 + a wrong-type case + a bad-signature case). - Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
""" """
import uuid import uuid
@@ -101,3 +102,80 @@ class TestRefreshTokenErrorTaxonomy:
assert response.status_code == 401 assert response.status_code == 401
assert response.json()["detail"] == "invalid_refresh_token" assert response.json()["detail"] == "invalid_refresh_token"
class TestSessionPolicyClaims:
"""§6 tests #1 and #14 — session-policy claims stamped at login.
Every token-issuing endpoint embeds auth_time/idle_max/abs_max in
the refresh JWT and surfaces idle_expires_at/absolute_expires_at on
the response.
"""
@pytest.mark.asyncio
async def test_login_json_embeds_session_claims_with_defaults(
self, client: AsyncClient, test_user: dict
):
before = datetime.now(timezone.utc)
response = await client.post(
"/api/v1/auth/login/json",
json={
"email": test_user["email"],
"password": test_user["password"],
},
)
assert response.status_code == 200, response.json()
body = response.json()
after = datetime.now(timezone.utc)
# Response surfaces both expiry windows as ISO strings.
assert body["idle_expires_at"] is not None
assert body["absolute_expires_at"] is not None
idle_at = datetime.fromisoformat(body["idle_expires_at"])
abs_at = datetime.fromisoformat(body["absolute_expires_at"])
# Strict default: 3 days idle, 14 days absolute.
assert timedelta(days=3) - timedelta(seconds=10) <= idle_at - before <= timedelta(days=3) + timedelta(seconds=10)
assert timedelta(days=14) - timedelta(seconds=10) <= abs_at - before <= timedelta(days=14) + timedelta(seconds=10)
# JWT carries the claims in seconds, plus auth_time as Unix seconds.
decoded = jwt.decode(
body["refresh_token"], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
assert decoded["idle_max"] == 3 * 24 * 60 * 60 # 259200
assert decoded["abs_max"] == 14 * 24 * 60 * 60 # 1209600
assert int(before.timestamp()) <= decoded["auth_time"] <= int(after.timestamp())
@pytest.mark.asyncio
async def test_refresh_carries_claims_forward_unchanged(
self, client: AsyncClient, test_user: dict
):
# Login produces the original session.
login_resp = await client.post(
"/api/v1/auth/login/json",
json={"email": test_user["email"], "password": test_user["password"]},
)
original_refresh = login_resp.json()["refresh_token"]
original_payload = jwt.decode(
original_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# Refresh rotates the token but must carry auth_time/idle_max/abs_max
# forward unchanged so the absolute window doesn't slide.
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {original_refresh}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
new_refresh = refresh_resp.json()["refresh_token"]
new_payload = jwt.decode(
new_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
assert new_payload["auth_time"] == original_payload["auth_time"]
assert new_payload["idle_max"] == original_payload["idle_max"]
assert new_payload["abs_max"] == original_payload["abs_max"]
# Idle deadline does slide because exp = now + idle_max.
assert new_payload["exp"] >= original_payload["exp"]
# JTI rotates.
assert new_payload["jti"] != original_payload["jti"]