feat: user management — admin create, password reset, archive/delete, quick invite
Phase 1: must_change_password enforcement + change password endpoint/page Phase 2: Admin user creation (M365-style) with temp password Phase 3: Password reset (self-service forgot + admin-triggered) Phase 4: User archive (soft delete) + hard delete with precheck Phase 5: Quick invite from admin Users page Also fixes: - Auto-create subscription for accounts missing one - Hard delete precheck ignores sole-member personal accounts - Seed script patches tree nodes for validation compliance Migrations: 031 (must_change_password), 032 (password_reset_tokens), 033 (user soft delete) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
@@ -10,6 +10,13 @@ from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
|
||||
# Routes that are allowed even when must_change_password is True
|
||||
_PASSWORD_CHANGE_ALLOWLIST = {
|
||||
"/api/v1/auth/password/change",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/me",
|
||||
}
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
@@ -65,16 +72,26 @@ async def get_refresh_token_payload(
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> User:
|
||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials."""
|
||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
||||
Enforces must_change_password — blocks all routes except allowlist."""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account has been deactivated"
|
||||
)
|
||||
|
||||
# Enforce must_change_password (backend hard lock)
|
||||
if current_user.must_change_password:
|
||||
if request.url.path not in _PASSWORD_CHANGE_ALLOWLIST:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="password_change_required"
|
||||
)
|
||||
|
||||
# Lightweight trial expiry check
|
||||
if current_user.account_id:
|
||||
from app.models.subscription import Subscription
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
@@ -8,14 +10,21 @@ from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_password_hash, generate_temp_password, create_password_reset_token, decode_token, hash_token
|
||||
from app.core.email import EmailService
|
||||
from app.models.user import User
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.password_reset_token import PasswordResetToken
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.session import Session
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.tree import Tree
|
||||
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
|
||||
from app.schemas.admin import MoveUserAccount
|
||||
from app.schemas.admin import MoveUserAccount, AdminUserCreate, AdminUserCreateResponse, AdminPasswordReset, AdminPasswordResetResponse, HardDeleteCheckResponse
|
||||
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
|
||||
from app.schemas.user_detail import (
|
||||
UserDetailResponse, AccountSummary, SubscriptionSummary,
|
||||
@@ -34,11 +43,14 @@ async def list_users(
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account")
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account"),
|
||||
include_archived: bool = Query(False, description="Include archived (soft-deleted) users"),
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
if is_active is not None:
|
||||
query = query.where(User.is_active == is_active)
|
||||
if role:
|
||||
@@ -53,6 +65,137 @@ async def list_users(
|
||||
return users
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
"""Generate a random 8-character alphanumeric display code."""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
@router.post("/users", response_model=AdminUserCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
data: AdminUserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Create a new user with a temporary password (super admin only).
|
||||
|
||||
Supports two modes:
|
||||
- existing: Join an existing account (resolved by display_code)
|
||||
- personal: Create a new personal account for the user
|
||||
"""
|
||||
# Validate mode-specific fields
|
||||
if data.account_mode == "existing":
|
||||
if not data.account_display_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="account_display_code is required for existing mode",
|
||||
)
|
||||
if not data.account_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="account_role is required for existing mode",
|
||||
)
|
||||
|
||||
# Check email uniqueness
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
# Generate temp password
|
||||
temp_password = generate_temp_password()
|
||||
password_hash = get_password_hash(temp_password)
|
||||
|
||||
if data.account_mode == "existing":
|
||||
# Resolve account by display code
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.display_code == data.account_display_code)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found for the given display code",
|
||||
)
|
||||
|
||||
new_user = User(
|
||||
email=data.email,
|
||||
password_hash=password_hash,
|
||||
name=data.name,
|
||||
role="engineer",
|
||||
account_id=account.id,
|
||||
account_role=data.account_role,
|
||||
must_change_password=True,
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush()
|
||||
|
||||
else:
|
||||
# Personal mode: create new account + user as owner
|
||||
new_account = Account(
|
||||
name=f"{data.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_user = User(
|
||||
email=data.email,
|
||||
password_hash=password_hash,
|
||||
name=data.name,
|
||||
role="engineer",
|
||||
account_id=new_account.id,
|
||||
account_role="owner",
|
||||
must_change_password=True,
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush()
|
||||
|
||||
new_account.owner_id = new_user.id
|
||||
|
||||
# Create free subscription for the new account
|
||||
new_subscription = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_subscription)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "user.create_admin", "user", new_user.id,
|
||||
{"email": data.email, "account_mode": data.account_mode},
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# Send welcome email (best-effort)
|
||||
email_sent = False
|
||||
if data.send_email:
|
||||
email_sent = await EmailService.send_welcome_email(
|
||||
to_email=data.email,
|
||||
temp_password=temp_password,
|
||||
)
|
||||
|
||||
return AdminUserCreateResponse(
|
||||
user={
|
||||
"id": str(new_user.id),
|
||||
"email": new_user.email,
|
||||
"name": new_user.name,
|
||||
"role": new_user.role,
|
||||
"is_active": new_user.is_active,
|
||||
"is_super_admin": new_user.is_super_admin,
|
||||
"account_id": str(new_user.account_id) if new_user.account_id else None,
|
||||
"account_role": new_user.account_role,
|
||||
"must_change_password": new_user.must_change_password,
|
||||
"created_at": new_user.created_at.isoformat() if new_user.created_at else None,
|
||||
},
|
||||
temporary_password=temp_password,
|
||||
email_sent=email_sent,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserDetailResponse)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
@@ -162,6 +305,7 @@ async def get_user(
|
||||
is_super_admin=user.is_super_admin,
|
||||
is_team_admin=getattr(user, "is_team_admin", False),
|
||||
created_at=user.created_at,
|
||||
deleted_at=user.deleted_at,
|
||||
account=account_summary, subscription=subscription_summary,
|
||||
invite_code_used=invite_code_used,
|
||||
recent_sessions=recent_sessions, total_sessions=total_sessions,
|
||||
@@ -321,7 +465,14 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
|
||||
)
|
||||
subscription = sub_result.scalar_one_or_none()
|
||||
if not subscription:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Subscription not found")
|
||||
# Auto-create a free subscription for accounts that predate the subscription system
|
||||
subscription = Subscription(
|
||||
account_id=user.account_id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(subscription)
|
||||
await db.flush()
|
||||
return user, subscription
|
||||
|
||||
|
||||
@@ -372,3 +523,357 @@ async def extend_user_trial(
|
||||
await db.commit()
|
||||
return {"plan": subscription.plan, "status": subscription.status,
|
||||
"current_period_end": subscription.current_period_end}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/password-reset", response_model=AdminPasswordResetResponse)
|
||||
async def admin_reset_password(
|
||||
user_id: UUID,
|
||||
data: AdminPasswordReset,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Admin-triggered password reset (super admin only).
|
||||
|
||||
Two modes:
|
||||
- email_link: sends a reset email to the user
|
||||
- temp_password: generates a temp password and returns it once
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
# Revoke all refresh tokens
|
||||
rt_result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user.id,
|
||||
RefreshToken.revoked_at.is_(None)
|
||||
)
|
||||
)
|
||||
for rt in rt_result.scalars().all():
|
||||
rt.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
user.must_change_password = True
|
||||
|
||||
if data.mode == "email_link":
|
||||
# Create reset token and send email
|
||||
raw_token = create_password_reset_token(str(user.id))
|
||||
payload = decode_token(raw_token)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = PasswordResetToken(
|
||||
token_hash=hash_token(payload["jti"]),
|
||||
user_id=user.id,
|
||||
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||
created_by_admin_id=current_user.id,
|
||||
)
|
||||
db.add(token_record)
|
||||
|
||||
await log_audit(db, current_user.id, "user.password_reset.admin_email", "user", user.id)
|
||||
await db.commit()
|
||||
|
||||
email_sent = False
|
||||
if settings.FRONTEND_URL:
|
||||
reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}"
|
||||
email_sent = await EmailService.send_password_reset_email(
|
||||
to_email=user.email, reset_url=reset_url,
|
||||
)
|
||||
|
||||
return AdminPasswordResetResponse(
|
||||
message="Password reset email sent" if email_sent else "Reset token created (email not configured)",
|
||||
email_sent=email_sent,
|
||||
)
|
||||
|
||||
else: # temp_password
|
||||
temp_pw = generate_temp_password()
|
||||
user.password_hash = get_password_hash(temp_pw)
|
||||
|
||||
await log_audit(db, current_user.id, "user.password_reset.admin_temp", "user", user.id)
|
||||
await db.commit()
|
||||
|
||||
return AdminPasswordResetResponse(
|
||||
message="Temporary password generated",
|
||||
temporary_password=temp_pw,
|
||||
email_sent=False,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/archive", response_model=UserResponse)
|
||||
async def archive_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Archive (soft delete) a user (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot archive yourself")
|
||||
|
||||
if user.deleted_at:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User is already archived")
|
||||
|
||||
user.deleted_at = datetime.now(timezone.utc)
|
||||
user.deleted_by = current_user.id
|
||||
user.is_active = False
|
||||
|
||||
# Revoke all refresh tokens
|
||||
rt_result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.user_id == user.id, RefreshToken.revoked_at.is_(None))
|
||||
)
|
||||
for rt in rt_result.scalars().all():
|
||||
rt.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
await log_audit(db, current_user.id, "user.archive", "user", user.id)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/restore", response_model=UserResponse)
|
||||
async def restore_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Restore an archived user (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
if not user.deleted_at:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User is not archived")
|
||||
|
||||
user.deleted_at = None
|
||||
user.deleted_by = None
|
||||
user.is_active = True
|
||||
|
||||
await log_audit(db, current_user.id, "user.restore", "user", user.id)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/hard-delete-check", response_model=HardDeleteCheckResponse)
|
||||
async def hard_delete_check(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Check if a user can be hard-deleted (super admin only). Returns blockers."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
blockers: dict = {}
|
||||
|
||||
# Check if user owns any accounts with OTHER members (true blocker).
|
||||
# Sole-member accounts (e.g. personal accounts) are cleaned up during delete.
|
||||
owned_account_ids_result = await db.execute(
|
||||
select(Account.id).where(Account.owner_id == user_id)
|
||||
)
|
||||
owned_account_ids = [row[0] for row in owned_account_ids_result.all()]
|
||||
shared_accounts = 0
|
||||
for acc_id in owned_account_ids:
|
||||
other_members = (await db.execute(
|
||||
select(func.count()).select_from(User).where(
|
||||
User.account_id == acc_id, User.id != user_id
|
||||
)
|
||||
)).scalar() or 0
|
||||
if other_members > 0:
|
||||
shared_accounts += 1
|
||||
if shared_accounts > 0:
|
||||
blockers["owned_accounts_with_other_members"] = shared_accounts
|
||||
|
||||
# Check authored trees
|
||||
authored_trees = (await db.execute(
|
||||
select(func.count()).select_from(Tree).where(Tree.author_id == user_id)
|
||||
)).scalar() or 0
|
||||
if authored_trees > 0:
|
||||
blockers["authored_trees"] = authored_trees
|
||||
|
||||
# Check sessions
|
||||
sessions_count = (await db.execute(
|
||||
select(func.count()).select_from(Session).where(Session.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
if sessions_count > 0:
|
||||
blockers["sessions"] = sessions_count
|
||||
|
||||
# Check audit logs
|
||||
audit_count = (await db.execute(
|
||||
select(func.count()).select_from(AuditLog).where(AuditLog.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
if audit_count > 0:
|
||||
blockers["audit_logs"] = audit_count
|
||||
|
||||
# Check invite codes created
|
||||
invites_created = (await db.execute(
|
||||
select(func.count()).select_from(InviteCode).where(InviteCode.created_by_id == user_id)
|
||||
)).scalar() or 0
|
||||
if invites_created > 0:
|
||||
blockers["invite_codes_created"] = invites_created
|
||||
|
||||
# Check account invites
|
||||
account_invites = (await db.execute(
|
||||
select(func.count()).select_from(AccountInvite).where(AccountInvite.invited_by_id == user_id)
|
||||
)).scalar() or 0
|
||||
if account_invites > 0:
|
||||
blockers["account_invites_created"] = account_invites
|
||||
|
||||
return HardDeleteCheckResponse(
|
||||
can_delete=len(blockers) == 0,
|
||||
blockers=blockers,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}/hard-delete", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def hard_delete_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Permanently delete a user (super admin only). User must be archived first."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot delete yourself")
|
||||
|
||||
if user.is_super_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot hard-delete a super admin")
|
||||
|
||||
if not user.deleted_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User must be archived before hard-deleting"
|
||||
)
|
||||
|
||||
# Run precheck
|
||||
precheck = await hard_delete_check(user_id, db, current_user)
|
||||
if not precheck.can_delete:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot delete: user has dependencies ({', '.join(precheck.blockers.keys())})"
|
||||
)
|
||||
|
||||
# Audit BEFORE delete
|
||||
await log_audit(db, current_user.id, "user.hard_delete", "user", user.id,
|
||||
{"email": user.email, "name": user.name})
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
|
||||
# Delete technical artifacts
|
||||
await db.execute(sa_delete(RefreshToken).where(RefreshToken.user_id == user_id))
|
||||
await db.execute(sa_delete(PasswordResetToken).where(PasswordResetToken.user_id == user_id))
|
||||
|
||||
# Clean up sole-member owned accounts (personal accounts)
|
||||
owned_accounts_result = await db.execute(
|
||||
select(Account).where(Account.owner_id == user_id)
|
||||
)
|
||||
for account in owned_accounts_result.scalars().all():
|
||||
# Null out owner_id first (RESTRICT FK)
|
||||
account.owner_id = None
|
||||
await db.flush()
|
||||
# Delete subscription if exists
|
||||
await db.execute(sa_delete(Subscription).where(Subscription.account_id == account.id))
|
||||
# Delete the account
|
||||
await db.delete(account)
|
||||
|
||||
# Delete the user
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/invites", status_code=status.HTTP_201_CREATED)
|
||||
async def admin_create_invite(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Quick-invite a user to an account (super admin only).
|
||||
|
||||
Body: {email, account_display_code, role}
|
||||
Creates an AccountInvite and sends the invite email.
|
||||
"""
|
||||
email = data.get("email")
|
||||
account_display_code = data.get("account_display_code")
|
||||
role = data.get("role", "engineer")
|
||||
|
||||
if not email or not account_display_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="email and account_display_code are required",
|
||||
)
|
||||
|
||||
if role not in ("engineer", "viewer"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="role must be 'engineer' or 'viewer'",
|
||||
)
|
||||
|
||||
# Resolve account
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.display_code == account_display_code)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account with display code '{account_display_code}' not found",
|
||||
)
|
||||
|
||||
# Check if email already has a pending invite to this account
|
||||
existing = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.account_id == account.id,
|
||||
AccountInvite.email == email,
|
||||
AccountInvite.accepted_by_id.is_(None),
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A pending invite already exists for this email and account",
|
||||
)
|
||||
|
||||
# Generate invite code
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account.id,
|
||||
invited_by_id=current_user.id,
|
||||
email=email,
|
||||
code=code,
|
||||
role=role,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
db.add(invite)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "user.invite_admin", "account_invite", invite.id,
|
||||
{"email": email, "account_id": str(account.id), "role": role},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Send email (best-effort)
|
||||
email_sent = await EmailService.send_account_invite_email(
|
||||
to_email=email,
|
||||
code=code,
|
||||
account_name=account.name or account_display_code,
|
||||
role=role,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(invite.id),
|
||||
"email": email,
|
||||
"code": code,
|
||||
"role": role,
|
||||
"account_display_code": account_display_code,
|
||||
"email_sent": email_sent,
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.security import (
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_password_reset_token,
|
||||
decode_token,
|
||||
hash_token,
|
||||
)
|
||||
@@ -25,7 +26,17 @@ from app.models.subscription import Subscription
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.token import Token
|
||||
from app.schemas.auth_password import (
|
||||
ChangePasswordRequest,
|
||||
ForgotPasswordRequest,
|
||||
VerifyResetTokenRequest,
|
||||
VerifyResetTokenResponse,
|
||||
ResetPasswordRequest,
|
||||
)
|
||||
from app.models.password_reset_token import PasswordResetToken
|
||||
from app.core.email import EmailService
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
from app.core.audit import log_audit
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
@@ -241,7 +252,8 @@ async def login(
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer"
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
)
|
||||
|
||||
|
||||
@@ -274,7 +286,8 @@ async def login_json(
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer"
|
||||
token_type="bearer",
|
||||
must_change_password=user.must_change_password,
|
||||
)
|
||||
|
||||
|
||||
@@ -356,3 +369,177 @@ async def logout(
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
|
||||
|
||||
@router.post("/password/change")
|
||||
@limiter.limit("5/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
data: ChangePasswordRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
if data.current_password == data.new_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be different from current password"
|
||||
)
|
||||
|
||||
current_user.password_hash = get_password_hash(data.new_password)
|
||||
current_user.must_change_password = False
|
||||
|
||||
# Revoke all refresh tokens for this user
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == current_user.id,
|
||||
RefreshToken.revoked_at.is_(None)
|
||||
)
|
||||
)
|
||||
active_tokens = result.scalars().all()
|
||||
for token in active_tokens:
|
||||
token.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
await log_audit(db, current_user.id, "auth.password_change", "user", current_user.id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
|
||||
@router.post("/password/forgot")
|
||||
@limiter.limit("3/minute")
|
||||
async def forgot_password(
|
||||
request: Request,
|
||||
data: ForgotPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
# Create reset token JWT
|
||||
raw_token = create_password_reset_token(str(user.id))
|
||||
payload = decode_token(raw_token)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = PasswordResetToken(
|
||||
token_hash=hash_token(payload["jti"]),
|
||||
user_id=user.id,
|
||||
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||
)
|
||||
db.add(token_record)
|
||||
await db.commit()
|
||||
|
||||
# Send email (best-effort)
|
||||
reset_url = f"{settings.FRONTEND_URL}/reset-password?token={raw_token}"
|
||||
await EmailService.send_password_reset_email(
|
||||
to_email=user.email,
|
||||
reset_url=reset_url,
|
||||
)
|
||||
|
||||
await log_audit(db, user.id, "auth.password_reset.request", "user", user.id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "If an account with that email exists, a reset link has been sent."}
|
||||
|
||||
|
||||
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
||||
async def verify_reset_token(
|
||||
data: VerifyResetTokenRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Verify a password reset token is valid."""
|
||||
payload = decode_token(data.token)
|
||||
if not payload or payload.get("type") != "password_reset":
|
||||
return VerifyResetTokenResponse(valid=False)
|
||||
|
||||
jti = payload.get("jti")
|
||||
if not jti:
|
||||
return VerifyResetTokenResponse(valid=False)
|
||||
|
||||
result = await db.execute(
|
||||
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if not token_record or not token_record.is_valid:
|
||||
return VerifyResetTokenResponse(valid=False)
|
||||
|
||||
# Get user email for display
|
||||
user_result = await db.execute(select(User.email).where(User.id == token_record.user_id))
|
||||
email = user_result.scalar_one_or_none()
|
||||
|
||||
return VerifyResetTokenResponse(valid=True, email=email)
|
||||
|
||||
|
||||
@router.post("/password/reset")
|
||||
@limiter.limit("5/minute")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
data: ResetPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Reset password using a valid reset token."""
|
||||
payload = decode_token(data.token)
|
||||
if not payload or payload.get("type") != "password_reset":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired reset token"
|
||||
)
|
||||
|
||||
jti = payload.get("jti")
|
||||
user_id = payload.get("sub")
|
||||
if not jti or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid reset token"
|
||||
)
|
||||
|
||||
# Validate token in DB (single-use)
|
||||
result = await db.execute(
|
||||
select(PasswordResetToken).where(PasswordResetToken.token_hash == hash_token(jti))
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if not token_record or not token_record.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Reset token has already been used or has expired"
|
||||
)
|
||||
|
||||
# Get user
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid reset token"
|
||||
)
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(data.new_password)
|
||||
user.must_change_password = False
|
||||
|
||||
# Mark token as used
|
||||
token_record.used_at = datetime.now(timezone.utc)
|
||||
|
||||
# Revoke all refresh tokens
|
||||
rt_result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == user.id,
|
||||
RefreshToken.revoked_at.is_(None)
|
||||
)
|
||||
)
|
||||
for rt in rt_result.scalars().all():
|
||||
rt.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
await log_audit(db, user.id, "auth.password_reset.complete", "user", user.id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Password has been reset successfully"}
|
||||
|
||||
Reference in New Issue
Block a user