Third commit in the session-expiration-policy series. Every refresh token issued from now on carries the policy snapshot in its JWT (in seconds, for direct Unix math), and every login/OAuth response surfaces both expiry windows as ISO timestamps. /auth/refresh carries the claims forward unchanged — including auth_time, which never resets on rotation. Does NOT yet enforce the absolute cap — that's commit 4, sequenced so the gate can be reverted independently if pilots hit an edge case. But the wire is fully populated, and a grandfather path is already in _refresh_session_tokens for tokens issued before this PR. Key changes: - core/security.py: create_refresh_token signature changes to (user_id, *, auth_time, idle_max_seconds, abs_max_seconds). Adds resolve_session_policy(account) -> (idle_minutes, absolute_minutes) applying defaults for NULL overrides. - schemas/token.py + schemas/oauth.py: Token and OAuthCallbackResponse gain idle_expires_at + absolute_expires_at (Optional[datetime], Pydantic emits ISO 8601 UTC strings). - endpoints/auth.py: new _mint_session_tokens(user, db) and _refresh_session_tokens(payload, user, db) helpers. /auth/login, /auth/login/json, and /auth/refresh now route through them. The refresh endpoint's pre-existing "Refresh token has been revoked" error normalized to the taxonomy detail "invalid_refresh_token". - endpoints/oauth.py: both Google and Microsoft callbacks call _mint_session_tokens; OAuthCallbackResponse carries the expiry fields through. - tests: two new cases in test_session_policy.py — login_json embeds the claims with strict defaults (3d/14d -> 259200/1209600 sec) and surfaces matching ISO expiry fields; refresh carries auth_time, idle_max, abs_max forward unchanged across rotation. 35/35 across test_session_policy + test_auth + test_oauth_callbacks + test_account_invite_lookup + test_account_management. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
import hashlib
|
|
import secrets
|
|
import string
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from jose import JWTError, jwt
|
|
from jose.exceptions import ExpiredSignatureError
|
|
from passlib.context import CryptContext
|
|
from .config import settings
|
|
|
|
|
|
class IdleTokenExpired(Exception):
|
|
"""Raised by decode_refresh_token_strict when a refresh JWT is past its `exp`.
|
|
|
|
Distinct from JWTError so callers can map idle expiry to `session_expired_idle`
|
|
on the wire while all other decode failures map to `invalid_refresh_token`.
|
|
"""
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Hash a password."""
|
|
return pwd_context.hash(password, rounds=settings.BCRYPT_ROUNDS)
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a JWT access token."""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def create_refresh_token(
|
|
user_id: str,
|
|
*,
|
|
auth_time: int,
|
|
idle_max_seconds: int,
|
|
abs_max_seconds: int,
|
|
) -> str:
|
|
"""Create a JWT refresh token with session-policy claims embedded.
|
|
|
|
The JWT carries five claims beyond the standard `sub`/`type`/`jti`:
|
|
|
|
- `auth_time`: Unix-seconds timestamp of the original login; never reset
|
|
on rotation. Used by `/auth/refresh` to enforce the absolute cap.
|
|
- `idle_max`: idle window in seconds, snapshotted from the account's
|
|
policy at login. Carried forward across rotations unchanged.
|
|
- `abs_max`: absolute lifetime in seconds, snapshotted at login.
|
|
- `exp`: current idle deadline (`now + idle_max`). Standard JWT expiry.
|
|
|
|
See docs/plans/2026-05-13-session-expiration-policy.md §4.2 for the unit
|
|
convention (everything outside the JWT is minutes; inside the JWT it's
|
|
seconds so `auth_time + abs_max` is direct Unix math).
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
expire = now + timedelta(seconds=idle_max_seconds)
|
|
jti = str(uuid.uuid4())
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"type": "refresh",
|
|
"jti": jti,
|
|
"exp": expire,
|
|
"auth_time": auth_time,
|
|
"idle_max": idle_max_seconds,
|
|
"abs_max": abs_max_seconds,
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def resolve_session_policy(account) -> tuple[int, int]:
|
|
"""Return (idle_minutes, absolute_minutes) for an account.
|
|
|
|
NULL overrides fall back to the system defaults from Settings. Partial
|
|
overrides (one column NULL, one set) are intentionally allowed at this
|
|
layer; the PATCH /accounts/me/security endpoint validates the resolved
|
|
effective values to enforce idle <= absolute. See plan §4.3.
|
|
"""
|
|
idle = account.session_idle_minutes or settings.SESSION_IDLE_MINUTES_DEFAULT
|
|
absolute = account.session_absolute_minutes or settings.SESSION_ABSOLUTE_MINUTES_DEFAULT
|
|
return idle, absolute
|
|
|
|
|
|
def hash_token(jti: str) -> str:
|
|
"""Hash a token JTI for secure storage."""
|
|
return hashlib.sha256(jti.encode()).hexdigest()
|
|
|
|
|
|
def decode_token(token: str) -> Optional[dict]:
|
|
"""Decode and validate a JWT token.
|
|
|
|
Collapses all jose errors (including expiry) into None — preserved for
|
|
access tokens, password-reset tokens, and email-verification tokens where
|
|
the caller does not need to distinguish expiry from invalid. Refresh tokens
|
|
use decode_refresh_token_strict instead so they can map idle expiry to
|
|
`session_expired_idle` distinctly.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def decode_refresh_token_strict(token: str) -> dict:
|
|
"""Decode a refresh token, distinguishing idle expiry from invalid.
|
|
|
|
Raises:
|
|
IdleTokenExpired: token signature is valid but `exp` is past — i.e. the
|
|
idle window has elapsed.
|
|
JWTError: any other decode failure (bad signature, malformed, wrong
|
|
algorithm).
|
|
|
|
Type discrimination (`type == "refresh"`) is the caller's responsibility —
|
|
this function only inspects the JWT itself.
|
|
"""
|
|
try:
|
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
except ExpiredSignatureError as e:
|
|
raise IdleTokenExpired() from e
|
|
|
|
|
|
def create_password_reset_token(user_id: str) -> str:
|
|
"""Create a JWT password reset token (30-minute expiry, unique JTI)."""
|
|
jti = str(uuid.uuid4())
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"type": "password_reset",
|
|
"jti": jti,
|
|
"exp": expire,
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def create_email_verification_token(user_id: str) -> str:
|
|
"""Create a JWT email verification token (24-hour expiry, unique JTI)."""
|
|
jti = str(uuid.uuid4())
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"type": "email_verification",
|
|
"jti": jti,
|
|
"exp": expire,
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def generate_temp_password(length: int = 16) -> str:
|
|
"""Generate a temporary password with guaranteed complexity.
|
|
|
|
Includes at least 1 uppercase, 1 lowercase, 1 digit, and 1 symbol.
|
|
Excludes ambiguous characters: 0, O, I, l, 1, |
|
|
"""
|
|
upper = "ABCDEFGHJKLMNPQRSTUVWXYZ" # no O, I
|
|
lower = "abcdefghjkmnopqrstuvwxyz" # no l
|
|
digits = "23456789" # no 0, 1
|
|
symbols = "!@#$%^&*-_+=?"
|
|
|
|
# Guarantee at least one of each category
|
|
required = [
|
|
secrets.choice(upper),
|
|
secrets.choice(lower),
|
|
secrets.choice(digits),
|
|
secrets.choice(symbols),
|
|
]
|
|
|
|
# Fill the rest from the combined pool
|
|
pool = upper + lower + digits + symbols
|
|
remaining = [secrets.choice(pool) for _ in range(length - len(required))]
|
|
|
|
# Combine and shuffle
|
|
all_chars = required + remaining
|
|
# Fisher-Yates shuffle using secrets for uniform randomness
|
|
for i in range(len(all_chars) - 1, 0, -1):
|
|
j = secrets.randbelow(i + 1)
|
|
all_chars[i], all_chars[j] = all_chars[j], all_chars[i]
|
|
|
|
return "".join(all_chars)
|