feat(auth): session expiration policy (3d idle / 14d absolute) + per-account override + bulk revoke #168
@@ -7,7 +7,13 @@ from sqlalchemy import select
|
|||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.security import decode_token
|
from jose import JWTError
|
||||||
|
|
||||||
|
from app.core.security import (
|
||||||
|
IdleTokenExpired,
|
||||||
|
decode_refresh_token_strict,
|
||||||
|
decode_token,
|
||||||
|
)
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.plan_limits import PlanLimits
|
from app.models.plan_limits import PlanLimits
|
||||||
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
||||||
@@ -101,12 +107,35 @@ async def get_current_user_optional(
|
|||||||
async def get_refresh_token_payload(
|
async def get_refresh_token_payload(
|
||||||
token: Annotated[str, Depends(oauth2_scheme)]
|
token: Annotated[str, Depends(oauth2_scheme)]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Extract and validate a refresh token from the Authorization header."""
|
"""Extract and validate a refresh token from the Authorization header.
|
||||||
payload = decode_token(token)
|
|
||||||
if payload is None or payload.get("type") != "refresh":
|
Returns one of three outcomes via HTTP 401 `detail`:
|
||||||
|
- `session_expired_idle` — JWT signature valid but `exp` past
|
||||||
|
- `invalid_refresh_token` — any other decode failure, or `type != "refresh"`
|
||||||
|
- (200 path) — returns the decoded payload
|
||||||
|
|
||||||
|
The frontend uses these to choose between the "your session ended for
|
||||||
|
security" banner and a plain logout redirect. See
|
||||||
|
docs/plans/2026-05-13-session-expiration-policy.md §4.10.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = decode_refresh_token_strict(token)
|
||||||
|
except IdleTokenExpired:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid refresh token",
|
detail="session_expired_idle",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_refresh_token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_refresh_token",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
return payload
|
return payload
|
||||||
|
|||||||
@@ -5,9 +5,18 @@ import uuid
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
from jose.exceptions import ExpiredSignatureError
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class IdleTokenExpired(Exception):
|
||||||
|
"""Raised by decode_refresh_token_strict when a refresh JWT is past its `exp`.
|
||||||
|
|
||||||
|
Distinct from JWTError so callers can map idle expiry to `session_expired_idle`
|
||||||
|
on the wire while all other decode failures map to `invalid_refresh_token`.
|
||||||
|
"""
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +58,14 @@ def hash_token(jti: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def decode_token(token: str) -> Optional[dict]:
|
def decode_token(token: str) -> Optional[dict]:
|
||||||
"""Decode and validate a JWT token."""
|
"""Decode and validate a JWT token.
|
||||||
|
|
||||||
|
Collapses all jose errors (including expiry) into None — preserved for
|
||||||
|
access tokens, password-reset tokens, and email-verification tokens where
|
||||||
|
the caller does not need to distinguish expiry from invalid. Refresh tokens
|
||||||
|
use decode_refresh_token_strict instead so they can map idle expiry to
|
||||||
|
`session_expired_idle` distinctly.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||||
return payload
|
return payload
|
||||||
@@ -57,6 +73,24 @@ def decode_token(token: str) -> Optional[dict]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def decode_refresh_token_strict(token: str) -> dict:
|
||||||
|
"""Decode a refresh token, distinguishing idle expiry from invalid.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IdleTokenExpired: token signature is valid but `exp` is past — i.e. the
|
||||||
|
idle window has elapsed.
|
||||||
|
JWTError: any other decode failure (bad signature, malformed, wrong
|
||||||
|
algorithm).
|
||||||
|
|
||||||
|
Type discrimination (`type == "refresh"`) is the caller's responsibility —
|
||||||
|
this function only inspects the JWT itself.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||||
|
except ExpiredSignatureError as e:
|
||||||
|
raise IdleTokenExpired() from e
|
||||||
|
|
||||||
|
|
||||||
def create_password_reset_token(user_id: str) -> str:
|
def create_password_reset_token(user_id: str) -> str:
|
||||||
"""Create a JWT password reset token (30-minute expiry, unique JTI)."""
|
"""Create a JWT password reset token (30-minute expiry, unique JTI)."""
|
||||||
jti = str(uuid.uuid4())
|
jti = str(uuid.uuid4())
|
||||||
|
|||||||
103
backend/tests/test_session_policy.py
Normal file
103
backend/tests/test_session_policy.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Tests for the session-expiration-policy series.
|
||||||
|
|
||||||
|
See docs/plans/2026-05-13-session-expiration-policy.md.
|
||||||
|
Test numbers below correspond to the cases listed in §6 of the plan.
|
||||||
|
|
||||||
|
This file grows across commits — commit 2 lands the error-detail
|
||||||
|
taxonomy tests (#11 + a wrong-type case + a bad-signature case).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_refresh_token(
|
||||||
|
*,
|
||||||
|
sub: str,
|
||||||
|
exp: datetime,
|
||||||
|
token_type: str = "refresh",
|
||||||
|
secret: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a refresh JWT with arbitrary `exp` for testing.
|
||||||
|
|
||||||
|
Bypasses create_refresh_token so tests can produce already-expired
|
||||||
|
tokens, wrong-type tokens, or wrong-signature tokens.
|
||||||
|
"""
|
||||||
|
return jwt.encode(
|
||||||
|
{
|
||||||
|
"sub": sub,
|
||||||
|
"type": token_type,
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"exp": exp,
|
||||||
|
},
|
||||||
|
secret or settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshTokenErrorTaxonomy:
|
||||||
|
"""§6 test #11 — refresh-token error-detail taxonomy.
|
||||||
|
|
||||||
|
`/auth/refresh` distinguishes idle expiry from generic invalid-token
|
||||||
|
failures via `detail`, so the frontend can choose between the "session
|
||||||
|
ended for security" banner and a plain logout redirect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idle_expired_refresh_returns_session_expired_idle(
|
||||||
|
self, client: AsyncClient, test_user: dict
|
||||||
|
):
|
||||||
|
token = _encode_refresh_token(
|
||||||
|
sub=test_user["user_data"]["id"],
|
||||||
|
exp=datetime.now(timezone.utc) - timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "session_expired_idle"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_type_token_returns_invalid_refresh_token(
|
||||||
|
self, client: AsyncClient, test_user: dict
|
||||||
|
):
|
||||||
|
token = _encode_refresh_token(
|
||||||
|
sub=test_user["user_data"]["id"],
|
||||||
|
exp=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||||
|
token_type="access",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "invalid_refresh_token"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bad_signature_returns_invalid_refresh_token(
|
||||||
|
self, client: AsyncClient, test_user: dict
|
||||||
|
):
|
||||||
|
token = _encode_refresh_token(
|
||||||
|
sub=test_user["user_data"]["id"],
|
||||||
|
exp=datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||||
|
secret="not-the-real-secret-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "invalid_refresh_token"
|
||||||
Reference in New Issue
Block a user