feat(auth): guard login/password paths against OAuth-only users
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -62,6 +62,22 @@ def _generate_display_code() -> str:
|
|||||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||||
|
|
||||||
|
|
||||||
|
async def _reject_if_oauth_only(db: AsyncSession, user) -> None:
|
||||||
|
"""If the user has no password_hash, raise 400 with a list of linked
|
||||||
|
providers so the client can redirect them to the right OAuth flow."""
|
||||||
|
if user is None or user.password_hash is not None:
|
||||||
|
return
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthIdentity.provider).where(OAuthIdentity.user_id == user.id)
|
||||||
|
)
|
||||||
|
providers = [row for row in result.scalars().all()]
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "use_oauth_provider", "providers": providers},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
@limiter.limit("3/minute")
|
@limiter.limit("3/minute")
|
||||||
async def register(
|
async def register(
|
||||||
@@ -243,6 +259,7 @@ async def login(
|
|||||||
result = await db.execute(select(User).where(User.email == form_data.username))
|
result = await db.execute(select(User).where(User.email == form_data.username))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
await _reject_if_oauth_only(db, user)
|
||||||
if not user or not verify_password(form_data.password, user.password_hash):
|
if not user or not verify_password(form_data.password, user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -280,6 +297,7 @@ async def login_json(
|
|||||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
await _reject_if_oauth_only(db, user)
|
||||||
if not user or not verify_password(credentials.password, user.password_hash):
|
if not user or not verify_password(credentials.password, user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -445,6 +463,7 @@ async def change_password(
|
|||||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Change the current user's password."""
|
"""Change the current user's password."""
|
||||||
|
await _reject_if_oauth_only(db, current_user)
|
||||||
if not verify_password(data.current_password, current_user.password_hash):
|
if not verify_password(data.current_password, current_user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -488,7 +507,7 @@ async def forgot_password(
|
|||||||
result = await db.execute(select(User).where(User.email == data.email))
|
result = await db.execute(select(User).where(User.email == data.email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user:
|
if user and user.password_hash is not None:
|
||||||
# Create reset token JWT
|
# Create reset token JWT
|
||||||
raw_token = create_password_reset_token(str(user.id))
|
raw_token = create_password_reset_token(str(user.id))
|
||||||
payload = decode_token(raw_token)
|
payload = decode_token(raw_token)
|
||||||
|
|||||||
83
backend/tests/test_oauth_only_user_paths.py
Normal file
83
backend/tests/test_oauth_only_user_paths.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_oauth_only_user(test_db, email, *, with_identity=True):
|
||||||
|
"""Create an OAuth-only user (password_hash=None) directly in the test DB."""
|
||||||
|
import secrets
|
||||||
|
account = Account(
|
||||||
|
name=f"{email}-acct",
|
||||||
|
display_code=secrets.token_hex(4).upper(),
|
||||||
|
)
|
||||||
|
test_db.add(account)
|
||||||
|
await test_db.flush()
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
name="OAuth User",
|
||||||
|
password_hash=None,
|
||||||
|
account_id=account.id,
|
||||||
|
account_role="owner",
|
||||||
|
)
|
||||||
|
test_db.add(user)
|
||||||
|
await test_db.flush()
|
||||||
|
if with_identity:
|
||||||
|
test_db.add(OAuthIdentity(
|
||||||
|
user_id=user.id, provider="google",
|
||||||
|
provider_subject=f"google_{email}",
|
||||||
|
provider_email_at_link=email,
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_form_rejects_oauth_only_user_with_helpful_error(client, test_db):
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-only@example.com")
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data={"username": "oauth-only@example.com", "password": "wontwork"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
body = response.json()
|
||||||
|
assert body["detail"]["error"] == "use_oauth_provider"
|
||||||
|
assert "google" in body["detail"]["providers"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_json_rejects_oauth_only_user(client, test_db):
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-only2@example.com")
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": "oauth-only2@example.com", "password": "wontwork"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["detail"]["error"] == "use_oauth_provider"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_password_forgot_silent_for_oauth_only_user(client, test_db):
|
||||||
|
"""OAuth-only users get the generic message; no email is sent."""
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-forgot@example.com", with_identity=False)
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
with patch("app.core.email.EmailService.send_password_reset_email", new_callable=AsyncMock) as mock_send:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/password/forgot",
|
||||||
|
json={"email": "oauth-forgot@example.com"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_for_password_user_still_works(client, test_user):
|
||||||
|
"""Regression: existing password-based login must still succeed."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": test_user["email"], "password": test_user["password"]},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["access_token"]
|
||||||
Reference in New Issue
Block a user