diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index ef42147e..fa73d819 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -110,25 +110,21 @@ async def _mint_session_tokens(user: User, db: AsyncSession) -> Token: ) -async def _refresh_session_tokens( +async def _resolve_refresh_claims( payload: dict, user: User, db: AsyncSession -) -> Token: - """Carry session-policy claims forward across a refresh-token rotation. +) -> tuple[int, int, int]: + """Return (auth_time, idle_max_seconds, abs_max_seconds) for a refresh. - Grandfathers legacy tokens issued before this PR (no auth_time claim) - by snapshotting the account's current policy and treating now() as - auth_time — i.e. one free rotation under the new policy. Caller - commits. - - Does NOT enforce the absolute cap — that lands in the next commit so - the cap can be rolled back independently if needed. + Grandfathers legacy tokens issued before the session-policy PR: tokens + missing any of auth_time/idle_max/abs_max get treated as if just minted + under the account's current policy. One free rotation under the new + rules — see plan §5.1. Callers that have the claims use them as-is. """ auth_time = payload.get("auth_time") idle_max_seconds = payload.get("idle_max") abs_max_seconds = payload.get("abs_max") if auth_time is None or idle_max_seconds is None or abs_max_seconds is None: - # Grandfather path — legacy token from before the session-policy PR. account = ( await db.execute(select(Account).where(Account.id == user.account_id)) ).scalar_one() @@ -137,6 +133,21 @@ async def _refresh_session_tokens( idle_max_seconds = idle_minutes * 60 abs_max_seconds = abs_minutes * 60 + return auth_time, idle_max_seconds, abs_max_seconds + + +async def _mint_with_claims( + user: User, + auth_time: int, + idle_max_seconds: int, + abs_max_seconds: int, + db: AsyncSession, +) -> Token: + """Mint a refresh+access pair carrying explicit session-policy claims. + + Used by /auth/refresh after the grandfather + absolute-cap checks + have already produced the effective claim values. Caller commits. + """ now = datetime.now(timezone.utc) refresh_token_str = create_refresh_token( user_id=str(user.id), @@ -452,13 +463,39 @@ async def refresh_token( payload: Annotated[dict, Depends(get_refresh_token_payload)], db: Annotated[AsyncSession, Depends(get_admin_db)] ): - """Refresh access token using refresh token (rotation: old token is revoked).""" + """Refresh access token, enforcing both idle and absolute session windows. + + Algorithm (see plan §4.5): + + 1. Decode refresh JWT (the dep already rejects idle-expired tokens with + session_expired_idle). + 2. Load the user. If missing or inactive, 401 invalid_refresh_token. + 3. Resolve effective auth_time/idle_max/abs_max (grandfather legacy + tokens that pre-date this PR). + 4. Atomically revoke the JTI regardless of outcome — so an absolute- + expired token cannot be replayed; the second attempt finds it + already revoked and gets invalid_refresh_token instead. + 5. If the atomic UPDATE matched zero rows, 401 invalid_refresh_token. + 6. If now >= auth_time + abs_max, 401 session_expired_absolute. + 7. Otherwise mint new tokens carrying the claims forward. + """ user_id = payload.get("sub") jti = payload.get("jti") - # Atomically revoke the old refresh token (token rotation). - # Using a conditional UPDATE prevents the race where two concurrent - # refresh requests both read revoked_at=NULL and both succeed. + user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none() + if not user or not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid_refresh_token", + ) + + auth_time, idle_max_seconds, abs_max_seconds = await _resolve_refresh_claims( + payload, user, db + ) + + # Atomically revoke the old refresh token first — this consumes the + # token regardless of whether the absolute check passes, so an absolute- + # expired token cannot be replayed. if jti: token_hash = hash_token(jti) result = await db.execute( @@ -471,26 +508,29 @@ async def refresh_token( .returning(RefreshToken.id, RefreshToken.user_id) ) revoked_row = result.fetchone() - if not revoked_row: - # Either the token doesn't exist or was already revoked/used. - # Surfaced to the frontend as a plain logout — not "session - # expired" — because the user did not hit a policy boundary. raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid_refresh_token" + detail="invalid_refresh_token", ) - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - - if not user: + # Absolute-window check. Boundary is `>=`, not `>` — a deadline equal to + # now is expired. The token row has already been revoked above, so the + # client cannot retry this token even though we're raising after the + # consume. + now_unix = int(datetime.now(timezone.utc).timestamp()) + if now_unix >= auth_time + abs_max_seconds: + # Commit the revoke so the consumed-on-failure invariant survives + # any subsequent rollback in the request lifecycle. + await db.commit() raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid_refresh_token" + detail="session_expired_absolute", ) - token = await _refresh_session_tokens(payload, user, db) + token = await _mint_with_claims( + user, auth_time, idle_max_seconds, abs_max_seconds, db + ) await db.commit() return token diff --git a/backend/tests/test_session_policy.py b/backend/tests/test_session_policy.py index a1d1de83..301f351a 100644 --- a/backend/tests/test_session_policy.py +++ b/backend/tests/test_session_policy.py @@ -6,6 +6,7 @@ Test numbers below correspond to the cases listed in §6 of the plan. This file grows across commits: - Commit 2: error-detail taxonomy (#11 + wrong-type + bad-signature) - Commit 3: claims embedded at login + response fields surfaced (#1, #14) +- Commit 4: absolute-cap enforcement + grandfather path (#8, #9, #12) """ import uuid @@ -179,3 +180,133 @@ class TestSessionPolicyClaims: assert new_payload["exp"] >= original_payload["exp"] # JTI rotates. assert new_payload["jti"] != original_payload["jti"] + + +def _backdate_auth_time(refresh_token: str, *, seconds_back: int) -> str: + """Re-sign a refresh JWT with an earlier auth_time, preserving JTI. + + The DB row in refresh_tokens is keyed on hash(jti), so preserving jti + lets the atomic revoke step still find the row. Used to simulate + "this session is past its absolute cap" without waiting two weeks. + """ + payload = jwt.decode( + refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + payload["auth_time"] = payload["auth_time"] - seconds_back + return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + + +class TestAbsoluteCap: + """§6 tests #8, #9, #12 — absolute-cap enforcement and grandfather path.""" + + @pytest.mark.asyncio + async def test_refresh_at_absolute_deadline_rejects( + self, client: AsyncClient, test_user: dict + ): + """§6 test #8 — boundary check uses `>=`, not `>`. + + A token whose auth_time + abs_max equals now() is expired, not + valid. Backdate the original token's auth_time by exactly abs_max + seconds so now >= deadline. + """ + login_resp = await client.post( + "/api/v1/auth/login/json", + json={"email": test_user["email"], "password": test_user["password"]}, + ) + original = login_resp.json()["refresh_token"] + abs_max = jwt.decode( + original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + )["abs_max"] + + expired = _backdate_auth_time(original, seconds_back=abs_max) + + response = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {expired}"}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "session_expired_absolute" + + @pytest.mark.asyncio + async def test_absolute_expired_token_is_consumed( + self, client: AsyncClient, test_user: dict + ): + """§6 test #9 — first attempt returns session_expired_absolute and + revokes the row; second attempt sees the revoked row and returns + invalid_refresh_token. Prevents replay of an absolute-expired token. + """ + login_resp = await client.post( + "/api/v1/auth/login/json", + json={"email": test_user["email"], "password": test_user["password"]}, + ) + original = login_resp.json()["refresh_token"] + abs_max = jwt.decode( + original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + )["abs_max"] + expired = _backdate_auth_time(original, seconds_back=abs_max + 1) + + first = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {expired}"}, + ) + assert first.status_code == 401 + assert first.json()["detail"] == "session_expired_absolute" + + second = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {expired}"}, + ) + assert second.status_code == 401 + assert second.json()["detail"] == "invalid_refresh_token" + + @pytest.mark.asyncio + async def test_grandfather_path_for_legacy_token( + self, client: AsyncClient, test_user: dict, test_db + ): + """§6 test #12 — refresh token issued before this PR (no auth_time + claim) gets one successful rotation; the new token has fresh + auth_time/idle_max/abs_max claims snapshotted from current policy. + """ + from app.core.security import hash_token + from app.models.refresh_token import RefreshToken + + login_resp = await client.post( + "/api/v1/auth/login/json", + json={"email": test_user["email"], "password": test_user["password"]}, + ) + original = login_resp.json()["refresh_token"] + original_payload = jwt.decode( + original, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + + # Strip the new claims to simulate a token issued before this PR. + # JTI preserved so the DB-side revoke still finds the row. + legacy_payload = { + "sub": original_payload["sub"], + "type": "refresh", + "jti": original_payload["jti"], + "exp": original_payload["exp"], + } + legacy_token = jwt.encode( + legacy_payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM + ) + + response = await client.post( + "/api/v1/auth/refresh", + headers={"Authorization": f"Bearer {legacy_token}"}, + ) + + assert response.status_code == 200, response.json() + new_payload = jwt.decode( + response.json()["refresh_token"], + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + ) + assert new_payload.get("auth_time") is not None + assert new_payload.get("idle_max") == 3 * 24 * 60 * 60 + assert new_payload.get("abs_max") == 14 * 24 * 60 * 60 + # auth_time was set to ~now during grandfather, not preserved from + # the legacy token (since the legacy token didn't have one). + now_unix = int(datetime.now(timezone.utc).timestamp()) + assert abs(new_payload["auth_time"] - now_unix) < 10