Second commit in the session-expiration-policy series. Lands the error-detail taxonomy from §4.10 of the plan; no UI-visible change yet because the frontend interceptor (commit 7) doesn't read the new detail strings, but the wire is now ready for it. Today every /auth/refresh failure returns 401 "Invalid refresh token" regardless of cause, so the frontend has no way to distinguish "your session ended for security" from "we don't recognize this token at all." This commit introduces: - decode_refresh_token_strict(): wraps jose.jwt.decode and raises a new IdleTokenExpired exception (from ExpiredSignatureError) so callers can branch on idle expiry. All other jose failures still propagate as JWTError. The legacy decode_token() is preserved for access-token, password-reset, and email-verification paths that don't need the distinction. - get_refresh_token_payload(): now maps IdleTokenExpired -> "session_expired_idle", JWTError and wrong-type tokens -> "invalid_refresh_token". - test_session_policy.py: new test file (will accumulate cases across the series). Three tests for the taxonomy: idle-expired returns session_expired_idle; wrong type returns invalid_refresh_token; bad signature returns invalid_refresh_token. 20/20 across test_session_policy + test_auth + test_oauth_callbacks. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
import hashlib
|
|
import secrets
|
|
import string
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from jose import JWTError, jwt
|
|
from jose.exceptions import ExpiredSignatureError
|
|
from passlib.context import CryptContext
|
|
from .config import settings
|
|
|
|
|
|
class IdleTokenExpired(Exception):
|
|
"""Raised by decode_refresh_token_strict when a refresh JWT is past its `exp`.
|
|
|
|
Distinct from JWTError so callers can map idle expiry to `session_expired_idle`
|
|
on the wire while all other decode failures map to `invalid_refresh_token`.
|
|
"""
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Hash a password."""
|
|
return pwd_context.hash(password, rounds=settings.BCRYPT_ROUNDS)
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create a JWT access token."""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def create_refresh_token(data: dict) -> str:
|
|
"""Create a JWT refresh token with a unique jti for revocation tracking."""
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
jti = str(uuid.uuid4())
|
|
to_encode.update({"exp": expire, "type": "refresh", "jti": jti})
|
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def hash_token(jti: str) -> str:
|
|
"""Hash a token JTI for secure storage."""
|
|
return hashlib.sha256(jti.encode()).hexdigest()
|
|
|
|
|
|
def decode_token(token: str) -> Optional[dict]:
|
|
"""Decode and validate a JWT token.
|
|
|
|
Collapses all jose errors (including expiry) into None — preserved for
|
|
access tokens, password-reset tokens, and email-verification tokens where
|
|
the caller does not need to distinguish expiry from invalid. Refresh tokens
|
|
use decode_refresh_token_strict instead so they can map idle expiry to
|
|
`session_expired_idle` distinctly.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def decode_refresh_token_strict(token: str) -> dict:
|
|
"""Decode a refresh token, distinguishing idle expiry from invalid.
|
|
|
|
Raises:
|
|
IdleTokenExpired: token signature is valid but `exp` is past — i.e. the
|
|
idle window has elapsed.
|
|
JWTError: any other decode failure (bad signature, malformed, wrong
|
|
algorithm).
|
|
|
|
Type discrimination (`type == "refresh"`) is the caller's responsibility —
|
|
this function only inspects the JWT itself.
|
|
"""
|
|
try:
|
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
except ExpiredSignatureError as e:
|
|
raise IdleTokenExpired() from e
|
|
|
|
|
|
def create_password_reset_token(user_id: str) -> str:
|
|
"""Create a JWT password reset token (30-minute expiry, unique JTI)."""
|
|
jti = str(uuid.uuid4())
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"type": "password_reset",
|
|
"jti": jti,
|
|
"exp": expire,
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def create_email_verification_token(user_id: str) -> str:
|
|
"""Create a JWT email verification token (24-hour expiry, unique JTI)."""
|
|
jti = str(uuid.uuid4())
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
to_encode = {
|
|
"sub": user_id,
|
|
"type": "email_verification",
|
|
"jti": jti,
|
|
"exp": expire,
|
|
}
|
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
|
|
def generate_temp_password(length: int = 16) -> str:
|
|
"""Generate a temporary password with guaranteed complexity.
|
|
|
|
Includes at least 1 uppercase, 1 lowercase, 1 digit, and 1 symbol.
|
|
Excludes ambiguous characters: 0, O, I, l, 1, |
|
|
"""
|
|
upper = "ABCDEFGHJKLMNPQRSTUVWXYZ" # no O, I
|
|
lower = "abcdefghjkmnopqrstuvwxyz" # no l
|
|
digits = "23456789" # no 0, 1
|
|
symbols = "!@#$%^&*-_+=?"
|
|
|
|
# Guarantee at least one of each category
|
|
required = [
|
|
secrets.choice(upper),
|
|
secrets.choice(lower),
|
|
secrets.choice(digits),
|
|
secrets.choice(symbols),
|
|
]
|
|
|
|
# Fill the rest from the combined pool
|
|
pool = upper + lower + digits + symbols
|
|
remaining = [secrets.choice(pool) for _ in range(length - len(required))]
|
|
|
|
# Combine and shuffle
|
|
all_chars = required + remaining
|
|
# Fisher-Yates shuffle using secrets for uniform randomness
|
|
for i in range(len(all_chars) - 1, 0, -1):
|
|
j = secrets.randbelow(i + 1)
|
|
all_chars[i], all_chars[j] = all_chars[j], all_chars[i]
|
|
|
|
return "".join(all_chars)
|