diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 67717d45..6314e63e 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -7,7 +7,13 @@ from sqlalchemy import select import sentry_sdk from app.core.database import get_db -from app.core.security import decode_token +from jose import JWTError + +from app.core.security import ( + IdleTokenExpired, + decode_refresh_token_strict, + decode_token, +) from app.models.user import User from app.models.plan_limits import PlanLimits from app.core.tenant_context import set_current_account_id, clear_current_account_id @@ -101,12 +107,35 @@ async def get_current_user_optional( async def get_refresh_token_payload( token: Annotated[str, Depends(oauth2_scheme)] ) -> dict: - """Extract and validate a refresh token from the Authorization header.""" - payload = decode_token(token) - if payload is None or payload.get("type") != "refresh": + """Extract and validate a refresh token from the Authorization header. + + Returns one of three outcomes via HTTP 401 `detail`: + - `session_expired_idle` — JWT signature valid but `exp` past + - `invalid_refresh_token` — any other decode failure, or `type != "refresh"` + - (200 path) — returns the decoded payload + + The frontend uses these to choose between the "your session ended for + security" banner and a plain logout redirect. See + docs/plans/2026-05-13-session-expiration-policy.md §4.10. + """ + try: + payload = decode_refresh_token_strict(token) + except IdleTokenExpired: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token", + detail="session_expired_idle", + headers={"WWW-Authenticate": "Bearer"}, + ) + except JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_refresh_token", + headers={"WWW-Authenticate": "Bearer"}, + ) + if payload.get("type") != "refresh": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_refresh_token", headers={"WWW-Authenticate": "Bearer"}, ) return payload diff --git a/backend/app/core/security.py b/backend/app/core/security.py index f5e2f460..d37fad42 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -5,9 +5,18 @@ import uuid from datetime import datetime, timedelta, timezone from typing import Optional from jose import JWTError, jwt +from jose.exceptions import ExpiredSignatureError from passlib.context import CryptContext from .config import settings + +class IdleTokenExpired(Exception): + """Raised by decode_refresh_token_strict when a refresh JWT is past its `exp`. + + Distinct from JWTError so callers can map idle expiry to `session_expired_idle` + on the wire while all other decode failures map to `invalid_refresh_token`. + """ + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -49,7 +58,14 @@ def hash_token(jti: str) -> str: def decode_token(token: str) -> Optional[dict]: - """Decode and validate a JWT token.""" + """Decode and validate a JWT token. + + Collapses all jose errors (including expiry) into None — preserved for + access tokens, password-reset tokens, and email-verification tokens where + the caller does not need to distinguish expiry from invalid. Refresh tokens + use decode_refresh_token_strict instead so they can map idle expiry to + `session_expired_idle` distinctly. + """ try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload @@ -57,6 +73,24 @@ def decode_token(token: str) -> Optional[dict]: return None +def decode_refresh_token_strict(token: str) -> dict: + """Decode a refresh token, distinguishing idle expiry from invalid. + + Raises: + IdleTokenExpired: token signature is valid but `exp` is past — i.e. the + idle window has elapsed. + JWTError: any other decode failure (bad signature, malformed, wrong + algorithm). + + Type discrimination (`type == "refresh"`) is the caller's responsibility — + this function only inspects the JWT itself. + """ + try: + return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + except ExpiredSignatureError as e: + raise IdleTokenExpired() from e + + def create_password_reset_token(user_id: str) -> str: """Create a JWT password reset token (30-minute expiry, unique JTI).""" jti = str(uuid.uuid4()) diff --git a/backend/tests/test_session_policy.py b/backend/tests/test_session_policy.py new file mode 100644 index 00000000..b0e0d145 --- /dev/null +++ b/backend/tests/test_session_policy.py @@ -0,0 +1,103 @@ +"""Tests for the session-expiration-policy series. + +See docs/plans/2026-05-13-session-expiration-policy.md. +Test numbers below correspond to the cases listed in §6 of the plan. + +This file grows across commits — commit 2 lands the error-detail +taxonomy tests (#11 + a wrong-type case + a bad-signature case). +""" + +import uuid +from datetime import datetime, timedelta, timezone + +import pytest +from httpx import AsyncClient +from jose import jwt + +from app.core.config import settings + + +def _encode_refresh_token( + *, + sub: str, + exp: datetime, + token_type: str = "refresh", + secret: str | None = None, +) -> str: + """Build a refresh JWT with arbitrary `exp` for testing. + + Bypasses create_refresh_token so tests can produce already-expired + tokens, wrong-type tokens, or wrong-signature tokens. + """ + return jwt.encode( + { + "sub": sub, + "type": token_type, + "jti": str(uuid.uuid4()), + "exp": exp, + }, + secret or settings.SECRET_KEY, + algorithm=settings.ALGORITHM, + ) + + +class TestRefreshTokenErrorTaxonomy: + """§6 test #11 — refresh-token error-detail taxonomy. + + `/auth/refresh` distinguishes idle expiry from generic invalid-token + failures via `detail`, so the frontend can choose between the "session + ended for security" banner and a plain logout redirect. + """ + + @pytest.mark.asyncio + async def test_idle_expired_refresh_returns_session_expired_idle( + self, client: AsyncClient, test_user: dict + ): + token = _encode_refresh_token( + sub=test_user["user_data"]["id"], + exp=datetime.now(timezone.utc) - timedelta(seconds=1), + ) + + response = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "session_expired_idle" + + @pytest.mark.asyncio + async def test_wrong_type_token_returns_invalid_refresh_token( + self, client: AsyncClient, test_user: dict + ): + token = _encode_refresh_token( + sub=test_user["user_data"]["id"], + exp=datetime.now(timezone.utc) + timedelta(minutes=5), + token_type="access", + ) + + response = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "invalid_refresh_token" + + @pytest.mark.asyncio + async def test_bad_signature_returns_invalid_refresh_token( + self, client: AsyncClient, test_user: dict + ): + token = _encode_refresh_token( + sub=test_user["user_data"]["id"], + exp=datetime.now(timezone.utc) + timedelta(minutes=5), + secret="not-the-real-secret-key", + ) + + response = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "invalid_refresh_token"