import hashlib import secrets import string import uuid from datetime import datetime, timedelta, timezone from typing import Optional from jose import JWTError, jwt from jose.exceptions import ExpiredSignatureError from passlib.context import CryptContext from .config import settings class IdleTokenExpired(Exception): """Raised by decode_refresh_token_strict when a refresh JWT is past its `exp`. Distinct from JWTError so callers can map idle expiry to `session_expired_idle` on the wire while all other decode failures map to `invalid_refresh_token`. """ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Hash a password.""" return pwd_context.hash(password, rounds=settings.BCRYPT_ROUNDS) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """Create a JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire, "type": "access"}) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt def create_refresh_token( user_id: str, *, auth_time: int, idle_max_seconds: int, abs_max_seconds: int, ) -> str: """Create a JWT refresh token with session-policy claims embedded. The JWT carries five claims beyond the standard `sub`/`type`/`jti`: - `auth_time`: Unix-seconds timestamp of the original login; never reset on rotation. Used by `/auth/refresh` to enforce the absolute cap. - `idle_max`: idle window in seconds, snapshotted from the account's policy at login. Carried forward across rotations unchanged. - `abs_max`: absolute lifetime in seconds, snapshotted at login. - `exp`: current idle deadline (`now + idle_max`). Standard JWT expiry. See docs/plans/2026-05-13-session-expiration-policy.md §4.2 for the unit convention (everything outside the JWT is minutes; inside the JWT it's seconds so `auth_time + abs_max` is direct Unix math). """ now = datetime.now(timezone.utc) expire = now + timedelta(seconds=idle_max_seconds) jti = str(uuid.uuid4()) to_encode = { "sub": user_id, "type": "refresh", "jti": jti, "exp": expire, "auth_time": auth_time, "idle_max": idle_max_seconds, "abs_max": abs_max_seconds, } return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def resolve_session_policy(account) -> tuple[int, int]: """Return (idle_minutes, absolute_minutes) for an account. NULL overrides fall back to the system defaults from Settings. Partial overrides (one column NULL, one set) are intentionally allowed at this layer; the PATCH /accounts/me/security endpoint validates the resolved effective values to enforce idle <= absolute. See plan §4.3. """ idle = account.session_idle_minutes or settings.SESSION_IDLE_MINUTES_DEFAULT absolute = account.session_absolute_minutes or settings.SESSION_ABSOLUTE_MINUTES_DEFAULT return idle, absolute def hash_token(jti: str) -> str: """Hash a token JTI for secure storage.""" return hashlib.sha256(jti.encode()).hexdigest() def decode_token(token: str) -> Optional[dict]: """Decode and validate a JWT token. Collapses all jose errors (including expiry) into None — preserved for access tokens, password-reset tokens, and email-verification tokens where the caller does not need to distinguish expiry from invalid. Refresh tokens use decode_refresh_token_strict instead so they can map idle expiry to `session_expired_idle` distinctly. """ try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload except JWTError: return None def decode_refresh_token_strict(token: str) -> dict: """Decode a refresh token, distinguishing idle expiry from invalid. Raises: IdleTokenExpired: token signature is valid but `exp` is past — i.e. the idle window has elapsed. JWTError: any other decode failure (bad signature, malformed, wrong algorithm). Type discrimination (`type == "refresh"`) is the caller's responsibility — this function only inspects the JWT itself. """ try: return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) except ExpiredSignatureError as e: raise IdleTokenExpired() from e def create_password_reset_token(user_id: str) -> str: """Create a JWT password reset token (30-minute expiry, unique JTI).""" jti = str(uuid.uuid4()) expire = datetime.now(timezone.utc) + timedelta(minutes=30) to_encode = { "sub": user_id, "type": "password_reset", "jti": jti, "exp": expire, } return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def create_email_verification_token(user_id: str) -> str: """Create a JWT email verification token (24-hour expiry, unique JTI).""" jti = str(uuid.uuid4()) expire = datetime.now(timezone.utc) + timedelta(hours=24) to_encode = { "sub": user_id, "type": "email_verification", "jti": jti, "exp": expire, } return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) def generate_temp_password(length: int = 16) -> str: """Generate a temporary password with guaranteed complexity. Includes at least 1 uppercase, 1 lowercase, 1 digit, and 1 symbol. Excludes ambiguous characters: 0, O, I, l, 1, | """ upper = "ABCDEFGHJKLMNPQRSTUVWXYZ" # no O, I lower = "abcdefghjkmnopqrstuvwxyz" # no l digits = "23456789" # no 0, 1 symbols = "!@#$%^&*-_+=?" # Guarantee at least one of each category required = [ secrets.choice(upper), secrets.choice(lower), secrets.choice(digits), secrets.choice(symbols), ] # Fill the rest from the combined pool pool = upper + lower + digits + symbols remaining = [secrets.choice(pool) for _ in range(length - len(required))] # Combine and shuffle all_chars = required + remaining # Fisher-Yates shuffle using secrets for uniform randomness for i in range(len(all_chars) - 1, 0, -1): j = secrets.randbelow(i + 1) all_chars[i], all_chars[j] = all_chars[j], all_chars[i] return "".join(all_chars)