fix: high-severity security hardening (Phase B permissions audit)

Phase B addresses 7 high-severity gaps from the permissions audit:

- B1: Enforce tree access check on session start via can_access_tree
- B2: Replace all inline permission helpers with centralized permissions.py
- B3: Fix require_engineer_or_admin to check is_team_admin before role
- B4: Add is_active field on User with enforcement in get_current_active_user
- B5: Add admin user management endpoints (list, get, role, team-admin, deactivate, activate)
- B6: Add rate limiting on auth/invite endpoints via slowapi (disabled in DEBUG)
- B7: Implement refresh token rotation with JTI-based revocation and meaningful logout

Also reduces access token TTL from 15 to 5 minutes and updates CLAUDE.md
with SaaS/MSP context for future planning sessions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-05 22:44:05 -05:00
parent 3e0fb92012
commit 71ba0b95a5
27 changed files with 743 additions and 229 deletions

View File

@@ -0,0 +1,29 @@
"""add is_active field to users
Revision ID: 012
Revises: 011
Create Date: 2026-02-05
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '012'
down_revision: Union[str, None] = '011'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
'users',
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true')
)
def downgrade() -> None:
op.drop_column('users', 'is_active')

View File

@@ -0,0 +1,35 @@
"""add refresh_tokens table
Revision ID: 013
Revises: 012
Create Date: 2026-02-05
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '013'
down_revision: Union[str, None] = '012'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'refresh_tokens',
sa.Column('id', UUID(as_uuid=True), primary_key=True),
sa.Column('token_hash', sa.String(64), unique=True, nullable=False, index=True),
sa.Column('user_id', UUID(as_uuid=True), sa.ForeignKey('users.id'), nullable=False, index=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table('refresh_tokens')

View File

@@ -67,8 +67,11 @@ async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
"""Ensure user is active (not disabled)."""
# For now, all users are considered active
# Add logic here if you add an is_active field to User
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account has been deactivated"
)
return current_user
@@ -90,6 +93,8 @@ async def require_engineer_or_admin(
"""Require engineer, team admin, or super admin role (blocks viewers)."""
if current_user.is_super_admin:
return current_user
if current_user.is_team_admin and current_user.team_id is not None:
return current_user
if current_user.role not in ("engineer",):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -0,0 +1,166 @@
from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.core.database import get_db
from app.models.user import User
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
from app.api.deps import require_admin
router = APIRouter(prefix="/admin", tags=["admin"])
@router.get("/users", response_model=list[UserResponse])
async def list_users(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
role: Optional[str] = Query(None, description="Filter by role"),
team_id: Optional[UUID] = Query(None, description="Filter by team")
):
"""List all users (super admin only)."""
query = select(User)
if is_active is not None:
query = query.where(User.is_active == is_active)
if role:
query = query.where(User.role == role)
if team_id:
query = query.where(User.team_id == team_id)
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
users = result.scalars().all()
return users
@router.get("/users/{user_id}", response_model=UserResponse)
async def get_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Get user details (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.put("/users/{user_id}/role", response_model=UserResponse)
async def update_user_role(
user_id: UUID,
role_data: RoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Change user role (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot change your own role"
)
user.role = role_data.role
await db.commit()
await db.refresh(user)
return user
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
async def toggle_team_admin(
user_id: UUID,
data: TeamAdminUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Toggle is_team_admin for a user (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if data.is_team_admin and user.team_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User must belong to a team to be a team admin"
)
user.is_team_admin = data.is_team_admin
await db.commit()
await db.refresh(user)
return user
@router.put("/users/{user_id}/deactivate", response_model=UserResponse)
async def deactivate_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Deactivate a user account (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot deactivate your own account"
)
user.is_active = False
await db.commit()
await db.refresh(user)
return user
@router.put("/users/{user_id}/activate", response_model=UserResponse)
async def activate_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Reactivate a user account (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
user.is_active = True
await db.commit()
await db.refresh(user)
return user

View File

@@ -1,29 +1,46 @@
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.config import settings
from app.core.database import get_db
from app.core.rate_limit import limiter
from app.core.security import (
verify_password,
get_password_hash,
create_access_token,
create_refresh_token,
decode_token,
hash_token,
)
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.schemas.user import UserCreate, UserResponse, UserLogin
from app.schemas.token import Token
from app.api.deps import get_current_user, get_refresh_token_payload
from app.api.deps import get_current_active_user, get_refresh_token_payload
router = APIRouter(prefix="/auth", tags=["authentication"])
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
"""Decode a refresh token JWT and store its hash in the database."""
payload = decode_token(refresh_token_str)
if payload and payload.get("jti"):
token_record = RefreshToken(
token_hash=hash_token(payload["jti"]),
user_id=user_id,
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
)
db.add(token_record)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
request: Request,
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_db)]
):
@@ -92,7 +109,9 @@ async def register(
@router.post("/login", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[AsyncSession, Depends(get_db)]
):
@@ -110,21 +129,26 @@ async def login(
# Update last login
user.last_login = datetime.now(timezone.utc)
await db.commit()
# Create tokens
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token = create_refresh_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token,
refresh_token=refresh_token_str,
token_type="bearer"
)
@router.post("/login/json", response_model=Token)
@limiter.limit("5/minute")
async def login_json(
request: Request,
credentials: UserLogin,
db: Annotated[AsyncSession, Depends(get_db)]
):
@@ -139,25 +163,50 @@ async def login_json(
)
user.last_login = datetime.now(timezone.utc)
await db.commit()
access_token = create_access_token(data={"sub": str(user.id)})
refresh_token = create_refresh_token(data={"sub": str(user.id)})
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store refresh token hash in DB
await _store_refresh_token(db, refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token,
refresh_token=refresh_token_str,
token_type="bearer"
)
@router.post("/refresh", response_model=Token)
@limiter.limit("10/minute")
async def refresh_token(
request: Request,
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Refresh access token using refresh token."""
"""Refresh access token using refresh token (rotation: old token is revoked)."""
user_id = payload.get("sub")
jti = payload.get("jti")
# Validate refresh token hasn't been revoked
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
stored_token = result.scalar_one_or_none()
if stored_token and stored_token.is_revoked:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has been revoked"
)
# Revoke the old refresh token (token rotation)
if stored_token:
stored_token.revoked_at = datetime.now(timezone.utc)
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -168,26 +217,42 @@ async def refresh_token(
)
access_token = create_access_token(data={"sub": str(user.id)})
new_refresh_token = create_refresh_token(data={"sub": str(user.id)})
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
# Store new refresh token
await _store_refresh_token(db, new_refresh_token_str, user.id)
await db.commit()
return Token(
access_token=access_token,
refresh_token=new_refresh_token,
refresh_token=new_refresh_token_str,
token_type="bearer"
)
@router.get("/me", response_model=UserResponse)
async def get_me(
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current authenticated user."""
return current_user
@router.post("/logout")
async def logout():
"""Logout user (client should discard tokens)."""
# JWT tokens are stateless, so logout is handled client-side
# In a production app, you might want to blacklist the token
async def logout(
payload: Annotated[dict, Depends(get_refresh_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Logout user by revoking the refresh token."""
jti = payload.get("jti")
if jti:
token_hash = hash_token(jti)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
stored_token = result.scalar_one_or_none()
if stored_token and not stored_token.is_revoked:
stored_token.revoked_at = datetime.now(timezone.utc)
await db.commit()
return {"message": "Successfully logged out"}

View File

@@ -10,7 +10,8 @@ from app.models.category import TreeCategory
from app.models.tree import Tree
from app.models.user import User
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse
from app.api.deps import get_current_user
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_category, can_create_category
router = APIRouter(prefix="/categories", tags=["categories"])
@@ -22,32 +23,10 @@ def slugify(name: str) -> str:
return slug
def can_manage_category(user: User, category: TreeCategory) -> bool:
"""Check if user can manage (edit/delete) a category."""
# Global admins can manage any category
if user.is_super_admin:
return True
# Team admins can manage their team's categories
if user.is_team_admin and category.team_id == user.team_id:
return True
return False
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
"""Check if user can create a category for the given team."""
# Global admins can create global categories (team_id=None) or any team's categories
if user.is_super_admin:
return True
# Team admins can only create categories for their own team
if user.is_team_admin and team_id == user.team_id:
return True
return False
@router.get("", response_model=list[CategoryListResponse])
async def list_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
):
@@ -110,7 +89,7 @@ async def list_categories(
async def get_category(
category_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific category by ID."""
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))
@@ -155,7 +134,7 @@ async def get_category(
async def create_category(
category_data: CategoryCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Create a new category.
@@ -221,7 +200,7 @@ async def update_category(
category_id: UUID,
category_data: CategoryUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Update a category."""
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))
@@ -291,7 +270,7 @@ async def update_category(
async def delete_category(
category_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Soft delete (archive) a category."""
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))

View File

@@ -17,7 +17,8 @@ from app.schemas.folder import (
FolderReorderRequest,
FolderTreeRequest
)
from app.api.deps import get_current_user
from app.api.deps import get_current_active_user
from app.core.permissions import can_access_tree
router = APIRouter(prefix="/folders", tags=["folders"])
@@ -63,30 +64,10 @@ async def is_descendant(db: AsyncSession, potential_descendant_id: UUID, ancesto
return False
def can_access_tree(user: User, tree: Tree) -> bool:
"""Check if user can access a tree (to add to folder).
User can access tree if:
- Tree is public
- User is the author
- Tree belongs to user's team
- User is a global admin
"""
if tree.is_public:
return True
if user.id == tree.author_id:
return True
if tree.team_id == user.team_id and user.team_id is not None:
return True
if user.is_super_admin:
return True
return False
@router.get("", response_model=list[FolderListResponse])
async def list_folders(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""List all folders for the current user.
@@ -120,7 +101,7 @@ async def list_folders(
async def get_folder(
folder_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific folder by ID."""
result = await db.execute(
@@ -160,7 +141,7 @@ async def get_folder(
async def create_folder(
folder_data: FolderCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Create a new folder for the current user.
@@ -241,7 +222,7 @@ async def update_folder(
folder_id: UUID,
folder_data: FolderUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Update a folder.
@@ -352,7 +333,7 @@ async def update_folder(
async def delete_folder(
folder_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Delete a folder.
@@ -384,7 +365,7 @@ async def delete_folder(
async def reorder_folders(
reorder_data: FolderReorderRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Reorder folders by providing folder IDs in desired order."""
# Get all user's folders
@@ -414,7 +395,7 @@ async def add_tree_to_folder(
folder_id: UUID,
request: FolderTreeRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Add a tree to a folder."""
# Get folder with trees
@@ -474,7 +455,7 @@ async def remove_tree_from_folder(
folder_id: UUID,
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Remove a tree from a folder."""
# Get folder with trees
@@ -519,7 +500,7 @@ async def remove_tree_from_folder(
async def get_folder_tree_ids(
folder_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get all tree IDs in a folder.

View File

@@ -1,10 +1,10 @@
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.rate_limit import limiter
from app.models.user import User
from app.models.invite_code import InviteCode
from app.schemas.invite_code import InviteCodeCreate, InviteCodeResponse, InviteCodeValidation
@@ -74,7 +74,9 @@ async def revoke_invite_code(
@router.get("/validate/{code}", response_model=InviteCodeValidation)
@limiter.limit("5/minute")
async def validate_invite_code(
request: Request,
code: str,
db: Annotated[AsyncSession, Depends(get_db)]
):

View File

@@ -12,7 +12,8 @@ from app.models.tree import Tree
from app.models.session import Session
from app.models.user import User
from app.schemas.session import SessionCreate, SessionUpdate, SessionResponse, SessionExport, ScratchpadUpdate
from app.api.deps import get_current_user
from app.api.deps import get_current_active_user
from app.core.permissions import can_access_tree
router = APIRouter(prefix="/sessions", tags=["sessions"])
@@ -20,7 +21,7 @@ router = APIRouter(prefix="/sessions", tags=["sessions"])
@router.get("", response_model=list[SessionResponse])
async def list_sessions(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
completed: Optional[bool] = Query(None, description="Filter by completion status"),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100)
@@ -46,7 +47,7 @@ async def list_sessions(
async def get_session(
session_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific session."""
result = await db.execute(select(Session).where(Session.id == session_id))
@@ -71,7 +72,7 @@ async def get_session(
async def start_session(
session_data: SessionCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Start a new troubleshooting session."""
# Get the tree
@@ -90,6 +91,12 @@ async def start_session(
detail="Cannot start session with inactive tree"
)
if not can_access_tree(current_user, tree):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
)
# Create session with tree snapshot
new_session = Session(
tree_id=tree.id,
@@ -115,7 +122,7 @@ async def update_session(
session_id: UUID,
session_data: SessionUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Update session (add decisions, notes, etc.)."""
result = await db.execute(select(Session).where(Session.id == session_id))
@@ -154,7 +161,7 @@ async def update_session(
async def complete_session(
session_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Mark session as complete."""
result = await db.execute(select(Session).where(Session.id == session_id))
@@ -189,7 +196,7 @@ async def update_scratchpad(
session_id: UUID,
data: ScratchpadUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Update session scratchpad. Accepts updates on both active and completed sessions."""
result = await db.execute(select(Session).where(Session.id == session_id))
@@ -218,7 +225,7 @@ async def export_session(
session_id: UUID,
export_options: SessionExport,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Export session to formatted notes."""
result = await db.execute(select(Session).where(Session.id == session_id))

View File

@@ -14,37 +14,16 @@ from app.schemas.step_category import (
StepCategoryListResponse,
slugify
)
from app.api.deps import get_current_user
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_step_category, can_create_step_category
router = APIRouter(prefix="/step-categories", tags=["step-categories"])
def can_manage_step_category(user: User, category: StepCategory) -> bool:
"""Check if user can manage (edit/delete) a step category."""
# Global admins can manage any category
if user.is_super_admin:
return True
# Team admins can manage their team's categories
if user.is_team_admin and category.team_id == user.team_id:
return True
return False
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
"""Check if user can create a step category for the given team."""
# Global admins can create global categories (team_id=None) or any team's categories
if user.is_super_admin:
return True
# Team admins can only create categories for their own team
if user.is_team_admin and team_id == user.team_id:
return True
return False
@router.get("", response_model=list[StepCategoryListResponse])
async def list_step_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
):
@@ -100,7 +79,7 @@ async def list_step_categories(
async def get_step_category(
category_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific step category by ID."""
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))
@@ -137,7 +116,7 @@ async def get_step_category(
async def create_step_category(
category_data: StepCategoryCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Create a new step category.
@@ -203,7 +182,7 @@ async def update_step_category(
category_id: UUID,
category_data: StepCategoryUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Update a step category."""
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))
@@ -265,7 +244,7 @@ async def update_step_category(
async def delete_step_category(
category_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Soft delete (archive) a step category."""
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))

View File

@@ -7,7 +7,8 @@ from sqlalchemy import select, or_, and_, func, desc, Integer, case
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_user, require_engineer_or_admin
from app.api.deps import get_current_active_user, require_engineer_or_admin
from app.core.permissions import can_view_step, can_edit_step
from app.models.user import User
from app.models.step_library import StepLibrary, StepRating
from app.models.step_category import StepCategory
@@ -25,27 +26,6 @@ from app.schemas.step_library import (
router = APIRouter(prefix="/steps", tags=["steps"])
# Permission helpers
def can_view_step(user: User, step: StepLibrary) -> bool:
"""Check if user can view a step based on visibility."""
if step.visibility == 'public':
return True
if step.visibility == 'private':
return step.created_by == user.id
if step.visibility == 'team':
return step.team_id == user.team_id or user.is_super_admin
return False
def can_edit_step(user: User, step: StepLibrary) -> bool:
"""Check if user can edit/delete a step."""
if user.is_super_admin:
return True
if user.role == 'viewer':
return False
return step.created_by == user.id
async def get_step_or_404(
step_id: UUID,
db: AsyncSession,
@@ -99,7 +79,7 @@ async def list_steps(
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""List steps with filters and pagination."""
query = select(StepLibrary).where(
@@ -177,7 +157,7 @@ async def search_steps(
q: str = Query(..., min_length=1),
limit: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Full-text search for steps."""
# Use PostgreSQL full-text search
@@ -229,7 +209,7 @@ async def search_steps(
async def get_popular_tags(
limit: int = Query(20, ge=1, le=50),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Get popular tags with usage counts."""
# Use unnest to expand arrays and count occurrences
@@ -255,7 +235,7 @@ async def get_popular_tags(
async def get_step(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Get a step by ID."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
@@ -374,7 +354,7 @@ async def update_step(
step_id: UUID,
step_data: StepLibraryUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Update a step (owner or admin only)."""
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
@@ -444,7 +424,7 @@ async def update_step(
async def delete_step(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Soft delete a step (owner or admin only)."""
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
@@ -462,7 +442,7 @@ async def rate_step(
step_id: UUID,
rating_data: StepRatingCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Rate a step (1-5 stars with optional review)."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
@@ -516,7 +496,7 @@ async def update_rating(
step_id: UUID,
rating_data: StepRatingUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Update your rating for a step."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
@@ -563,7 +543,7 @@ async def update_rating(
async def delete_rating(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Delete your rating for a step."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
@@ -593,7 +573,7 @@ async def get_reviews(
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_active_user)
):
"""Get reviews for a step."""
await get_step_or_404(step_id, db, current_user, check_view=True)

View File

@@ -10,52 +10,16 @@ from app.models.tag import TreeTag, tree_tag_assignments
from app.models.tree import Tree
from app.models.user import User
from app.schemas.tag import TagCreate, TagResponse, TagListResponse, TagAssignment
from app.api.deps import get_current_user
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_tree_tags, can_create_tag
router = APIRouter(prefix="/tags", tags=["tags"])
def can_manage_tree_tags(user: User, tree: Tree) -> bool:
"""Check if user can manage tags on a tree.
Allowed:
- Tree author (engineer+)
- Super admins
- Team admins for their team's trees
"""
if user.is_super_admin:
return True
if user.role == "viewer":
return False
if user.id == tree.author_id:
return True
if user.is_team_admin and tree.team_id == user.team_id:
return True
return False
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
"""Check if user can create a tag for the given scope.
- Super admins can create global tags (team_id=None)
- Team admins and super admins can create team-specific tags
- Engineers can create team tags for their own team
- Viewers cannot create tags
"""
if user.is_super_admin:
return True
if user.role == "viewer":
return False
# For team-specific tags, user must belong to that team
if team_id is not None and team_id == user.team_id:
return True
return False
@router.get("", response_model=list[TagListResponse])
async def list_tags(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_team: bool = Query(True, description="Include team-specific tags")
):
"""List tags visible to the user.
@@ -88,7 +52,7 @@ async def list_tags(
@router.get("/search", response_model=list[TagListResponse])
async def search_tags(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
q: str = Query(..., min_length=1, description="Search query"),
limit: int = Query(10, ge=1, le=50),
include_team: bool = Query(True, description="Include team-specific tags")
@@ -125,7 +89,7 @@ async def search_tags(
async def get_tag(
tag_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific tag by ID."""
result = await db.execute(select(TreeTag).where(TreeTag.id == tag_id))
@@ -151,7 +115,7 @@ async def get_tag(
async def create_tag(
tag_data: TagCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Create a new tag.
@@ -197,7 +161,7 @@ async def add_tags_to_tree(
tree_id: UUID,
tag_data: TagAssignment,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Add tags to a tree.
@@ -281,7 +245,7 @@ async def remove_tag_from_tree(
tree_id: UUID,
tag_slug: str,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Remove a tag from a tree."""
# Get tree with tags
@@ -330,7 +294,7 @@ async def replace_tree_tags(
tree_id: UUID,
tag_data: TagAssignment,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Replace all tags on a tree.
@@ -412,7 +376,7 @@ async def replace_tree_tags(
async def get_tree_tags(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get all tags assigned to a tree."""
# Get tree with tags

View File

@@ -12,7 +12,8 @@ from app.models.category import TreeCategory
from app.models.tag import TreeTag
from app.models.folder import UserFolder
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.api.deps import get_current_user, require_engineer_or_admin, require_admin
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
router = APIRouter(prefix="/trees", tags=["trees"])
@@ -98,7 +99,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
@router.get("", response_model=list[TreeListResponse])
async def list_trees(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
category: Optional[str] = Query(None, description="Filter by legacy category string"),
category_id: Optional[UUID] = Query(None, description="Filter by category ID"),
tags: Optional[str] = Query(None, description="Comma-separated tag slugs to filter by"),
@@ -176,7 +177,7 @@ async def list_trees(
@router.get("/categories", response_model=list[str])
async def list_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""List all unique categories from trees the user can access.
@@ -196,7 +197,7 @@ async def list_categories(
@router.get("/search", response_model=list[TreeListResponse])
async def search_trees(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
current_user: Annotated[User, Depends(get_current_active_user)],
q: str = Query(..., min_length=2, description="Search query"),
limit: int = Query(20, ge=1, le=50)
):
@@ -226,7 +227,7 @@ async def search_trees(
async def get_tree(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get a specific tree by ID."""
result = await db.execute(
@@ -245,15 +246,7 @@ async def get_tree(
detail="Tree not found"
)
# Check access: tree must be active AND (default OR public OR author OR same team)
can_access = (
tree.is_default or
tree.is_public or
tree.author_id == current_user.id or
(tree.team_id == current_user.team_id and current_user.team_id is not None) or
current_user.is_super_admin
)
if not tree.is_active or not can_access:
if not tree.is_active or not can_access_tree(current_user, tree):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
@@ -399,13 +392,7 @@ async def update_tree(
detail="Tree not found"
)
# Check if user can edit: must be author, team admin for team trees, or global admin
can_edit = (
tree.author_id == current_user.id or
current_user.is_super_admin or
(current_user.is_team_admin and tree.team_id == current_user.team_id)
)
if not can_edit:
if not can_edit_tree(current_user, tree):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only edit your own trees"

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
api_router = APIRouter()
@@ -12,3 +12,4 @@ api_router.include_router(tags.router)
api_router.include_router(folders.router)
api_router.include_router(step_categories.router)
api_router.include_router(steps.router)
api_router.include_router(admin.router)

View File

@@ -43,7 +43,7 @@ class Settings(BaseSettings):
)
return v
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
ACCESS_TOKEN_EXPIRE_MINUTES: int = 5
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# Security

View File

@@ -9,13 +9,15 @@ Role hierarchy: super_admin > team_admin > engineer > viewer
- viewer: role='viewer', read-only (can browse, run sessions, rate steps)
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
from uuid import UUID
if TYPE_CHECKING:
from app.models.user import User
from app.models.tree import Tree
from app.models.step_library import StepLibrary
from app.models.category import TreeCategory
from app.models.step_category import StepCategory
ROLE_HIERARCHY = {
"super_admin": 4,
@@ -92,3 +94,78 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool:
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
return True
return False
def can_access_tree(user: User, tree: Tree) -> bool:
"""Can the user access (view) this tree?"""
if tree.is_default or tree.is_public:
return True
if tree.author_id == user.id:
return True
if tree.team_id == user.team_id and user.team_id is not None:
return True
if user.is_super_admin:
return True
return False
def can_view_step(user: User, step: StepLibrary) -> bool:
"""Can the user view this step based on its visibility?"""
if step.visibility == "public":
return True
if step.visibility == "private":
return step.created_by == user.id
if step.visibility == "team":
return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin
return False
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
"""Can the user create a tag for the given scope?
- Super admins can create global tags (team_id=None) or any team's tags
- Engineers can create team tags for their own team
- Viewers cannot create tags
"""
if user.is_super_admin:
return True
if not can_create_content(user):
return False
if team_id is not None and team_id == user.team_id:
return True
return False
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
"""Can the user create a category for the given team?
- Super admins can create global or any team's categories
- Team admins can create categories for their own team
"""
if user.is_super_admin:
return True
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
return True
return False
def can_manage_step_category(user: User, category: StepCategory) -> bool:
"""Can the user edit/delete this step category?"""
if user.is_super_admin:
return True
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
return True
return False
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
"""Can the user create a step category for the given team?
- Super admins can create global or any team's step categories
- Team admins can create step categories for their own team
"""
if user.is_super_admin:
return True
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
return True
return False

View File

@@ -0,0 +1,6 @@
from slowapi import Limiter
from slowapi.util import get_remote_address
from app.core.config import settings
limiter = Limiter(key_func=get_remote_address, enabled=not settings.DEBUG)

View File

@@ -1,3 +1,5 @@
import hashlib
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
@@ -30,14 +32,20 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
def create_refresh_token(data: dict) -> str:
"""Create a JWT refresh token."""
"""Create a JWT refresh token with a unique jti for revocation tracking."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
jti = str(uuid.uuid4())
to_encode.update({"exp": expire, "type": "refresh", "jti": jti})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def hash_token(jti: str) -> str:
"""Hash a token JTI for secure storage."""
return hashlib.sha256(jti.encode()).hexdigest()
def decode_token(token: str) -> Optional[dict]:
"""Decode and validate a JWT token."""
try:

View File

@@ -2,11 +2,14 @@ import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from app.core.config import settings
from app.core.database import init_db
from app.core.logging_config import setup_logging
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
from app.core.rate_limit import limiter
from app.api.router import api_router
# Initialize logging configuration
@@ -38,6 +41,9 @@ app = FastAPI(
lifespan=lifespan
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add logging middleware (BEFORE CORS to log all requests)
app.add_middleware(ErrorLoggingMiddleware)
app.add_middleware(RequestLoggingMiddleware)

View File

@@ -9,6 +9,7 @@ from .tag import TreeTag, tree_tag_assignments
from .folder import UserFolder, user_folder_trees
from .step_category import StepCategory
from .step_library import StepLibrary, StepRating, StepUsageLog
from .refresh_token import RefreshToken
__all__ = [
"User",
@@ -26,4 +27,5 @@ __all__ = [
"StepLibrary",
"StepRating",
"StepUsageLog",
"RefreshToken",
]

View File

@@ -0,0 +1,38 @@
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import String, DateTime, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=False,
index=True
)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc)
)
@property
def is_revoked(self) -> bool:
return self.revoked_at is not None
@property
def is_expired(self) -> bool:
return datetime.now(timezone.utc) > self.expires_at

View File

@@ -33,6 +33,7 @@ class User(Base):
role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id"),

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional
from typing import Literal, Optional
from uuid import UUID
from pydantic import BaseModel, EmailStr, Field
@@ -29,9 +29,18 @@ class UserResponse(UserBase):
role: str
is_super_admin: bool = False
is_team_admin: bool = False
is_active: bool = True
team_id: Optional[UUID] = None
created_at: datetime
last_login: Optional[datetime] = None
class Config:
from_attributes = True
class RoleUpdate(BaseModel):
role: Literal["engineer", "viewer"]
class TeamAdminUpdate(BaseModel):
is_team_admin: bool

View File

@@ -19,5 +19,8 @@ pydantic==2.6.1
pydantic-settings==2.1.0
email-validator==2.1.0
# Rate Limiting
slowapi==0.1.9
# Utilities
python-dotenv==1.0.1

129
backend/tests/test_admin.py Normal file
View File

@@ -0,0 +1,129 @@
"""Integration tests for admin user management endpoints."""
import pytest
from httpx import AsyncClient
class TestAdminEndpoints:
"""Test suite for admin user management endpoints."""
@pytest.mark.asyncio
async def test_list_users_as_admin(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test listing users as a super admin."""
response = await client.get(
"/api/v1/admin/users", headers=admin_auth_headers
)
assert response.status_code == 200
users = response.json()
assert len(users) >= 2 # admin + test_user
@pytest.mark.asyncio
async def test_list_users_as_non_admin(
self, client: AsyncClient, auth_headers: dict
):
"""Test that non-admin users cannot list users."""
response = await client.get(
"/api/v1/admin/users", headers=auth_headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_get_user_as_admin(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test getting user details as admin."""
user_id = test_user["user_data"]["id"]
response = await client.get(
f"/api/v1/admin/users/{user_id}", headers=admin_auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["email"] == test_user["email"]
@pytest.mark.asyncio
async def test_change_user_role(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test changing a user's role to viewer."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/role",
json={"role": "viewer"},
headers=admin_auth_headers
)
assert response.status_code == 200
assert response.json()["role"] == "viewer"
@pytest.mark.asyncio
async def test_change_role_invalid(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test that invalid role values are rejected."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/role",
json={"role": "admin"},
headers=admin_auth_headers
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_cannot_change_own_role(
self, client: AsyncClient, admin_auth_headers: dict, test_admin: dict
):
"""Test that admin cannot change their own role."""
admin_id = test_admin["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{admin_id}/role",
json={"role": "viewer"},
headers=admin_auth_headers
)
assert response.status_code == 400
assert "own role" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_deactivate_user(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test deactivating a user."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/deactivate",
headers=admin_auth_headers
)
assert response.status_code == 200
assert response.json()["is_active"] is False
@pytest.mark.asyncio
async def test_activate_user(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test reactivating a user."""
user_id = test_user["user_data"]["id"]
# Deactivate first
await client.put(
f"/api/v1/admin/users/{user_id}/deactivate",
headers=admin_auth_headers
)
# Then reactivate
response = await client.put(
f"/api/v1/admin/users/{user_id}/activate",
headers=admin_auth_headers
)
assert response.status_code == 200
assert response.json()["is_active"] is True
@pytest.mark.asyncio
async def test_cannot_deactivate_self(
self, client: AsyncClient, admin_auth_headers: dict, test_admin: dict
):
"""Test that admin cannot deactivate themselves."""
admin_id = test_admin["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{admin_id}/deactivate",
headers=admin_auth_headers
)
assert response.status_code == 400
assert "own account" in response.json()["detail"].lower()

View File

@@ -684,3 +684,40 @@ class TestSessions:
content = response.text
assert '<script>' not in content
assert '&lt;script&gt;' in content
@pytest.mark.asyncio
async def test_start_session_on_others_private_tree_forbidden(
self, client: AsyncClient, auth_headers: dict, test_tree: dict
):
"""Test that a user cannot start a session on another user's private tree."""
# Register a second user
await client.post("/api/v1/auth/register", json={
"email": "other@example.com",
"password": "OtherPassword123!",
"name": "Other User"
})
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "other@example.com",
"password": "OtherPassword123!"
})
other_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# test_tree is owned by test_user (not public, not default)
response = await client.post(
"/api/v1/sessions",
json={"tree_id": test_tree["id"]},
headers=other_headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_start_session_super_admin_any_tree(
self, client: AsyncClient, admin_auth_headers: dict, test_tree: dict
):
"""Test that a super admin can start a session on any tree."""
response = await client.post(
"/api/v1/sessions",
json={"tree_id": test_tree["id"]},
headers=admin_auth_headers
)
assert response.status_code == 201