feat: update all endpoints and schemas for account-based model
Replace team_id with account_id across all API endpoints (trees, categories, tags, steps, step_categories, admin, auth). Add new accounts and webhooks endpoints. Registration now atomically creates Account + Subscription, with account_invite_code bypassing the platform invite gate. Schemas updated for account_id/account_role. 82 tests passing including 18 new tests for accounts, subscriptions, and permissions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
236
backend/app/api/endpoints/accounts.py
Normal file
236
backend/app/api/endpoints/accounts.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
import secrets
|
||||
import string
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.api.deps import get_current_active_user, require_account_owner
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["accounts"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=AccountResponse)
|
||||
async def get_my_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's account."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
return account
|
||||
|
||||
|
||||
@router.get("/me/subscription", response_model=SubscriptionDetails)
|
||||
async def get_my_subscription(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's subscription details including limits and usage."""
|
||||
sub = await get_account_subscription(current_user.account_id, db)
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No subscription found"
|
||||
)
|
||||
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if not limits:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Plan limits not configured"
|
||||
)
|
||||
|
||||
usage = await get_account_usage(current_user.account_id, db)
|
||||
|
||||
return SubscriptionDetails(
|
||||
subscription=SubscriptionResponse.model_validate(sub),
|
||||
limits=PlanLimitsResponse.model_validate(limits),
|
||||
usage=UsageResponse(**usage),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me/members", response_model=list[UserResponse])
|
||||
async def get_my_members(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get members of current user's account."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.account_id == current_user.account_id)
|
||||
.order_by(User.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=AccountResponse)
|
||||
async def update_my_account(
|
||||
data: AccountUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Update account settings (owner only)."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(account, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
return account
|
||||
|
||||
|
||||
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
|
||||
async def update_member_role(
|
||||
user_id: UUID,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Change a member's role within the account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
user.account_role = data.account_role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_member(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Remove a member from the account (owner only).
|
||||
|
||||
The removed user gets a new personal account.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot remove yourself from your own account"
|
||||
)
|
||||
|
||||
# Create a personal account for the removed user
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
display_code = ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
new_account = Account(
|
||||
name=f"{user.name}'s Account",
|
||||
display_code=display_code,
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_sub = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_sub)
|
||||
|
||||
user.account_id = new_account.id
|
||||
user.account_role = "owner"
|
||||
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_invite(
|
||||
data: AccountInviteCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create an invite to join this account (owner only)."""
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
expires_at = None
|
||||
if data.expires_in_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=current_user.account_id,
|
||||
invited_by_id=current_user.id,
|
||||
email=data.email,
|
||||
code=code,
|
||||
role=data.role,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
return invite
|
||||
|
||||
|
||||
@router.get("/me/invites", response_model=list[AccountInviteResponse])
|
||||
async def list_invites(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""List invites for this account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.account_id == current_user.account_id)
|
||||
.order_by(AccountInvite.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import select, func
|
||||
from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
|
||||
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
|
||||
from app.api.deps import require_admin
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
@@ -21,7 +21,7 @@ async def list_users(
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
team_id: Optional[UUID] = Query(None, description="Filter by team")
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account")
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
@@ -30,8 +30,8 @@ async def list_users(
|
||||
query = query.where(User.is_active == is_active)
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
if team_id:
|
||||
query = query.where(User.team_id == team_id)
|
||||
if account_id:
|
||||
query = query.where(User.account_id == account_id)
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
@@ -91,14 +91,14 @@ async def update_user_role(
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
|
||||
async def toggle_team_admin(
|
||||
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
|
||||
async def update_account_role(
|
||||
user_id: UUID,
|
||||
data: TeamAdminUpdate,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Toggle is_team_admin for a user (super admin only)."""
|
||||
"""Change a user's account role (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -108,15 +108,10 @@ async def toggle_team_admin(
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if data.is_team_admin and user.team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User must belong to a team to be a team admin"
|
||||
)
|
||||
|
||||
user.is_team_admin = data.is_team_admin
|
||||
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
|
||||
{"is_team_admin": data.is_team_admin})
|
||||
old_role = user.account_role
|
||||
user.account_role = data.account_role
|
||||
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
|
||||
{"old_account_role": old_role, "new_account_role": data.account_role})
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
@@ -18,6 +20,9 @@ from app.core.security import (
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.token import Token
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
|
||||
db.add(token_record)
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
"""Generate a random 8-character alphanumeric display code."""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("3/minute")
|
||||
async def register(
|
||||
@@ -44,10 +55,46 @@ async def register(
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Register a new user."""
|
||||
# Validate invite code if required
|
||||
"""Register a new user.
|
||||
|
||||
Supports two flows:
|
||||
- account_invite_code: Join an existing account (bypasses platform invite gate)
|
||||
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
|
||||
|
||||
After user creation, if no account invite was used, a personal Account
|
||||
and free Subscription are created automatically.
|
||||
"""
|
||||
# Check for account invite code FIRST — bypasses platform invite gate
|
||||
account_invite_record = None
|
||||
if user_data.account_invite_code:
|
||||
result = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.code == user_data.account_invite_code
|
||||
)
|
||||
)
|
||||
account_invite_record = result.scalar_one_or_none()
|
||||
|
||||
if not account_invite_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid account invite code"
|
||||
)
|
||||
|
||||
if account_invite_record.is_used:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has already been used"
|
||||
)
|
||||
|
||||
if account_invite_record.is_expired:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has expired"
|
||||
)
|
||||
|
||||
# Validate platform invite code if required (skip if account invite was provided)
|
||||
invite_code_record = None
|
||||
if settings.REQUIRE_INVITE_CODE:
|
||||
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
|
||||
if not user_data.invite_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -96,8 +143,37 @@ async def register(
|
||||
invite_code_id=invite_code_record.id if invite_code_record else None
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # Get user ID before creating account
|
||||
|
||||
# Mark invite code as used
|
||||
if account_invite_record:
|
||||
# Join existing account via account invite
|
||||
new_user.account_id = account_invite_record.account_id
|
||||
new_user.account_role = account_invite_record.role
|
||||
|
||||
# Mark account invite as used
|
||||
account_invite_record.accepted_by_id = new_user.id
|
||||
account_invite_record.used_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
# Create personal Account + free Subscription
|
||||
new_account = Account(
|
||||
name=f"{user_data.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
owner_id=new_user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush() # Get account ID
|
||||
|
||||
new_subscription = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_subscription)
|
||||
|
||||
new_user.account_id = new_account.id
|
||||
new_user.account_role = "owner"
|
||||
|
||||
# Mark platform invite code as used
|
||||
if invite_code_record:
|
||||
invite_code_record.used_by_id = new_user.id
|
||||
invite_code_record.used_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -28,11 +28,11 @@ async def list_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(TreeCategory)
|
||||
@@ -41,19 +41,19 @@ async def list_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(TreeCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(TreeCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(TreeCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeCategory.team_id.is_(None), # Global
|
||||
TreeCategory.team_id == current_user.team_id # User's team
|
||||
TreeCategory.account_id.is_(None), # Global
|
||||
TreeCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(TreeCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(TreeCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
|
||||
|
||||
@@ -76,7 +76,7 @@ async def list_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
tree_count=tree_count
|
||||
@@ -101,8 +101,8 @@ async def get_category(
|
||||
detail="Category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -121,7 +121,7 @@ async def get_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -138,10 +138,10 @@ async def create_category(
|
||||
):
|
||||
"""Create a new category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_category(current_user, category_data.team_id):
|
||||
if not can_create_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this category"
|
||||
@@ -150,10 +150,10 @@ async def create_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == slug,
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -164,7 +164,7 @@ async def create_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(TreeCategory.display_order)).where(
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -173,7 +173,7 @@ async def create_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -186,7 +186,7 @@ async def create_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -227,7 +227,7 @@ async def update_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == new_slug,
|
||||
TreeCategory.team_id == category.team_id,
|
||||
TreeCategory.account_id == category.account_id,
|
||||
TreeCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -257,7 +257,7 @@ async def update_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -25,11 +25,11 @@ async def list_step_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List step categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(StepCategory)
|
||||
@@ -38,19 +38,19 @@ async def list_step_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(StepCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(StepCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(StepCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
StepCategory.team_id.is_(None), # Global
|
||||
StepCategory.team_id == current_user.team_id # User's team
|
||||
StepCategory.account_id.is_(None), # Global
|
||||
StepCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(StepCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(StepCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(StepCategory.display_order, StepCategory.name)
|
||||
|
||||
@@ -66,7 +66,7 @@ async def list_step_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
step_count=0 # Will be computed when step_library exists
|
||||
@@ -91,8 +91,8 @@ async def get_step_category(
|
||||
detail="Step category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this step category"
|
||||
@@ -103,7 +103,7 @@ async def get_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -120,10 +120,10 @@ async def create_step_category(
|
||||
):
|
||||
"""Create a new step category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_step_category(current_user, category_data.team_id):
|
||||
if not can_create_step_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this step category"
|
||||
@@ -132,10 +132,10 @@ async def create_step_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == slug,
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -146,7 +146,7 @@ async def create_step_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(StepCategory.display_order)).where(
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -155,7 +155,7 @@ async def create_step_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -168,7 +168,7 @@ async def create_step_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -209,7 +209,7 @@ async def update_step_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == new_slug,
|
||||
StepCategory.team_id == category.team_id,
|
||||
StepCategory.account_id == category.account_id,
|
||||
StepCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -231,7 +231,7 @@ async def update_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -55,10 +55,10 @@ async def get_step_or_404(
|
||||
|
||||
def build_visibility_filter(user: User):
|
||||
"""Build SQLAlchemy filter for step visibility based on user."""
|
||||
if user.team_id:
|
||||
if user.account_id:
|
||||
return or_(
|
||||
StepLibrary.visibility == 'public',
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
|
||||
StepLibrary.created_by == user.id # Own private steps
|
||||
)
|
||||
else:
|
||||
@@ -249,7 +249,7 @@ async def get_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -296,10 +296,10 @@ async def create_step(
|
||||
if not cat_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Invalid category")
|
||||
|
||||
# Team validation: can only set team_id to own team
|
||||
team_id = step_data.team_id
|
||||
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another team")
|
||||
# Account validation: can only set account_id to own account
|
||||
account_id = step_data.account_id
|
||||
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another account")
|
||||
|
||||
step = StepLibrary(
|
||||
title=step_data.title,
|
||||
@@ -309,7 +309,7 @@ async def create_step(
|
||||
tags=step_data.tags,
|
||||
visibility=step_data.visibility,
|
||||
created_by=current_user.id,
|
||||
team_id=team_id or current_user.team_id,
|
||||
account_id=account_id or current_user.account_id,
|
||||
)
|
||||
|
||||
db.add(step)
|
||||
@@ -326,7 +326,7 @@ async def create_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -393,7 +393,7 @@ async def update_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
|
||||
@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
|
||||
async def list_tags(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""List tags visible to the user.
|
||||
|
||||
Returns global tags plus team-specific tags for the user's team.
|
||||
Returns global tags plus account-specific tags for the user's account.
|
||||
Tags are ordered by usage count (most used first).
|
||||
"""
|
||||
query = select(TreeTag)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if include_team and current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global
|
||||
TreeTag.team_id == current_user.team_id # User's team
|
||||
TreeTag.account_id.is_(None), # Global
|
||||
TreeTag.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only show global tags
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
|
||||
|
||||
@@ -55,7 +55,7 @@ async def search_tags(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""Search/autocomplete tags.
|
||||
|
||||
@@ -68,15 +68,15 @@ async def search_tags(
|
||||
)
|
||||
|
||||
# Filter by visibility
|
||||
if include_team and current_user.team_id:
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == current_user.team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
|
||||
|
||||
@@ -102,8 +102,8 @@ async def get_tag(
|
||||
detail="Tag not found"
|
||||
)
|
||||
|
||||
# Check access: global tags visible to all, team tags only to team members
|
||||
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global tags visible to all, account tags only to account members
|
||||
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tag"
|
||||
@@ -120,10 +120,10 @@ async def create_tag(
|
||||
):
|
||||
"""Create a new tag.
|
||||
|
||||
- Global admins can create global tags (team_id=None)
|
||||
- Team members can create team-specific tags for their team
|
||||
- Global admins can create global tags (account_id=None)
|
||||
- Account members can create account-specific tags for their account
|
||||
"""
|
||||
if not can_create_tag(current_user, tag_data.team_id):
|
||||
if not can_create_tag(current_user, tag_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this tag"
|
||||
@@ -132,10 +132,10 @@ async def create_tag(
|
||||
# Generate slug
|
||||
slug = TreeTag.slugify(tag_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
TreeTag.team_id == tag_data.team_id
|
||||
TreeTag.account_id == tag_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -147,7 +147,7 @@ async def create_tag(
|
||||
new_tag = TreeTag(
|
||||
name=tag_data.name,
|
||||
slug=slug,
|
||||
team_id=tag_data.team_id,
|
||||
account_id=tag_data.account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(new_tag)
|
||||
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
|
||||
continue
|
||||
|
||||
# Try to find existing tag
|
||||
# Determine scope: use tree's team, or global for admin-owned trees
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
# Determine scope: use tree's account, or global for admin-owned trees
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global tag
|
||||
TreeTag.team_id == tag_team_id # Team tag
|
||||
TreeTag.account_id.is_(None), # Global tag
|
||||
TreeTag.account_id == tag_account_id # Account tag
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
tag = tag_result.scalar_one_or_none()
|
||||
|
||||
if not tag:
|
||||
# Create new tag - prefer team scope unless admin creating on public tree
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
# Fall back to user's team if they can't create in tree's scope
|
||||
new_team_id = current_user.team_id
|
||||
# Create new tag - prefer account scope unless admin creating on public tree
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
# Fall back to user's account if they can't create in tree's scope
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -331,7 +331,7 @@ async def replace_tree_tags(
|
||||
tree.tags.clear()
|
||||
|
||||
# Add new tags
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
for tag_name in tag_data.tags:
|
||||
slug = TreeTag.slugify(tag_name)
|
||||
@@ -340,8 +340,8 @@ async def replace_tree_tags(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tag_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tag_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -349,14 +349,14 @@ async def replace_tree_tags(
|
||||
|
||||
if not tag:
|
||||
# Create new tag
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
new_team_id = current_user.team_id
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -397,7 +397,7 @@ async def get_tree_tags(
|
||||
# Check if user can view the tree
|
||||
if not tree.is_public:
|
||||
if tree.author_id != current_user.id:
|
||||
if tree.team_id != current_user.team_id:
|
||||
if tree.account_id != current_user.account_id:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
from app.core.audit import log_audit
|
||||
|
||||
router = APIRouter(prefix="/trees", tags=["trees"])
|
||||
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
|
||||
Tree.is_public == True,
|
||||
Tree.author_id == current_user.id,
|
||||
]
|
||||
if current_user.team_id:
|
||||
conditions.append(Tree.team_id == current_user.team_id)
|
||||
if current_user.account_id:
|
||||
conditions.append(Tree.account_id == current_user.account_id)
|
||||
return or_(*conditions)
|
||||
|
||||
|
||||
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
|
||||
category_info=category_info,
|
||||
tags=tree.tag_names,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
tags=tree.tag_names,
|
||||
tree_structure=tree.tree_structure,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -289,7 +290,7 @@ async def create_tree(
|
||||
detail="Category not found"
|
||||
)
|
||||
# Check category access
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -302,16 +303,25 @@ async def create_tree(
|
||||
category_id=tree_data.category_id,
|
||||
tree_structure=tree_data.tree_structure,
|
||||
author_id=None if is_default else current_user.id, # Default trees have no author
|
||||
team_id=None if is_default else current_user.team_id,
|
||||
account_id=None if is_default else current_user.account_id,
|
||||
is_public=True if is_default else tree_data.is_public, # Default trees are always public
|
||||
is_default=is_default
|
||||
)
|
||||
# Check subscription tree limit
|
||||
if not is_default and current_user.account_id:
|
||||
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
||||
if not can_create:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
db.add(new_tree)
|
||||
await db.flush() # Get the ID
|
||||
|
||||
# Handle tags
|
||||
if tree_data.tags:
|
||||
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
# Collect tags to add
|
||||
tags_to_add = []
|
||||
@@ -322,8 +332,8 @@ async def create_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -334,7 +344,7 @@ async def create_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -420,7 +430,7 @@ async def update_tree(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Category not found"
|
||||
)
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -450,7 +460,7 @@ async def update_tree(
|
||||
)
|
||||
|
||||
# Add new tags
|
||||
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
added_tag_ids = set()
|
||||
|
||||
for tag_name in tags_data:
|
||||
@@ -459,8 +469,8 @@ async def update_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -470,7 +480,7 @@ async def update_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
|
||||
62
backend/app/api/endpoints/webhooks.py
Normal file
62
backend/app/api/endpoints/webhooks.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Request, HTTPException, status, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
from app.core.stripe_handlers import WEBHOOK_HANDLERS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
@router.post("/stripe")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
Returns 200 for all events to prevent Stripe retries.
|
||||
Actual processing happens only when Stripe is configured.
|
||||
"""
|
||||
if not settings.stripe_enabled:
|
||||
return {"status": "ok", "message": "Stripe not configured, event ignored"}
|
||||
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
if not sig_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing stripe-signature header"
|
||||
)
|
||||
|
||||
# Verify webhook signature
|
||||
try:
|
||||
import stripe
|
||||
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("stripe package not installed, cannot verify webhook")
|
||||
return {"status": "ok", "message": "stripe package not installed"}
|
||||
except Exception as e:
|
||||
logger.error("Stripe webhook signature verification failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid signature"
|
||||
)
|
||||
|
||||
event_type = event.get("type", "")
|
||||
handler = WEBHOOK_HANDLERS.get(event_type)
|
||||
|
||||
if handler:
|
||||
try:
|
||||
await handler(event, db)
|
||||
except Exception:
|
||||
logger.exception("Error handling Stripe event %s", event_type)
|
||||
|
||||
return {"status": "ok"}
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
|
||||
api_router.include_router(step_categories.router)
|
||||
api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
|
||||
39
backend/app/schemas/account.py
Normal file
39
backend/app/schemas/account.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
display_code: str
|
||||
owner_id: UUID
|
||||
stripe_customer_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
|
||||
|
||||
class AccountInviteCreate(BaseModel):
|
||||
email: str = Field(..., max_length=255)
|
||||
role: str = Field("engineer", pattern="^(engineer|viewer)$")
|
||||
expires_in_days: Optional[int] = Field(None, ge=1, le=30)
|
||||
|
||||
|
||||
class AccountInviteResponse(BaseModel):
|
||||
id: UUID
|
||||
account_id: UUID
|
||||
email: str
|
||||
code: str
|
||||
role: str
|
||||
expires_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -20,7 +20,7 @@ class CategoryBase(BaseModel):
|
||||
|
||||
|
||||
class CategoryCreate(CategoryBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
@@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel):
|
||||
class CategoryResponse(CategoryBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
@@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel):
|
||||
name: str
|
||||
slug: str
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
tree_count: int = 0
|
||||
|
||||
@@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel):
|
||||
|
||||
|
||||
class StepCategoryCreate(StepCategoryBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
|
||||
|
||||
|
||||
class StepCategoryUpdate(BaseModel):
|
||||
@@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel):
|
||||
class StepCategoryResponse(StepCategoryBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
@@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel):
|
||||
name: str
|
||||
slug: str
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
step_count: int = 0
|
||||
|
||||
@@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel):
|
||||
|
||||
|
||||
class StepLibraryCreate(StepLibraryBase):
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class StepLibraryUpdate(BaseModel):
|
||||
@@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel):
|
||||
class StepLibraryResponse(StepLibraryBase):
|
||||
id: UUID
|
||||
created_by: UUID
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
rating_average: Decimal
|
||||
rating_count: int
|
||||
|
||||
40
backend/app/schemas/subscription.py
Normal file
40
backend/app/schemas/subscription.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SubscriptionResponse(BaseModel):
|
||||
id: UUID
|
||||
plan: str
|
||||
status: str
|
||||
billing_interval: Optional[str] = None
|
||||
current_period_start: Optional[datetime] = None
|
||||
current_period_end: Optional[datetime] = None
|
||||
cancel_at_period_end: bool = False
|
||||
stripe_subscription_id: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class PlanLimitsResponse(BaseModel):
|
||||
plan: str
|
||||
max_trees: Optional[int] = None
|
||||
max_sessions_per_month: Optional[int] = None
|
||||
max_users: Optional[int] = None
|
||||
custom_branding: bool = False
|
||||
priority_support: bool = False
|
||||
export_formats: list[str] = ["markdown", "text"]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UsageResponse(BaseModel):
|
||||
tree_count: int
|
||||
session_count_this_month: int
|
||||
|
||||
|
||||
class SubscriptionDetails(BaseModel):
|
||||
subscription: SubscriptionResponse
|
||||
limits: PlanLimitsResponse
|
||||
usage: UsageResponse
|
||||
@@ -19,13 +19,13 @@ class TagBase(BaseModel):
|
||||
|
||||
|
||||
class TagCreate(TagBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.")
|
||||
|
||||
|
||||
class TagResponse(TagBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
created_at: datetime
|
||||
|
||||
@@ -37,7 +37,7 @@ class TagListResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
|
||||
class Config:
|
||||
@@ -53,4 +53,4 @@ class TagSearchParams(BaseModel):
|
||||
"""Query parameters for tag search/autocomplete."""
|
||||
q: str = Field(..., min_length=1, description="Search query")
|
||||
limit: int = Field(10, ge=1, le=50)
|
||||
include_team: bool = Field(True, description="Include team-specific tags")
|
||||
include_account: bool = Field(True, description="Include account-specific tags")
|
||||
|
||||
@@ -44,7 +44,7 @@ class TreeResponse(TreeBase):
|
||||
id: UUID
|
||||
tree_structure: dict[str, Any]
|
||||
author_id: Optional[UUID] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
category_id: Optional[UUID] = None
|
||||
category_info: Optional[CategoryInfo] = None
|
||||
tags: list[str] = [] # List of tag names
|
||||
@@ -69,7 +69,7 @@ class TreeListResponse(BaseModel):
|
||||
category_info: Optional[CategoryInfo] = None
|
||||
tags: list[str] = [] # List of tag names
|
||||
author_id: Optional[UUID] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
is_active: bool
|
||||
is_public: bool
|
||||
is_default: bool
|
||||
|
||||
@@ -13,6 +13,7 @@ class UserBase(BaseModel):
|
||||
class UserCreate(UserBase):
|
||||
password: str = Field(..., min_length=10, description="Password must be at least 10 characters")
|
||||
invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)")
|
||||
account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account")
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
@@ -38,11 +39,11 @@ class UserLogin(BaseModel):
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
role: str
|
||||
role: str = "engineer"
|
||||
account_id: UUID
|
||||
account_role: str
|
||||
is_super_admin: bool = False
|
||||
is_team_admin: bool = False
|
||||
is_active: bool = True
|
||||
team_id: Optional[UUID] = None
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
@@ -54,5 +55,5 @@ class RoleUpdate(BaseModel):
|
||||
role: Literal["engineer", "viewer"]
|
||||
|
||||
|
||||
class TeamAdminUpdate(BaseModel):
|
||||
is_team_admin: bool
|
||||
class AccountRoleUpdate(BaseModel):
|
||||
account_role: str = Field(..., pattern="^(engineer|viewer)$")
|
||||
|
||||
@@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await conn.execute(sa.text("CREATE SCHEMA public"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Seed plan_limits for subscription checks
|
||||
await conn.execute(sa.text("""
|
||||
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
|
||||
VALUES
|
||||
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
|
||||
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
|
||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
||||
"""))
|
||||
|
||||
# Create async session maker
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
|
||||
170
backend/tests/test_account_management.py
Normal file
170
backend/tests/test_account_management.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Integration tests for account management endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAccountEndpoints:
|
||||
"""Test suite for account management endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_account(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting current user's account."""
|
||||
response = await client.get("/api/v1/accounts/me", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "name" in data
|
||||
assert "display_code" in data
|
||||
assert "owner_id" in data
|
||||
assert len(data["display_code"]) == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting current user's subscription details."""
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "subscription" in data
|
||||
assert "limits" in data
|
||||
assert "usage" in data
|
||||
assert data["subscription"]["plan"] == "free"
|
||||
assert data["subscription"]["status"] == "active"
|
||||
assert data["limits"]["max_trees"] == 3
|
||||
assert data["limits"]["max_sessions_per_month"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting members of current user's account."""
|
||||
response = await client.get("/api/v1/accounts/me/members", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
# Current user should be in members list
|
||||
assert any(m["account_role"] == "owner" for m in data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_my_account(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test updating account name."""
|
||||
response = await client.patch(
|
||||
"/api/v1/accounts/me",
|
||||
json={"name": "Updated Account Name"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Account Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_requires_owner(self, client: AsyncClient):
|
||||
"""Test that non-owners cannot update account settings."""
|
||||
# Register two users
|
||||
owner_data = {
|
||||
"email": "owner@example.com",
|
||||
"password": "OwnerPass123!",
|
||||
"name": "Owner"
|
||||
}
|
||||
await client.post("/api/v1/auth/register", json=owner_data)
|
||||
|
||||
# Login as owner and create an invite
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "owner@example.com",
|
||||
"password": "OwnerPass123!"
|
||||
})
|
||||
owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "member@example.com", "role": "engineer"},
|
||||
headers=owner_headers
|
||||
)
|
||||
assert invite_resp.status_code == 201
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register member with account invite code
|
||||
member_data = {
|
||||
"email": "member@example.com",
|
||||
"password": "MemberPass123!",
|
||||
"name": "Member",
|
||||
"account_invite_code": invite_code
|
||||
}
|
||||
reg_resp = await client.post("/api/v1/auth/register", json=member_data)
|
||||
assert reg_resp.status_code == 201
|
||||
assert reg_resp.json()["account_role"] == "engineer"
|
||||
|
||||
# Login as member
|
||||
member_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "member@example.com",
|
||||
"password": "MemberPass123!"
|
||||
})
|
||||
member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"}
|
||||
|
||||
# Member should not be able to update account
|
||||
response = await client.patch(
|
||||
"/api/v1/accounts/me",
|
||||
json={"name": "Hacked Name"},
|
||||
headers=member_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test creating and listing account invites."""
|
||||
# Create invite
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "invitee@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["email"] == "invitee@example.com"
|
||||
assert data["role"] == "engineer"
|
||||
assert "code" in data
|
||||
|
||||
# List invites
|
||||
list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers)
|
||||
assert list_response.status_code == 200
|
||||
invites = list_response.json()
|
||||
assert len(invites) >= 1
|
||||
assert any(i["email"] == "invitee@example.com" for i in invites)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that account invite code joins user to existing account."""
|
||||
# Get current account
|
||||
account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers)
|
||||
account_id = account_resp.json()["id"]
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "joiner@example.com", "role": "viewer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register with account invite code
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "joiner@example.com",
|
||||
"password": "JoinerPass123!",
|
||||
"name": "Joiner",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
data = reg_resp.json()
|
||||
assert data["account_id"] == account_id
|
||||
assert data["account_role"] == "viewer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_invalid_invite_code(self, client: AsyncClient):
|
||||
"""Test that invalid account invite code is rejected."""
|
||||
response = await client.post("/api/v1/auth/register", json={
|
||||
"email": "bad@example.com",
|
||||
"password": "BadPassword123!",
|
||||
"name": "Bad User",
|
||||
"account_invite_code": "INVALID_CODE"
|
||||
})
|
||||
assert response.status_code == 400
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
@@ -169,3 +169,51 @@ class TestAdminEndpoints:
|
||||
log = result.scalar_one_or_none()
|
||||
assert log is not None
|
||||
assert str(log.resource_id) == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_account_role(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test changing a user's account role."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["account_role"] == "viewer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_account_role_invalid(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test that invalid account_role values are rejected."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "owner"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_log_created_on_account_role_change(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession
|
||||
):
|
||||
"""Test that changing account role creates an audit log entry."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(AuditLog.action == "user.account_role_change")
|
||||
)
|
||||
log = result.scalar_one_or_none()
|
||||
assert log is not None
|
||||
assert str(log.resource_id) == user_id
|
||||
assert log.details["old_account_role"] == "owner"
|
||||
assert log.details["new_account_role"] == "viewer"
|
||||
|
||||
@@ -23,6 +23,8 @@ class TestAuthentication:
|
||||
assert data["email"] == user_data["email"]
|
||||
assert data["name"] == user_data["name"]
|
||||
assert data["role"] == "engineer"
|
||||
assert "account_id" in data
|
||||
assert data["account_role"] == "owner"
|
||||
assert "id" in data
|
||||
assert "password" not in data # Password should not be returned
|
||||
|
||||
@@ -107,6 +109,7 @@ class TestAuthentication:
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["role"] == "engineer"
|
||||
assert data["account_role"] == "owner"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_default_role_is_engineer(self, client: AsyncClient):
|
||||
@@ -121,6 +124,7 @@ class TestAuthentication:
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["role"] == "engineer"
|
||||
assert response.json()["account_role"] == "owner"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_rejects_no_uppercase(self, client: AsyncClient):
|
||||
|
||||
205
backend/tests/test_permissions_account.py
Normal file
205
backend/tests/test_permissions_account.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Integration tests for account-based permissions."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAccountPermissions:
|
||||
"""Test suite for account-based permission checks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db):
|
||||
"""Test that viewers cannot create trees."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Register a user
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "viewer@example.com",
|
||||
"password": "ViewerPass123!",
|
||||
"name": "Viewer User"
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
user_id = PyUUID(reg_resp.json()["id"])
|
||||
|
||||
# Demote to viewer via ORM
|
||||
result = await test_db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "viewer"
|
||||
await test_db.commit()
|
||||
|
||||
# Login as viewer
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "viewer@example.com",
|
||||
"password": "ViewerPass123!"
|
||||
})
|
||||
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Try to create tree
|
||||
response = await client.post("/api/v1/trees", json={
|
||||
"name": "Viewer Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"}
|
||||
}, headers=viewer_headers)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""Test that viewers can browse/list trees."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Create a public tree as the regular user first
|
||||
await client.post("/api/v1/trees", json={
|
||||
"name": "Public Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": True
|
||||
}, headers=auth_headers)
|
||||
|
||||
# Register viewer
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "viewer2@example.com",
|
||||
"password": "ViewerPass123!",
|
||||
"name": "Viewer 2"
|
||||
})
|
||||
user_id = PyUUID(reg_resp.json()["id"])
|
||||
|
||||
result = await test_db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "viewer"
|
||||
await test_db.commit()
|
||||
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "viewer2@example.com",
|
||||
"password": "ViewerPass123!"
|
||||
})
|
||||
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Viewer can list trees
|
||||
response = await client.get("/api/v1/trees", headers=viewer_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""Test that account owner can edit trees created by account members."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Get owner's account
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id = me_resp.json()["account_id"]
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "engineer@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register engineer in same account
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "engineer@example.com",
|
||||
"password": "EngineerPass123!",
|
||||
"name": "Engineer",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
assert reg_resp.json()["account_id"] == account_id
|
||||
|
||||
# Login as engineer
|
||||
eng_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "engineer@example.com",
|
||||
"password": "EngineerPass123!"
|
||||
})
|
||||
eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"}
|
||||
|
||||
# Engineer creates a tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Engineer's Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}
|
||||
}, headers=eng_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Owner can edit engineer's tree
|
||||
update_resp = await client.put(
|
||||
f"/api/v1/trees/{tree_id}",
|
||||
json={"name": "Owner Updated Name"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
assert update_resp.json()["name"] == "Owner Updated Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that account members can see each other's non-public trees."""
|
||||
# Get owner's account
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id = me_resp.json()["account_id"]
|
||||
|
||||
# Owner creates a private tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Private Account Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": False
|
||||
}, headers=auth_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Create invite and add member
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "teammate@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "teammate@example.com",
|
||||
"password": "TeammatePass123!",
|
||||
"name": "Teammate",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
|
||||
mate_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "teammate@example.com",
|
||||
"password": "TeammatePass123!"
|
||||
})
|
||||
mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"}
|
||||
|
||||
# Teammate should see the private tree (same account)
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Private Account Tree"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that users from different accounts cannot see private trees."""
|
||||
# Owner creates a private tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Secret Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": False
|
||||
}, headers=auth_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Register a completely separate user (different account)
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "outsider@example.com",
|
||||
"password": "OutsiderPass123!",
|
||||
"name": "Outsider"
|
||||
})
|
||||
|
||||
outsider_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "outsider@example.com",
|
||||
"password": "OutsiderPass123!"
|
||||
})
|
||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||
|
||||
# Outsider should NOT see the private tree
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||
assert response.status_code == 403
|
||||
129
backend/tests/test_subscription_limits.py
Normal file
129
backend/tests/test_subscription_limits.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for subscription limits."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class TestSubscriptionLimits:
|
||||
"""Test suite for subscription plan limits."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that free plan has tree creation limit of 3."""
|
||||
tree_template = {
|
||||
"name": "Limit Test Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test tree"
|
||||
}
|
||||
}
|
||||
|
||||
# Create trees up to the limit
|
||||
for i in range(3):
|
||||
tree_data = {**tree_template, "name": f"Tree {i+1}"}
|
||||
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}"
|
||||
|
||||
# 4th tree should fail with 402
|
||||
tree_data = {**tree_template, "name": "Tree 4 Over Limit"}
|
||||
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert response.status_code == 402
|
||||
assert "limit" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that subscription details reflect actual usage."""
|
||||
# Check initial usage
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
initial_usage = response.json()["usage"]
|
||||
assert initial_usage["tree_count"] == 0
|
||||
|
||||
# Create a tree
|
||||
tree_data = {
|
||||
"name": "Usage Test Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test"
|
||||
}
|
||||
}
|
||||
create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert create_resp.status_code == 201
|
||||
|
||||
# Check usage increased
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
updated_usage = response.json()["usage"]
|
||||
assert updated_usage["tree_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_super_admin_bypasses_limits(
|
||||
self, client: AsyncClient, admin_auth_headers: dict
|
||||
):
|
||||
"""Test that super admin can create trees without limit checks."""
|
||||
tree_template = {
|
||||
"name": "Admin Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test tree"
|
||||
},
|
||||
"is_default": True # Default trees skip limit check
|
||||
}
|
||||
|
||||
# Super admin creating default trees should always work
|
||||
for i in range(5):
|
||||
tree_data = {**tree_template, "name": f"Admin Tree {i+1}"}
|
||||
response = await client.post(
|
||||
"/api/v1/trees", json=tree_data, headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that free plan limits are correct."""
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
limits = response.json()["limits"]
|
||||
assert limits["plan"] == "free"
|
||||
assert limits["max_trees"] == 3
|
||||
assert limits["max_sessions_per_month"] == 20
|
||||
assert limits["max_users"] == 1
|
||||
assert limits["custom_branding"] is False
|
||||
assert limits["priority_support"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgraded_plan_has_higher_limits(
|
||||
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
|
||||
):
|
||||
"""Test that upgrading plan increases limits."""
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
|
||||
# Get the user's subscription and upgrade it
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id_str = me_resp.json()["account_id"]
|
||||
|
||||
from uuid import UUID
|
||||
account_id = UUID(account_id_str)
|
||||
result = await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account_id)
|
||||
)
|
||||
sub = result.scalar_one()
|
||||
sub.plan = "pro"
|
||||
await test_db.commit()
|
||||
|
||||
# Check limits are now pro
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
limits = response.json()["limits"]
|
||||
assert limits["plan"] == "pro"
|
||||
assert limits["max_trees"] == 25
|
||||
assert limits["max_sessions_per_month"] == 200
|
||||
Reference in New Issue
Block a user