Fix broken JWT token refresh that caused "Failed to load trees" after idle timeout. The refresh endpoint expected token as query param but frontend sent it as Authorization header. Added proper dependency (get_refresh_token_payload) and refresh queue to handle concurrent 401s. Also fix seed trees not being visible to non-admin users by updating the seed script to set is_public/is_default on existing trees. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
194 lines
6.0 KiB
Python
194 lines
6.0 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Annotated
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.core.security import (
|
|
verify_password,
|
|
get_password_hash,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
)
|
|
from app.models.user import User
|
|
from app.models.invite_code import InviteCode
|
|
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
|
from app.schemas.token import Token
|
|
from app.api.deps import get_current_user, get_refresh_token_payload
|
|
|
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
|
|
|
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
user_data: UserCreate,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Register a new user."""
|
|
# Validate invite code if required
|
|
invite_code_record = None
|
|
if settings.REQUIRE_INVITE_CODE:
|
|
if not user_data.invite_code:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code is required"
|
|
)
|
|
|
|
# Look up invite code (case-insensitive)
|
|
result = await db.execute(
|
|
select(InviteCode).where(InviteCode.code == user_data.invite_code.upper())
|
|
)
|
|
invite_code_record = result.scalar_one_or_none()
|
|
|
|
if not invite_code_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid invite code"
|
|
)
|
|
|
|
if invite_code_record.is_used:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code has already been used"
|
|
)
|
|
|
|
if invite_code_record.is_expired:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invite code has expired"
|
|
)
|
|
|
|
# Check if email already exists
|
|
result = await db.execute(select(User).where(User.email == user_data.email))
|
|
existing_user = result.scalar_one_or_none()
|
|
if existing_user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email already registered"
|
|
)
|
|
|
|
# Create new user
|
|
new_user = User(
|
|
email=user_data.email,
|
|
password_hash=get_password_hash(user_data.password),
|
|
name=user_data.name,
|
|
role=user_data.role,
|
|
invite_code_id=invite_code_record.id if invite_code_record else None
|
|
)
|
|
db.add(new_user)
|
|
|
|
# Mark invite code as used
|
|
if invite_code_record:
|
|
invite_code_record.used_by_id = new_user.id
|
|
invite_code_record.used_at = datetime.now(timezone.utc)
|
|
|
|
await db.commit()
|
|
await db.refresh(new_user)
|
|
|
|
return new_user
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
async def login(
|
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Login and get access token."""
|
|
# Find user by email
|
|
result = await db.execute(select(User).where(User.email == form_data.username))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(form_data.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Update last login
|
|
user.last_login = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
# Create tokens
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
|
|
|
|
@router.post("/login/json", response_model=Token)
|
|
async def login_json(
|
|
credentials: UserLogin,
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Login with JSON body (alternative to form data)."""
|
|
result = await db.execute(select(User).where(User.email == credentials.email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not verify_password(credentials.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password"
|
|
)
|
|
|
|
user.last_login = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=Token)
|
|
async def refresh_token(
|
|
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
|
db: Annotated[AsyncSession, Depends(get_db)]
|
|
):
|
|
"""Refresh access token using refresh token."""
|
|
user_id = payload.get("sub")
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found"
|
|
)
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
|
new_refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=new_refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_me(
|
|
current_user: Annotated[User, Depends(get_current_user)]
|
|
):
|
|
"""Get current authenticated user."""
|
|
return current_user
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout():
|
|
"""Logout user (client should discard tokens)."""
|
|
# JWT tokens are stateless, so logout is handled client-side
|
|
# In a production app, you might want to blacklist the token
|
|
return {"message": "Successfully logged out"}
|