feat(auth): guard login/password paths against OAuth-only users

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-06 15:02:28 -04:00
parent 2ef2350de7
commit b0708ed650
2 changed files with 103 additions and 1 deletions

View File

@@ -62,6 +62,22 @@ def _generate_display_code() -> str:
return ''.join(secrets.choice(chars) for _ in range(8))
async def _reject_if_oauth_only(db: AsyncSession, user) -> None:
"""If the user has no password_hash, raise 400 with a list of linked
providers so the client can redirect them to the right OAuth flow."""
if user is None or user.password_hash is not None:
return
from app.models.oauth_identity import OAuthIdentity
result = await db.execute(
select(OAuthIdentity.provider).where(OAuthIdentity.user_id == user.id)
)
providers = [row for row in result.scalars().all()]
raise HTTPException(
status_code=400,
detail={"error": "use_oauth_provider", "providers": providers},
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
@@ -243,6 +259,7 @@ async def login(
result = await db.execute(select(User).where(User.email == form_data.username))
user = result.scalar_one_or_none()
await _reject_if_oauth_only(db, user)
if not user or not verify_password(form_data.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -280,6 +297,7 @@ async def login_json(
result = await db.execute(select(User).where(User.email == credentials.email))
user = result.scalar_one_or_none()
await _reject_if_oauth_only(db, user)
if not user or not verify_password(credentials.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -445,6 +463,7 @@ async def change_password(
db: Annotated[AsyncSession, Depends(get_admin_db)]
):
"""Change the current user's password."""
await _reject_if_oauth_only(db, current_user)
if not verify_password(data.current_password, current_user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -488,7 +507,7 @@ async def forgot_password(
result = await db.execute(select(User).where(User.email == data.email))
user = result.scalar_one_or_none()
if user:
if user and user.password_hash is not None:
# Create reset token JWT
raw_token = create_password_reset_token(str(user.id))
payload = decode_token(raw_token)