feat(auth): embed auth_time/idle_max/abs_max in refresh tokens at every login
Third commit in the session-expiration-policy series. Every refresh token issued from now on carries the policy snapshot in its JWT (in seconds, for direct Unix math), and every login/OAuth response surfaces both expiry windows as ISO timestamps. /auth/refresh carries the claims forward unchanged — including auth_time, which never resets on rotation. Does NOT yet enforce the absolute cap — that's commit 4, sequenced so the gate can be reverted independently if pilots hit an edge case. But the wire is fully populated, and a grandfather path is already in _refresh_session_tokens for tokens issued before this PR. Key changes: - core/security.py: create_refresh_token signature changes to (user_id, *, auth_time, idle_max_seconds, abs_max_seconds). Adds resolve_session_policy(account) -> (idle_minutes, absolute_minutes) applying defaults for NULL overrides. - schemas/token.py + schemas/oauth.py: Token and OAuthCallbackResponse gain idle_expires_at + absolute_expires_at (Optional[datetime], Pydantic emits ISO 8601 UTC strings). - endpoints/auth.py: new _mint_session_tokens(user, db) and _refresh_session_tokens(payload, user, db) helpers. /auth/login, /auth/login/json, and /auth/refresh now route through them. The refresh endpoint's pre-existing "Refresh token has been revoked" error normalized to the taxonomy detail "invalid_refresh_token". - endpoints/oauth.py: both Google and Microsoft callbacks call _mint_session_tokens; OAuthCallbackResponse carries the expiry fields through. - tests: two new cases in test_session_policy.py — login_json embeds the claims with strict defaults (3d/14d -> 259200/1209600 sec) and surfaces matching ISO expiry fields; refresh carries auth_time, idle_max, abs_max forward unchanged across rotation. 35/35 across test_session_policy + test_auth + test_oauth_callbacks + test_account_invite_lookup + test_account_management. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from app.core.security import (
|
||||
create_email_verification_token,
|
||||
decode_token,
|
||||
hash_token,
|
||||
resolve_session_policy,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
@@ -67,6 +68,97 @@ async def store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id)
|
||||
db.add(token_record)
|
||||
|
||||
|
||||
async def _mint_session_tokens(user: User, db: AsyncSession) -> Token:
|
||||
"""Mint a fresh refresh+access pair for a new login.
|
||||
|
||||
Snapshots the account's current session policy into the refresh JWT
|
||||
(auth_time/idle_max/abs_max) and registers the JTI in refresh_tokens.
|
||||
Caller is responsible for committing the session. Use this for every
|
||||
NEW login (password, OAuth, etc.) — for /auth/refresh use
|
||||
_refresh_session_tokens instead, which carries claims forward.
|
||||
|
||||
See docs/plans/2026-05-13-session-expiration-policy.md §4.6.
|
||||
"""
|
||||
account = (
|
||||
await db.execute(select(Account).where(Account.id == user.account_id))
|
||||
).scalar_one()
|
||||
idle_minutes, abs_minutes = resolve_session_policy(account)
|
||||
idle_max_seconds = idle_minutes * 60
|
||||
abs_max_seconds = abs_minutes * 60
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
auth_time_unix = int(now.timestamp())
|
||||
|
||||
refresh_token_str = create_refresh_token(
|
||||
user_id=str(user.id),
|
||||
auth_time=auth_time_unix,
|
||||
idle_max_seconds=idle_max_seconds,
|
||||
abs_max_seconds=abs_max_seconds,
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
|
||||
absolute_expires_at=datetime.fromtimestamp(
|
||||
auth_time_unix + abs_max_seconds, tz=timezone.utc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_session_tokens(
|
||||
payload: dict, user: User, db: AsyncSession
|
||||
) -> Token:
|
||||
"""Carry session-policy claims forward across a refresh-token rotation.
|
||||
|
||||
Grandfathers legacy tokens issued before this PR (no auth_time claim)
|
||||
by snapshotting the account's current policy and treating now() as
|
||||
auth_time — i.e. one free rotation under the new policy. Caller
|
||||
commits.
|
||||
|
||||
Does NOT enforce the absolute cap — that lands in the next commit so
|
||||
the cap can be rolled back independently if needed.
|
||||
"""
|
||||
auth_time = payload.get("auth_time")
|
||||
idle_max_seconds = payload.get("idle_max")
|
||||
abs_max_seconds = payload.get("abs_max")
|
||||
|
||||
if auth_time is None or idle_max_seconds is None or abs_max_seconds is None:
|
||||
# Grandfather path — legacy token from before the session-policy PR.
|
||||
account = (
|
||||
await db.execute(select(Account).where(Account.id == user.account_id))
|
||||
).scalar_one()
|
||||
idle_minutes, abs_minutes = resolve_session_policy(account)
|
||||
auth_time = int(datetime.now(timezone.utc).timestamp())
|
||||
idle_max_seconds = idle_minutes * 60
|
||||
abs_max_seconds = abs_minutes * 60
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
refresh_token_str = create_refresh_token(
|
||||
user_id=str(user.id),
|
||||
auth_time=auth_time,
|
||||
idle_max_seconds=idle_max_seconds,
|
||||
abs_max_seconds=abs_max_seconds,
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
idle_expires_at=now + timedelta(seconds=idle_max_seconds),
|
||||
absolute_expires_at=datetime.fromtimestamp(
|
||||
auth_time + abs_max_seconds, tz=timezone.utc
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
"""Generate a random 8-character alphanumeric display code."""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
@@ -323,20 +415,9 @@ async def login(
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
token = await _mint_session_tokens(user, db)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@router.post("/login/json", response_model=Token)
|
||||
@@ -359,19 +440,9 @@ async def login_json(
|
||||
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
token = await _mint_session_tokens(user, db)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
@@ -402,10 +473,12 @@ async def refresh_token(
|
||||
revoked_row = result.fetchone()
|
||||
|
||||
if not revoked_row:
|
||||
# Either the token doesn't exist or was already revoked/used
|
||||
# Either the token doesn't exist or was already revoked/used.
|
||||
# Surfaced to the frontend as a plain logout — not "session
|
||||
# expired" — because the user did not hit a policy boundary.
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has been revoked"
|
||||
detail="invalid_refresh_token"
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
@@ -414,21 +487,12 @@ async def refresh_token(
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found"
|
||||
detail="invalid_refresh_token"
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store new refresh token
|
||||
await store_refresh_token(db, new_refresh_token_str, user.id)
|
||||
token = await _refresh_session_tokens(payload, user, db)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token_str,
|
||||
token_type="bearer"
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
|
||||
@@ -7,10 +7,9 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.endpoints.auth import store_refresh_token
|
||||
from app.api.endpoints.auth import _mint_session_tokens
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
@@ -187,17 +186,14 @@ async def google_callback(
|
||||
account_invite_code=payload.account_invite_code,
|
||||
invited_email=payload.invited_email,
|
||||
)
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
|
||||
# reject this token as "revoked" (the rotation logic requires a row to
|
||||
# mark as used). _sign_in_or_register already committed; this needs a
|
||||
# second commit.
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
token = await _mint_session_tokens(user, db)
|
||||
await db.commit()
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=refresh_token_str,
|
||||
access_token=token.access_token,
|
||||
refresh_token=token.refresh_token,
|
||||
is_new_user=is_new,
|
||||
idle_expires_at=token.idle_expires_at,
|
||||
absolute_expires_at=token.absolute_expires_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -217,15 +213,12 @@ async def microsoft_callback(
|
||||
account_invite_code=payload.account_invite_code,
|
||||
invited_email=payload.invited_email,
|
||||
)
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
# Persist the refresh-token JTI so the first /auth/refresh call doesn't
|
||||
# reject this token as "revoked" (the rotation logic requires a row to
|
||||
# mark as used). _sign_in_or_register already committed; this needs a
|
||||
# second commit.
|
||||
await store_refresh_token(db, refresh_token_str, user.id)
|
||||
token = await _mint_session_tokens(user, db)
|
||||
await db.commit()
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=refresh_token_str,
|
||||
access_token=token.access_token,
|
||||
refresh_token=token.refresh_token,
|
||||
is_new_user=is_new,
|
||||
idle_expires_at=token.idle_expires_at,
|
||||
absolute_expires_at=token.absolute_expires_at,
|
||||
)
|
||||
|
||||
@@ -42,14 +42,54 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
"""Create a JWT refresh token with a unique jti for revocation tracking."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
def create_refresh_token(
|
||||
user_id: str,
|
||||
*,
|
||||
auth_time: int,
|
||||
idle_max_seconds: int,
|
||||
abs_max_seconds: int,
|
||||
) -> str:
|
||||
"""Create a JWT refresh token with session-policy claims embedded.
|
||||
|
||||
The JWT carries five claims beyond the standard `sub`/`type`/`jti`:
|
||||
|
||||
- `auth_time`: Unix-seconds timestamp of the original login; never reset
|
||||
on rotation. Used by `/auth/refresh` to enforce the absolute cap.
|
||||
- `idle_max`: idle window in seconds, snapshotted from the account's
|
||||
policy at login. Carried forward across rotations unchanged.
|
||||
- `abs_max`: absolute lifetime in seconds, snapshotted at login.
|
||||
- `exp`: current idle deadline (`now + idle_max`). Standard JWT expiry.
|
||||
|
||||
See docs/plans/2026-05-13-session-expiration-policy.md §4.2 for the unit
|
||||
convention (everything outside the JWT is minutes; inside the JWT it's
|
||||
seconds so `auth_time + abs_max` is direct Unix math).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expire = now + timedelta(seconds=idle_max_seconds)
|
||||
jti = str(uuid.uuid4())
|
||||
to_encode.update({"exp": expire, "type": "refresh", "jti": jti})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
to_encode = {
|
||||
"sub": user_id,
|
||||
"type": "refresh",
|
||||
"jti": jti,
|
||||
"exp": expire,
|
||||
"auth_time": auth_time,
|
||||
"idle_max": idle_max_seconds,
|
||||
"abs_max": abs_max_seconds,
|
||||
}
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def resolve_session_policy(account) -> tuple[int, int]:
|
||||
"""Return (idle_minutes, absolute_minutes) for an account.
|
||||
|
||||
NULL overrides fall back to the system defaults from Settings. Partial
|
||||
overrides (one column NULL, one set) are intentionally allowed at this
|
||||
layer; the PATCH /accounts/me/security endpoint validates the resolved
|
||||
effective values to enforce idle <= absolute. See plan §4.3.
|
||||
"""
|
||||
idle = account.session_idle_minutes or settings.SESSION_IDLE_MINUTES_DEFAULT
|
||||
absolute = account.session_absolute_minutes or settings.SESSION_ABSOLUTE_MINUTES_DEFAULT
|
||||
return idle, absolute
|
||||
|
||||
|
||||
def hash_token(jti: str) -> str:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -16,6 +18,11 @@ class OAuthCallbackResponse(BaseModel):
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
is_new_user: bool
|
||||
# Session-policy expiry windows — mirrors Token in token.py so the
|
||||
# frontend can drive expiry-soon toasts identically for password and
|
||||
# OAuth logins.
|
||||
idle_expires_at: datetime | None = None
|
||||
absolute_expires_at: datetime | None = None
|
||||
|
||||
|
||||
class InviteLookupResponse(BaseModel):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -7,6 +8,12 @@ class Token(BaseModel):
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
must_change_password: bool = False
|
||||
# Session-policy expiry windows derived from the refresh JWT. Frontend
|
||||
# uses these to drive the "your session ends soon" toast and to know
|
||||
# when /auth/refresh will reject for absolute expiry. See
|
||||
# docs/plans/2026-05-13-session-expiration-policy.md §4.2.
|
||||
idle_expires_at: Optional[datetime] = None
|
||||
absolute_expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
See docs/plans/2026-05-13-session-expiration-policy.md.
|
||||
Test numbers below correspond to the cases listed in §6 of the plan.
|
||||
|
||||
This file grows across commits — commit 2 lands the error-detail
|
||||
taxonomy tests (#11 + a wrong-type case + a bad-signature case).
|
||||
This file grows across commits:
|
||||
- Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature)
|
||||
- Commit 3: claims embedded at login + response fields surfaced (#1, #14)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -101,3 +102,80 @@ class TestRefreshTokenErrorTaxonomy:
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "invalid_refresh_token"
|
||||
|
||||
|
||||
class TestSessionPolicyClaims:
|
||||
"""§6 tests #1 and #14 — session-policy claims stamped at login.
|
||||
|
||||
Every token-issuing endpoint embeds auth_time/idle_max/abs_max in
|
||||
the refresh JWT and surfaces idle_expires_at/absolute_expires_at on
|
||||
the response.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_json_embeds_session_claims_with_defaults(
|
||||
self, client: AsyncClient, test_user: dict
|
||||
):
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={
|
||||
"email": test_user["email"],
|
||||
"password": test_user["password"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
body = response.json()
|
||||
after = datetime.now(timezone.utc)
|
||||
|
||||
# Response surfaces both expiry windows as ISO strings.
|
||||
assert body["idle_expires_at"] is not None
|
||||
assert body["absolute_expires_at"] is not None
|
||||
idle_at = datetime.fromisoformat(body["idle_expires_at"])
|
||||
abs_at = datetime.fromisoformat(body["absolute_expires_at"])
|
||||
# Strict default: 3 days idle, 14 days absolute.
|
||||
assert timedelta(days=3) - timedelta(seconds=10) <= idle_at - before <= timedelta(days=3) + timedelta(seconds=10)
|
||||
assert timedelta(days=14) - timedelta(seconds=10) <= abs_at - before <= timedelta(days=14) + timedelta(seconds=10)
|
||||
|
||||
# JWT carries the claims in seconds, plus auth_time as Unix seconds.
|
||||
decoded = jwt.decode(
|
||||
body["refresh_token"], settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
assert decoded["idle_max"] == 3 * 24 * 60 * 60 # 259200
|
||||
assert decoded["abs_max"] == 14 * 24 * 60 * 60 # 1209600
|
||||
assert int(before.timestamp()) <= decoded["auth_time"] <= int(after.timestamp())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_carries_claims_forward_unchanged(
|
||||
self, client: AsyncClient, test_user: dict
|
||||
):
|
||||
# Login produces the original session.
|
||||
login_resp = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={"email": test_user["email"], "password": test_user["password"]},
|
||||
)
|
||||
original_refresh = login_resp.json()["refresh_token"]
|
||||
original_payload = jwt.decode(
|
||||
original_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
# Refresh rotates the token but must carry auth_time/idle_max/abs_max
|
||||
# forward unchanged so the absolute window doesn't slide.
|
||||
refresh_resp = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
headers={"Authorization": f"Bearer {original_refresh}"},
|
||||
)
|
||||
assert refresh_resp.status_code == 200, refresh_resp.json()
|
||||
new_refresh = refresh_resp.json()["refresh_token"]
|
||||
new_payload = jwt.decode(
|
||||
new_refresh, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
|
||||
assert new_payload["auth_time"] == original_payload["auth_time"]
|
||||
assert new_payload["idle_max"] == original_payload["idle_max"]
|
||||
assert new_payload["abs_max"] == original_payload["abs_max"]
|
||||
# Idle deadline does slide because exp = now + idle_max.
|
||||
assert new_payload["exp"] >= original_payload["exp"]
|
||||
# JTI rotates.
|
||||
assert new_payload["jti"] != original_payload["jti"]
|
||||
|
||||
Reference in New Issue
Block a user