fix: high-severity security hardening (Phase B permissions audit)
Phase B addresses 7 high-severity gaps from the permissions audit: - B1: Enforce tree access check on session start via can_access_tree - B2: Replace all inline permission helpers with centralized permissions.py - B3: Fix require_engineer_or_admin to check is_team_admin before role - B4: Add is_active field on User with enforcement in get_current_active_user - B5: Add admin user management endpoints (list, get, role, team-admin, deactivate, activate) - B6: Add rate limiting on auth/invite endpoints via slowapi (disabled in DEBUG) - B7: Implement refresh token rotation with JTI-based revocation and meaningful logout Also reduces access token TTL from 15 to 5 minutes and updates CLAUDE.md with SaaS/MSP context for future planning sessions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
29
backend/alembic/versions/012_add_user_is_active.py
Normal file
29
backend/alembic/versions/012_add_user_is_active.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""add is_active field to users
|
||||
|
||||
Revision ID: 012
|
||||
Revises: 011
|
||||
Create Date: 2026-02-05
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '012'
|
||||
down_revision: Union[str, None] = '011'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'users',
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('users', 'is_active')
|
||||
35
backend/alembic/versions/013_add_refresh_tokens.py
Normal file
35
backend/alembic/versions/013_add_refresh_tokens.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""add refresh_tokens table
|
||||
|
||||
Revision ID: 013
|
||||
Revises: 012
|
||||
Create Date: 2026-02-05
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '013'
|
||||
down_revision: Union[str, None] = '012'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'refresh_tokens',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column('token_hash', sa.String(64), unique=True, nullable=False, index=True),
|
||||
sa.Column('user_id', UUID(as_uuid=True), sa.ForeignKey('users.id'), nullable=False, index=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('revoked_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('refresh_tokens')
|
||||
@@ -67,8 +67,11 @@ async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
"""Ensure user is active (not disabled)."""
|
||||
# For now, all users are considered active
|
||||
# Add logic here if you add an is_active field to User
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account has been deactivated"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
@@ -90,6 +93,8 @@ async def require_engineer_or_admin(
|
||||
"""Require engineer, team admin, or super admin role (blocks viewers)."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
if current_user.is_team_admin and current_user.team_id is not None:
|
||||
return current_user
|
||||
if current_user.role not in ("engineer",):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
166
backend/app/api/endpoints/admin.py
Normal file
166
backend/app/api/endpoints/admin.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
|
||||
from app.api.deps import require_admin
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserResponse])
|
||||
async def list_users(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
team_id: Optional[UUID] = Query(None, description="Filter by team")
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(User.is_active == is_active)
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
if team_id:
|
||||
query = query.where(User.team_id == team_id)
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
users = result.scalars().all()
|
||||
return users
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Get user details (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/role", response_model=UserResponse)
|
||||
async def update_user_role(
|
||||
user_id: UUID,
|
||||
role_data: RoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Change user role (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
user.role = role_data.role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
|
||||
async def toggle_team_admin(
|
||||
user_id: UUID,
|
||||
data: TeamAdminUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Toggle is_team_admin for a user (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if data.is_team_admin and user.team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User must belong to a team to be a team admin"
|
||||
)
|
||||
|
||||
user.is_team_admin = data.is_team_admin
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/deactivate", response_model=UserResponse)
|
||||
async def deactivate_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Deactivate a user account (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot deactivate your own account"
|
||||
)
|
||||
|
||||
user.is_active = False
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/activate", response_model=UserResponse)
|
||||
async def activate_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Reactivate a user account (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user.is_active = True
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
@@ -1,29 +1,46 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_token,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.token import Token
|
||||
from app.api.deps import get_current_user, get_refresh_token_payload
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id) -> None:
|
||||
"""Decode a refresh token JWT and store its hash in the database."""
|
||||
payload = decode_token(refresh_token_str)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = RefreshToken(
|
||||
token_hash=hash_token(payload["jti"]),
|
||||
user_id=user_id,
|
||||
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||
)
|
||||
db.add(token_record)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("3/minute")
|
||||
async def register(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
@@ -92,7 +109,9 @@ async def register(
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
@limiter.limit("5/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
@@ -110,21 +129,26 @@ async def login(
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await _store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/json", response_model=Token)
|
||||
@limiter.limit("5/minute")
|
||||
async def login_json(
|
||||
request: Request,
|
||||
credentials: UserLogin,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
@@ -139,25 +163,50 @@ async def login_json(
|
||||
)
|
||||
|
||||
user.last_login = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store refresh token hash in DB
|
||||
await _store_refresh_token(db, refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
@limiter.limit("10/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Refresh access token using refresh token."""
|
||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
||||
user_id = payload.get("sub")
|
||||
jti = payload.get("jti")
|
||||
|
||||
# Validate refresh token hasn't been revoked
|
||||
if jti:
|
||||
token_hash = hash_token(jti)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
stored_token = result.scalar_one_or_none()
|
||||
|
||||
if stored_token and stored_token.is_revoked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has been revoked"
|
||||
)
|
||||
|
||||
# Revoke the old refresh token (token rotation)
|
||||
if stored_token:
|
||||
stored_token.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -168,26 +217,42 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
new_refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
||||
new_refresh_token_str = create_refresh_token(data={"sub": str(user.id)})
|
||||
|
||||
# Store new refresh token
|
||||
await _store_refresh_token(db, new_refresh_token_str, user.id)
|
||||
await db.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
refresh_token=new_refresh_token_str,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout():
|
||||
"""Logout user (client should discard tokens)."""
|
||||
# JWT tokens are stateless, so logout is handled client-side
|
||||
# In a production app, you might want to blacklist the token
|
||||
async def logout(
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Logout user by revoking the refresh token."""
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
token_hash = hash_token(jti)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
stored_token = result.scalar_one_or_none()
|
||||
if stored_token and not stored_token.is_revoked:
|
||||
stored_token.revoked_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
|
||||
@@ -10,7 +10,8 @@ from app.models.category import TreeCategory
|
||||
from app.models.tree import Tree
|
||||
from app.models.user import User
|
||||
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_category, can_create_category
|
||||
|
||||
router = APIRouter(prefix="/categories", tags=["categories"])
|
||||
|
||||
@@ -22,32 +23,10 @@ def slugify(name: str) -> str:
|
||||
return slug
|
||||
|
||||
|
||||
def can_manage_category(user: User, category: TreeCategory) -> bool:
|
||||
"""Check if user can manage (edit/delete) a category."""
|
||||
# Global admins can manage any category
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
# Team admins can manage their team's categories
|
||||
if user.is_team_admin and category.team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Check if user can create a category for the given team."""
|
||||
# Global admins can create global categories (team_id=None) or any team's categories
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
# Team admins can only create categories for their own team
|
||||
if user.is_team_admin and team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=list[CategoryListResponse])
|
||||
async def list_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
):
|
||||
@@ -110,7 +89,7 @@ async def list_categories(
|
||||
async def get_category(
|
||||
category_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific category by ID."""
|
||||
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))
|
||||
@@ -155,7 +134,7 @@ async def get_category(
|
||||
async def create_category(
|
||||
category_data: CategoryCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Create a new category.
|
||||
|
||||
@@ -221,7 +200,7 @@ async def update_category(
|
||||
category_id: UUID,
|
||||
category_data: CategoryUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Update a category."""
|
||||
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))
|
||||
@@ -291,7 +270,7 @@ async def update_category(
|
||||
async def delete_category(
|
||||
category_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Soft delete (archive) a category."""
|
||||
result = await db.execute(select(TreeCategory).where(TreeCategory.id == category_id))
|
||||
|
||||
@@ -17,7 +17,8 @@ from app.schemas.folder import (
|
||||
FolderReorderRequest,
|
||||
FolderTreeRequest
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_access_tree
|
||||
|
||||
router = APIRouter(prefix="/folders", tags=["folders"])
|
||||
|
||||
@@ -63,30 +64,10 @@ async def is_descendant(db: AsyncSession, potential_descendant_id: UUID, ancesto
|
||||
return False
|
||||
|
||||
|
||||
def can_access_tree(user: User, tree: Tree) -> bool:
|
||||
"""Check if user can access a tree (to add to folder).
|
||||
|
||||
User can access tree if:
|
||||
- Tree is public
|
||||
- User is the author
|
||||
- Tree belongs to user's team
|
||||
- User is a global admin
|
||||
"""
|
||||
if tree.is_public:
|
||||
return True
|
||||
if user.id == tree.author_id:
|
||||
return True
|
||||
if tree.team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=list[FolderListResponse])
|
||||
async def list_folders(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""List all folders for the current user.
|
||||
|
||||
@@ -120,7 +101,7 @@ async def list_folders(
|
||||
async def get_folder(
|
||||
folder_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific folder by ID."""
|
||||
result = await db.execute(
|
||||
@@ -160,7 +141,7 @@ async def get_folder(
|
||||
async def create_folder(
|
||||
folder_data: FolderCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Create a new folder for the current user.
|
||||
|
||||
@@ -241,7 +222,7 @@ async def update_folder(
|
||||
folder_id: UUID,
|
||||
folder_data: FolderUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Update a folder.
|
||||
|
||||
@@ -352,7 +333,7 @@ async def update_folder(
|
||||
async def delete_folder(
|
||||
folder_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Delete a folder.
|
||||
|
||||
@@ -384,7 +365,7 @@ async def delete_folder(
|
||||
async def reorder_folders(
|
||||
reorder_data: FolderReorderRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Reorder folders by providing folder IDs in desired order."""
|
||||
# Get all user's folders
|
||||
@@ -414,7 +395,7 @@ async def add_tree_to_folder(
|
||||
folder_id: UUID,
|
||||
request: FolderTreeRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Add a tree to a folder."""
|
||||
# Get folder with trees
|
||||
@@ -474,7 +455,7 @@ async def remove_tree_from_folder(
|
||||
folder_id: UUID,
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Remove a tree from a folder."""
|
||||
# Get folder with trees
|
||||
@@ -519,7 +500,7 @@ async def remove_tree_from_folder(
|
||||
async def get_folder_tree_ids(
|
||||
folder_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get all tree IDs in a folder.
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.schemas.invite_code import InviteCodeCreate, InviteCodeResponse, InviteCodeValidation
|
||||
@@ -74,7 +74,9 @@ async def revoke_invite_code(
|
||||
|
||||
|
||||
@router.get("/validate/{code}", response_model=InviteCodeValidation)
|
||||
@limiter.limit("5/minute")
|
||||
async def validate_invite_code(
|
||||
request: Request,
|
||||
code: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
|
||||
@@ -12,7 +12,8 @@ from app.models.tree import Tree
|
||||
from app.models.session import Session
|
||||
from app.models.user import User
|
||||
from app.schemas.session import SessionCreate, SessionUpdate, SessionResponse, SessionExport, ScratchpadUpdate
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_access_tree
|
||||
|
||||
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||
|
||||
@@ -20,7 +21,7 @@ router = APIRouter(prefix="/sessions", tags=["sessions"])
|
||||
@router.get("", response_model=list[SessionResponse])
|
||||
async def list_sessions(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
completed: Optional[bool] = Query(None, description="Filter by completion status"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100)
|
||||
@@ -46,7 +47,7 @@ async def list_sessions(
|
||||
async def get_session(
|
||||
session_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific session."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
@@ -71,7 +72,7 @@ async def get_session(
|
||||
async def start_session(
|
||||
session_data: SessionCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Start a new troubleshooting session."""
|
||||
# Get the tree
|
||||
@@ -90,6 +91,12 @@ async def start_session(
|
||||
detail="Cannot start session with inactive tree"
|
||||
)
|
||||
|
||||
if not can_access_tree(current_user, tree):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
)
|
||||
|
||||
# Create session with tree snapshot
|
||||
new_session = Session(
|
||||
tree_id=tree.id,
|
||||
@@ -115,7 +122,7 @@ async def update_session(
|
||||
session_id: UUID,
|
||||
session_data: SessionUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Update session (add decisions, notes, etc.)."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
@@ -154,7 +161,7 @@ async def update_session(
|
||||
async def complete_session(
|
||||
session_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Mark session as complete."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
@@ -189,7 +196,7 @@ async def update_scratchpad(
|
||||
session_id: UUID,
|
||||
data: ScratchpadUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Update session scratchpad. Accepts updates on both active and completed sessions."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
@@ -218,7 +225,7 @@ async def export_session(
|
||||
session_id: UUID,
|
||||
export_options: SessionExport,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Export session to formatted notes."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
|
||||
@@ -14,37 +14,16 @@ from app.schemas.step_category import (
|
||||
StepCategoryListResponse,
|
||||
slugify
|
||||
)
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_step_category, can_create_step_category
|
||||
|
||||
router = APIRouter(prefix="/step-categories", tags=["step-categories"])
|
||||
|
||||
|
||||
def can_manage_step_category(user: User, category: StepCategory) -> bool:
|
||||
"""Check if user can manage (edit/delete) a step category."""
|
||||
# Global admins can manage any category
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
# Team admins can manage their team's categories
|
||||
if user.is_team_admin and category.team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Check if user can create a step category for the given team."""
|
||||
# Global admins can create global categories (team_id=None) or any team's categories
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
# Team admins can only create categories for their own team
|
||||
if user.is_team_admin and team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=list[StepCategoryListResponse])
|
||||
async def list_step_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
):
|
||||
@@ -100,7 +79,7 @@ async def list_step_categories(
|
||||
async def get_step_category(
|
||||
category_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific step category by ID."""
|
||||
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))
|
||||
@@ -137,7 +116,7 @@ async def get_step_category(
|
||||
async def create_step_category(
|
||||
category_data: StepCategoryCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Create a new step category.
|
||||
|
||||
@@ -203,7 +182,7 @@ async def update_step_category(
|
||||
category_id: UUID,
|
||||
category_data: StepCategoryUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Update a step category."""
|
||||
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))
|
||||
@@ -265,7 +244,7 @@ async def update_step_category(
|
||||
async def delete_step_category(
|
||||
category_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Soft delete (archive) a step category."""
|
||||
result = await db.execute(select(StepCategory).where(StepCategory.id == category_id))
|
||||
|
||||
@@ -7,7 +7,8 @@ from sqlalchemy import select, or_, and_, func, desc, Integer, case
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_user, require_engineer_or_admin
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin
|
||||
from app.core.permissions import can_view_step, can_edit_step
|
||||
from app.models.user import User
|
||||
from app.models.step_library import StepLibrary, StepRating
|
||||
from app.models.step_category import StepCategory
|
||||
@@ -25,27 +26,6 @@ from app.schemas.step_library import (
|
||||
router = APIRouter(prefix="/steps", tags=["steps"])
|
||||
|
||||
|
||||
# Permission helpers
|
||||
def can_view_step(user: User, step: StepLibrary) -> bool:
|
||||
"""Check if user can view a step based on visibility."""
|
||||
if step.visibility == 'public':
|
||||
return True
|
||||
if step.visibility == 'private':
|
||||
return step.created_by == user.id
|
||||
if step.visibility == 'team':
|
||||
return step.team_id == user.team_id or user.is_super_admin
|
||||
return False
|
||||
|
||||
|
||||
def can_edit_step(user: User, step: StepLibrary) -> bool:
|
||||
"""Check if user can edit/delete a step."""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.role == 'viewer':
|
||||
return False
|
||||
return step.created_by == user.id
|
||||
|
||||
|
||||
async def get_step_or_404(
|
||||
step_id: UUID,
|
||||
db: AsyncSession,
|
||||
@@ -99,7 +79,7 @@ async def list_steps(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""List steps with filters and pagination."""
|
||||
query = select(StepLibrary).where(
|
||||
@@ -177,7 +157,7 @@ async def search_steps(
|
||||
q: str = Query(..., min_length=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Full-text search for steps."""
|
||||
# Use PostgreSQL full-text search
|
||||
@@ -229,7 +209,7 @@ async def search_steps(
|
||||
async def get_popular_tags(
|
||||
limit: int = Query(20, ge=1, le=50),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get popular tags with usage counts."""
|
||||
# Use unnest to expand arrays and count occurrences
|
||||
@@ -255,7 +235,7 @@ async def get_popular_tags(
|
||||
async def get_step(
|
||||
step_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get a step by ID."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_view=True)
|
||||
@@ -374,7 +354,7 @@ async def update_step(
|
||||
step_id: UUID,
|
||||
step_data: StepLibraryUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Update a step (owner or admin only)."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
|
||||
@@ -444,7 +424,7 @@ async def update_step(
|
||||
async def delete_step(
|
||||
step_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Soft delete a step (owner or admin only)."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
|
||||
@@ -462,7 +442,7 @@ async def rate_step(
|
||||
step_id: UUID,
|
||||
rating_data: StepRatingCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Rate a step (1-5 stars with optional review)."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_view=True)
|
||||
@@ -516,7 +496,7 @@ async def update_rating(
|
||||
step_id: UUID,
|
||||
rating_data: StepRatingUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Update your rating for a step."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_view=True)
|
||||
@@ -563,7 +543,7 @@ async def update_rating(
|
||||
async def delete_rating(
|
||||
step_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Delete your rating for a step."""
|
||||
step = await get_step_or_404(step_id, db, current_user, check_view=True)
|
||||
@@ -593,7 +573,7 @@ async def get_reviews(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""Get reviews for a step."""
|
||||
await get_step_or_404(step_id, db, current_user, check_view=True)
|
||||
|
||||
@@ -10,52 +10,16 @@ from app.models.tag import TreeTag, tree_tag_assignments
|
||||
from app.models.tree import Tree
|
||||
from app.models.user import User
|
||||
from app.schemas.tag import TagCreate, TagResponse, TagListResponse, TagAssignment
|
||||
from app.api.deps import get_current_user
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_tree_tags, can_create_tag
|
||||
|
||||
router = APIRouter(prefix="/tags", tags=["tags"])
|
||||
|
||||
|
||||
def can_manage_tree_tags(user: User, tree: Tree) -> bool:
|
||||
"""Check if user can manage tags on a tree.
|
||||
|
||||
Allowed:
|
||||
- Tree author (engineer+)
|
||||
- Super admins
|
||||
- Team admins for their team's trees
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.role == "viewer":
|
||||
return False
|
||||
if user.id == tree.author_id:
|
||||
return True
|
||||
if user.is_team_admin and tree.team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Check if user can create a tag for the given scope.
|
||||
|
||||
- Super admins can create global tags (team_id=None)
|
||||
- Team admins and super admins can create team-specific tags
|
||||
- Engineers can create team tags for their own team
|
||||
- Viewers cannot create tags
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.role == "viewer":
|
||||
return False
|
||||
# For team-specific tags, user must belong to that team
|
||||
if team_id is not None and team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("", response_model=list[TagListResponse])
|
||||
async def list_tags(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
):
|
||||
"""List tags visible to the user.
|
||||
@@ -88,7 +52,7 @@ async def list_tags(
|
||||
@router.get("/search", response_model=list[TagListResponse])
|
||||
async def search_tags(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
@@ -125,7 +89,7 @@ async def search_tags(
|
||||
async def get_tag(
|
||||
tag_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific tag by ID."""
|
||||
result = await db.execute(select(TreeTag).where(TreeTag.id == tag_id))
|
||||
@@ -151,7 +115,7 @@ async def get_tag(
|
||||
async def create_tag(
|
||||
tag_data: TagCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Create a new tag.
|
||||
|
||||
@@ -197,7 +161,7 @@ async def add_tags_to_tree(
|
||||
tree_id: UUID,
|
||||
tag_data: TagAssignment,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Add tags to a tree.
|
||||
|
||||
@@ -281,7 +245,7 @@ async def remove_tag_from_tree(
|
||||
tree_id: UUID,
|
||||
tag_slug: str,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Remove a tag from a tree."""
|
||||
# Get tree with tags
|
||||
@@ -330,7 +294,7 @@ async def replace_tree_tags(
|
||||
tree_id: UUID,
|
||||
tag_data: TagAssignment,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Replace all tags on a tree.
|
||||
|
||||
@@ -412,7 +376,7 @@ async def replace_tree_tags(
|
||||
async def get_tree_tags(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get all tags assigned to a tree."""
|
||||
# Get tree with tags
|
||||
|
||||
@@ -12,7 +12,8 @@ from app.models.category import TreeCategory
|
||||
from app.models.tag import TreeTag
|
||||
from app.models.folder import UserFolder
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.api.deps import get_current_user, require_engineer_or_admin, require_admin
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
|
||||
router = APIRouter(prefix="/trees", tags=["trees"])
|
||||
|
||||
@@ -98,7 +99,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
@router.get("", response_model=list[TreeListResponse])
|
||||
async def list_trees(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
category: Optional[str] = Query(None, description="Filter by legacy category string"),
|
||||
category_id: Optional[UUID] = Query(None, description="Filter by category ID"),
|
||||
tags: Optional[str] = Query(None, description="Comma-separated tag slugs to filter by"),
|
||||
@@ -176,7 +177,7 @@ async def list_trees(
|
||||
@router.get("/categories", response_model=list[str])
|
||||
async def list_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""List all unique categories from trees the user can access.
|
||||
|
||||
@@ -196,7 +197,7 @@ async def list_categories(
|
||||
@router.get("/search", response_model=list[TreeListResponse])
|
||||
async def search_trees(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
q: str = Query(..., min_length=2, description="Search query"),
|
||||
limit: int = Query(20, ge=1, le=50)
|
||||
):
|
||||
@@ -226,7 +227,7 @@ async def search_trees(
|
||||
async def get_tree(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get a specific tree by ID."""
|
||||
result = await db.execute(
|
||||
@@ -245,15 +246,7 @@ async def get_tree(
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
# Check access: tree must be active AND (default OR public OR author OR same team)
|
||||
can_access = (
|
||||
tree.is_default or
|
||||
tree.is_public or
|
||||
tree.author_id == current_user.id or
|
||||
(tree.team_id == current_user.team_id and current_user.team_id is not None) or
|
||||
current_user.is_super_admin
|
||||
)
|
||||
if not tree.is_active or not can_access:
|
||||
if not tree.is_active or not can_access_tree(current_user, tree):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
@@ -399,13 +392,7 @@ async def update_tree(
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
# Check if user can edit: must be author, team admin for team trees, or global admin
|
||||
can_edit = (
|
||||
tree.author_id == current_user.id or
|
||||
current_user.is_super_admin or
|
||||
(current_user.is_team_admin and tree.team_id == current_user.team_id)
|
||||
)
|
||||
if not can_edit:
|
||||
if not can_edit_tree(current_user, tree):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only edit your own trees"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -12,3 +12,4 @@ api_router.include_router(tags.router)
|
||||
api_router.include_router(folders.router)
|
||||
api_router.include_router(step_categories.router)
|
||||
api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
|
||||
@@ -43,7 +43,7 @@ class Settings(BaseSettings):
|
||||
)
|
||||
return v
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 5
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# Security
|
||||
|
||||
@@ -9,13 +9,15 @@ Role hierarchy: super_admin > team_admin > engineer > viewer
|
||||
- viewer: role='viewer', read-only (can browse, run sessions, rate steps)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.models.step_library import StepLibrary
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.step_category import StepCategory
|
||||
|
||||
ROLE_HIERARCHY = {
|
||||
"super_admin": 4,
|
||||
@@ -92,3 +94,78 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool:
|
||||
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_access_tree(user: User, tree: Tree) -> bool:
|
||||
"""Can the user access (view) this tree?"""
|
||||
if tree.is_default or tree.is_public:
|
||||
return True
|
||||
if tree.author_id == user.id:
|
||||
return True
|
||||
if tree.team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_view_step(user: User, step: StepLibrary) -> bool:
|
||||
"""Can the user view this step based on its visibility?"""
|
||||
if step.visibility == "public":
|
||||
return True
|
||||
if step.visibility == "private":
|
||||
return step.created_by == user.id
|
||||
if step.visibility == "team":
|
||||
return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin
|
||||
return False
|
||||
|
||||
|
||||
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a tag for the given scope?
|
||||
|
||||
- Super admins can create global tags (team_id=None) or any team's tags
|
||||
- Engineers can create team tags for their own team
|
||||
- Viewers cannot create tags
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if not can_create_content(user):
|
||||
return False
|
||||
if team_id is not None and team_id == user.team_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a category for the given team?
|
||||
|
||||
- Super admins can create global or any team's categories
|
||||
- Team admins can create categories for their own team
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_manage_step_category(user: User, category: StepCategory) -> bool:
|
||||
"""Can the user edit/delete this step category?"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a step category for the given team?
|
||||
|
||||
- Super admins can create global or any team's step categories
|
||||
- Team admins can create step categories for their own team
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
6
backend/app/core/rate_limit.py
Normal file
6
backend/app/core/rate_limit.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address, enabled=not settings.DEBUG)
|
||||
@@ -1,3 +1,5 @@
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
@@ -30,14 +32,20 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
"""Create a JWT refresh token."""
|
||||
"""Create a JWT refresh token with a unique jti for revocation tracking."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
jti = str(uuid.uuid4())
|
||||
to_encode.update({"exp": expire, "type": "refresh", "jti": jti})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def hash_token(jti: str) -> str:
|
||||
"""Hash a token JTI for secure storage."""
|
||||
return hashlib.sha256(jti.encode()).hexdigest()
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
"""Decode and validate a JWT token."""
|
||||
try:
|
||||
|
||||
@@ -2,11 +2,14 @@ import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.core.logging_config import setup_logging
|
||||
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||
from app.core.rate_limit import limiter
|
||||
from app.api.router import api_router
|
||||
|
||||
# Initialize logging configuration
|
||||
@@ -38,6 +41,9 @@ app = FastAPI(
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Add logging middleware (BEFORE CORS to log all requests)
|
||||
app.add_middleware(ErrorLoggingMiddleware)
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
@@ -9,6 +9,7 @@ from .tag import TreeTag, tree_tag_assignments
|
||||
from .folder import UserFolder, user_folder_trees
|
||||
from .step_category import StepCategory
|
||||
from .step_library import StepLibrary, StepRating, StepUsageLog
|
||||
from .refresh_token import RefreshToken
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"StepLibrary",
|
||||
"StepRating",
|
||||
"StepUsageLog",
|
||||
"RefreshToken",
|
||||
]
|
||||
|
||||
38
backend/app/models/refresh_token.py
Normal file
38
backend/app/models/refresh_token.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_revoked(self) -> bool:
|
||||
return self.revoked_at is not None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
@@ -33,6 +33,7 @@ class User(Base):
|
||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
|
||||
is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("teams.id"),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
@@ -29,9 +29,18 @@ class UserResponse(UserBase):
|
||||
role: str
|
||||
is_super_admin: bool = False
|
||||
is_team_admin: bool = False
|
||||
is_active: bool = True
|
||||
team_id: Optional[UUID] = None
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
role: Literal["engineer", "viewer"]
|
||||
|
||||
|
||||
class TeamAdminUpdate(BaseModel):
|
||||
is_team_admin: bool
|
||||
|
||||
@@ -19,5 +19,8 @@ pydantic==2.6.1
|
||||
pydantic-settings==2.1.0
|
||||
email-validator==2.1.0
|
||||
|
||||
# Rate Limiting
|
||||
slowapi==0.1.9
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
|
||||
129
backend/tests/test_admin.py
Normal file
129
backend/tests/test_admin.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for admin user management endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAdminEndpoints:
|
||||
"""Test suite for admin user management endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test listing users as a super admin."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users", headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
users = response.json()
|
||||
assert len(users) >= 2 # admin + test_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_non_admin(
|
||||
self, client: AsyncClient, auth_headers: dict
|
||||
):
|
||||
"""Test that non-admin users cannot list users."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test getting user details as admin."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/users/{user_id}", headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == test_user["email"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_user_role(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test changing a user's role to viewer."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/role",
|
||||
json={"role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["role"] == "viewer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_role_invalid(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test that invalid role values are rejected."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/role",
|
||||
json={"role": "admin"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_change_own_role(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_admin: dict
|
||||
):
|
||||
"""Test that admin cannot change their own role."""
|
||||
admin_id = test_admin["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{admin_id}/role",
|
||||
json={"role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "own role" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_user(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test deactivating a user."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/deactivate",
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["is_active"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_activate_user(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test reactivating a user."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
# Deactivate first
|
||||
await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/deactivate",
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
# Then reactivate
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/activate",
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_deactivate_self(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_admin: dict
|
||||
):
|
||||
"""Test that admin cannot deactivate themselves."""
|
||||
admin_id = test_admin["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{admin_id}/deactivate",
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "own account" in response.json()["detail"].lower()
|
||||
@@ -684,3 +684,40 @@ class TestSessions:
|
||||
content = response.text
|
||||
assert '<script>' not in content
|
||||
assert '<script>' in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_session_on_others_private_tree_forbidden(
|
||||
self, client: AsyncClient, auth_headers: dict, test_tree: dict
|
||||
):
|
||||
"""Test that a user cannot start a session on another user's private tree."""
|
||||
# Register a second user
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "other@example.com",
|
||||
"password": "OtherPassword123!",
|
||||
"name": "Other User"
|
||||
})
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "other@example.com",
|
||||
"password": "OtherPassword123!"
|
||||
})
|
||||
other_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# test_tree is owned by test_user (not public, not default)
|
||||
response = await client.post(
|
||||
"/api/v1/sessions",
|
||||
json={"tree_id": test_tree["id"]},
|
||||
headers=other_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_session_super_admin_any_tree(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_tree: dict
|
||||
):
|
||||
"""Test that a super admin can start a session on any tree."""
|
||||
response = await client.post(
|
||||
"/api/v1/sessions",
|
||||
json={"tree_id": test_tree["id"]},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
Reference in New Issue
Block a user