feat: update all endpoints and schemas for account-based model

Replace team_id with account_id across all API endpoints (trees,
categories, tags, steps, step_categories, admin, auth). Add new
accounts and webhooks endpoints. Registration now atomically creates
Account + Subscription, with account_invite_code bypassing the
platform invite gate.

Schemas updated for account_id/account_role. 82 tests passing
including 18 new tests for accounts, subscriptions, and permissions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-07 02:39:01 -05:00
parent 4ccb93ee31
commit e0089a9c5a
24 changed files with 1178 additions and 152 deletions

View File

@@ -0,0 +1,236 @@
from datetime import datetime, timezone, timedelta
from typing import Annotated, Optional
from uuid import UUID
import secrets
import string
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.subscription import Subscription
from app.models.user import User
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
from app.schemas.user import UserResponse, AccountRoleUpdate
from app.api.deps import get_current_active_user, require_account_owner
router = APIRouter(prefix="/accounts", tags=["accounts"])
@router.get("/me", response_model=AccountResponse)
async def get_my_account(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's account."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
return account
@router.get("/me/subscription", response_model=SubscriptionDetails)
async def get_my_subscription(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's subscription details including limits and usage."""
sub = await get_account_subscription(current_user.account_id, db)
if not sub:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No subscription found"
)
limits = await get_plan_limits(sub.plan, db)
if not limits:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Plan limits not configured"
)
usage = await get_account_usage(current_user.account_id, db)
return SubscriptionDetails(
subscription=SubscriptionResponse.model_validate(sub),
limits=PlanLimitsResponse.model_validate(limits),
usage=UsageResponse(**usage),
)
@router.get("/me/members", response_model=list[UserResponse])
async def get_my_members(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get members of current user's account."""
result = await db.execute(
select(User).where(User.account_id == current_user.account_id)
.order_by(User.created_at)
)
return result.scalars().all()
@router.patch("/me", response_model=AccountResponse)
async def update_my_account(
data: AccountUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Update account settings (owner only)."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(account, field, value)
await db.commit()
await db.refresh(account)
return account
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
async def update_member_role(
user_id: UUID,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Change a member's role within the account (owner only)."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot change your own role"
)
user.account_role = data.account_role
await db.commit()
await db.refresh(user)
return user
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_member(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Remove a member from the account (owner only).
The removed user gets a new personal account.
"""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot remove yourself from your own account"
)
# Create a personal account for the removed user
chars = string.ascii_uppercase + string.digits
display_code = ''.join(secrets.choice(chars) for _ in range(8))
new_account = Account(
name=f"{user.name}'s Account",
display_code=display_code,
owner_id=user.id,
)
db.add(new_account)
await db.flush()
new_sub = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_sub)
user.account_id = new_account.id
user.account_role = "owner"
await db.commit()
return None
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
async def create_invite(
data: AccountInviteCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Create an invite to join this account (owner only)."""
code = secrets.token_urlsafe(16)
expires_at = None
if data.expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
invite = AccountInvite(
account_id=current_user.account_id,
invited_by_id=current_user.id,
email=data.email,
code=code,
role=data.role,
expires_at=expires_at,
)
db.add(invite)
await db.commit()
await db.refresh(invite)
return invite
@router.get("/me/invites", response_model=list[AccountInviteResponse])
async def list_invites(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""List invites for this account (owner only)."""
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.account_id == current_user.account_id)
.order_by(AccountInvite.created_at.desc())
)
return result.scalars().all()

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select, func
from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
from app.api.deps import require_admin
router = APIRouter(prefix="/admin", tags=["admin"])
@@ -21,7 +21,7 @@ async def list_users(
limit: int = Query(100, ge=1, le=100),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
role: Optional[str] = Query(None, description="Filter by role"),
team_id: Optional[UUID] = Query(None, description="Filter by team")
account_id: Optional[UUID] = Query(None, description="Filter by account")
):
"""List all users (super admin only)."""
query = select(User)
@@ -30,8 +30,8 @@ async def list_users(
query = query.where(User.is_active == is_active)
if role:
query = query.where(User.role == role)
if team_id:
query = query.where(User.team_id == team_id)
if account_id:
query = query.where(User.account_id == account_id)
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
@@ -91,14 +91,14 @@ async def update_user_role(
return user
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
async def toggle_team_admin(
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
async def update_account_role(
user_id: UUID,
data: TeamAdminUpdate,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Toggle is_team_admin for a user (super admin only)."""
"""Change a user's account role (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -108,15 +108,10 @@ async def toggle_team_admin(
detail="User not found"
)
if data.is_team_admin and user.team_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User must belong to a team to be a team admin"
)
user.is_team_admin = data.is_team_admin
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
{"is_team_admin": data.is_team_admin})
old_role = user.account_role
user.account_role = data.account_role
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
{"old_account_role": old_role, "new_account_role": data.account_role})
await db.commit()
await db.refresh(user)
return user

View File

@@ -1,3 +1,5 @@
import secrets
import string
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
@@ -18,6 +20,9 @@ from app.core.security import (
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.schemas.user import UserCreate, UserResponse, UserLogin
from app.schemas.token import Token
from app.api.deps import get_current_active_user, get_refresh_token_payload
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
db.add(token_record)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(8))
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
@@ -44,10 +55,46 @@ async def register(
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Register a new user."""
# Validate invite code if required
"""Register a new user.
Supports two flows:
- account_invite_code: Join an existing account (bypasses platform invite gate)
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
)
account_invite_record = result.scalar_one_or_none()
if not account_invite_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid account invite code"
)
if account_invite_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has already been used"
)
if account_invite_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has expired"
)
# Validate platform invite code if required (skip if account invite was provided)
invite_code_record = None
if settings.REQUIRE_INVITE_CODE:
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
if not user_data.invite_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -96,8 +143,37 @@ async def register(
invite_code_id=invite_code_record.id if invite_code_record else None
)
db.add(new_user)
await db.flush() # Get user ID before creating account
# Mark invite code as used
if account_invite_record:
# Join existing account via account invite
new_user.account_id = account_invite_record.account_id
new_user.account_role = account_invite_record.role
# Mark account invite as used
account_invite_record.accepted_by_id = new_user.id
account_invite_record.used_at = datetime.now(timezone.utc)
else:
# Create personal Account + free Subscription
new_account = Account(
name=f"{user_data.name}'s Account",
display_code=_generate_display_code(),
owner_id=new_user.id,
)
db.add(new_account)
await db.flush() # Get account ID
new_subscription = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_subscription)
new_user.account_id = new_account.id
new_user.account_role = "owner"
# Mark platform invite code as used
if invite_code_record:
invite_code_record.used_by_id = new_user.id
invite_code_record.used_at = datetime.now(timezone.utc)

View File

@@ -28,11 +28,11 @@ async def list_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(TreeCategory)
@@ -41,19 +41,19 @@ async def list_categories(
if not include_inactive:
query = query.where(TreeCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(TreeCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(TreeCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
TreeCategory.team_id.is_(None), # Global
TreeCategory.team_id == current_user.team_id # User's team
TreeCategory.account_id.is_(None), # Global
TreeCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(TreeCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(TreeCategory.account_id.is_(None))
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
@@ -76,7 +76,7 @@ async def list_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
tree_count=tree_count
@@ -101,8 +101,8 @@ async def get_category(
detail="Category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -121,7 +121,7 @@ async def get_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -138,10 +138,10 @@ async def create_category(
):
"""Create a new category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_category(current_user, category_data.team_id):
if not can_create_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this category"
@@ -150,10 +150,10 @@ async def create_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeCategory).where(
TreeCategory.slug == slug,
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -164,7 +164,7 @@ async def create_category(
# Get next display order
order_query = select(func.max(TreeCategory.display_order)).where(
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -173,7 +173,7 @@ async def create_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -186,7 +186,7 @@ async def create_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -227,7 +227,7 @@ async def update_category(
# Check for duplicate slug
existing_query = select(TreeCategory).where(
TreeCategory.slug == new_slug,
TreeCategory.team_id == category.team_id,
TreeCategory.account_id == category.account_id,
TreeCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -257,7 +257,7 @@ async def update_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -25,11 +25,11 @@ async def list_step_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List step categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(StepCategory)
@@ -38,19 +38,19 @@ async def list_step_categories(
if not include_inactive:
query = query.where(StepCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(StepCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(StepCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
StepCategory.team_id.is_(None), # Global
StepCategory.team_id == current_user.team_id # User's team
StepCategory.account_id.is_(None), # Global
StepCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(StepCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(StepCategory.account_id.is_(None))
query = query.order_by(StepCategory.display_order, StepCategory.name)
@@ -66,7 +66,7 @@ async def list_step_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
step_count=0 # Will be computed when step_library exists
@@ -91,8 +91,8 @@ async def get_step_category(
detail="Step category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this step category"
@@ -103,7 +103,7 @@ async def get_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -120,10 +120,10 @@ async def create_step_category(
):
"""Create a new step category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_step_category(current_user, category_data.team_id):
if not can_create_step_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this step category"
@@ -132,10 +132,10 @@ async def create_step_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(StepCategory).where(
StepCategory.slug == slug,
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -146,7 +146,7 @@ async def create_step_category(
# Get next display order
order_query = select(func.max(StepCategory.display_order)).where(
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -155,7 +155,7 @@ async def create_step_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -168,7 +168,7 @@ async def create_step_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -209,7 +209,7 @@ async def update_step_category(
# Check for duplicate slug
existing_query = select(StepCategory).where(
StepCategory.slug == new_slug,
StepCategory.team_id == category.team_id,
StepCategory.account_id == category.account_id,
StepCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -231,7 +231,7 @@ async def update_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -55,10 +55,10 @@ async def get_step_or_404(
def build_visibility_filter(user: User):
"""Build SQLAlchemy filter for step visibility based on user."""
if user.team_id:
if user.account_id:
return or_(
StepLibrary.visibility == 'public',
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
StepLibrary.created_by == user.id # Own private steps
)
else:
@@ -249,7 +249,7 @@ async def get_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -296,10 +296,10 @@ async def create_step(
if not cat_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Invalid category")
# Team validation: can only set team_id to own team
team_id = step_data.team_id
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another team")
# Account validation: can only set account_id to own account
account_id = step_data.account_id
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another account")
step = StepLibrary(
title=step_data.title,
@@ -309,7 +309,7 @@ async def create_step(
tags=step_data.tags,
visibility=step_data.visibility,
created_by=current_user.id,
team_id=team_id or current_user.team_id,
account_id=account_id or current_user.account_id,
)
db.add(step)
@@ -326,7 +326,7 @@ async def create_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -393,7 +393,7 @@ async def update_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,

View File

@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
async def list_tags(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""List tags visible to the user.
Returns global tags plus team-specific tags for the user's team.
Returns global tags plus account-specific tags for the user's account.
Tags are ordered by usage count (most used first).
"""
query = select(TreeTag)
# Filter by visibility: global OR user's team
if include_team and current_user.team_id:
# Filter by visibility: global OR user's account
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None), # Global
TreeTag.team_id == current_user.team_id # User's team
TreeTag.account_id.is_(None), # Global
TreeTag.account_id == current_user.account_id # User's account
)
)
else:
# Only show global tags
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
@@ -55,7 +55,7 @@ async def search_tags(
current_user: Annotated[User, Depends(get_current_active_user)],
q: str = Query(..., min_length=1, description="Search query"),
limit: int = Query(10, ge=1, le=50),
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""Search/autocomplete tags.
@@ -68,15 +68,15 @@ async def search_tags(
)
# Filter by visibility
if include_team and current_user.team_id:
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == current_user.team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == current_user.account_id
)
)
else:
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
@@ -102,8 +102,8 @@ async def get_tag(
detail="Tag not found"
)
# Check access: global tags visible to all, team tags only to team members
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global tags visible to all, account tags only to account members
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tag"
@@ -120,10 +120,10 @@ async def create_tag(
):
"""Create a new tag.
- Global admins can create global tags (team_id=None)
- Team members can create team-specific tags for their team
- Global admins can create global tags (account_id=None)
- Account members can create account-specific tags for their account
"""
if not can_create_tag(current_user, tag_data.team_id):
if not can_create_tag(current_user, tag_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this tag"
@@ -132,10 +132,10 @@ async def create_tag(
# Generate slug
slug = TreeTag.slugify(tag_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeTag).where(
TreeTag.slug == slug,
TreeTag.team_id == tag_data.team_id
TreeTag.account_id == tag_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -147,7 +147,7 @@ async def create_tag(
new_tag = TreeTag(
name=tag_data.name,
slug=slug,
team_id=tag_data.team_id,
account_id=tag_data.account_id,
created_by=current_user.id
)
db.add(new_tag)
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
continue
# Try to find existing tag
# Determine scope: use tree's team, or global for admin-owned trees
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
# Determine scope: use tree's account, or global for admin-owned trees
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None), # Global tag
TreeTag.team_id == tag_team_id # Team tag
TreeTag.account_id.is_(None), # Global tag
TreeTag.account_id == tag_account_id # Account tag
)
)
tag_result = await db.execute(tag_query)
tag = tag_result.scalar_one_or_none()
if not tag:
# Create new tag - prefer team scope unless admin creating on public tree
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
# Fall back to user's team if they can't create in tree's scope
new_team_id = current_user.team_id
# Create new tag - prefer account scope unless admin creating on public tree
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
# Fall back to user's account if they can't create in tree's scope
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -331,7 +331,7 @@ async def replace_tree_tags(
tree.tags.clear()
# Add new tags
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
for tag_name in tag_data.tags:
slug = TreeTag.slugify(tag_name)
@@ -340,8 +340,8 @@ async def replace_tree_tags(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tag_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tag_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -349,14 +349,14 @@ async def replace_tree_tags(
if not tag:
# Create new tag
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
new_team_id = current_user.team_id
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -397,7 +397,7 @@ async def get_tree_tags(
# Check if user can view the tree
if not tree.is_public:
if tree.author_id != current_user.id:
if tree.team_id != current_user.team_id:
if tree.account_id != current_user.account_id:
if not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.subscriptions import check_tree_limit
from app.core.audit import log_audit
router = APIRouter(prefix="/trees", tags=["trees"])
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
Tree.is_public == True,
Tree.author_id == current_user.id,
]
if current_user.team_id:
conditions.append(Tree.team_id == current_user.team_id)
if current_user.account_id:
conditions.append(Tree.account_id == current_user.account_id)
return or_(*conditions)
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
category_info=category_info,
tags=tree.tag_names,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
tags=tree.tag_names,
tree_structure=tree.tree_structure,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -289,7 +290,7 @@ async def create_tree(
detail="Category not found"
)
# Check category access
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -302,16 +303,25 @@ async def create_tree(
category_id=tree_data.category_id,
tree_structure=tree_data.tree_structure,
author_id=None if is_default else current_user.id, # Default trees have no author
team_id=None if is_default else current_user.team_id,
account_id=None if is_default else current_user.account_id,
is_public=True if is_default else tree_data.is_public, # Default trees are always public
is_default=is_default
)
# Check subscription tree limit
if not is_default and current_user.account_id:
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
if not can_create:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
)
db.add(new_tree)
await db.flush() # Get the ID
# Handle tags
if tree_data.tags:
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
# Collect tags to add
tags_to_add = []
@@ -322,8 +332,8 @@ async def create_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -334,7 +344,7 @@ async def create_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -420,7 +430,7 @@ async def update_tree(
status_code=status.HTTP_404_NOT_FOUND,
detail="Category not found"
)
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -450,7 +460,7 @@ async def update_tree(
)
# Add new tags
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
added_tag_ids = set()
for tag_name in tags_data:
@@ -459,8 +469,8 @@ async def update_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -470,7 +480,7 @@ async def update_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)

View File

@@ -0,0 +1,62 @@
import logging
from fastapi import APIRouter, Request, HTTPException, status, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.config import settings
from app.core.stripe_handlers import WEBHOOK_HANDLERS
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
@router.post("/stripe")
async def stripe_webhook(
request: Request,
db: AsyncSession = Depends(get_db),
):
"""Handle Stripe webhook events.
Returns 200 for all events to prevent Stripe retries.
Actual processing happens only when Stripe is configured.
"""
if not settings.stripe_enabled:
return {"status": "ok", "message": "Stripe not configured, event ignored"}
payload = await request.body()
sig_header = request.headers.get("stripe-signature")
if not sig_header:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing stripe-signature header"
)
# Verify webhook signature
try:
import stripe
stripe.api_key = settings.STRIPE_SECRET_KEY
event = stripe.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except ImportError:
logger.warning("stripe package not installed, cannot verify webhook")
return {"status": "ok", "message": "stripe package not installed"}
except Exception as e:
logger.error("Stripe webhook signature verification failed: %s", e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature"
)
event_type = event.get("type", "")
handler = WEBHOOK_HANDLERS.get(event_type)
if handler:
try:
await handler(event, db)
except Exception:
logger.exception("Error handling Stripe event %s", event_type)
return {"status": "ok"}

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
api_router.include_router(step_categories.router)
api_router.include_router(steps.router)
api_router.include_router(admin.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
class AccountResponse(BaseModel):
id: UUID
name: str
display_code: str
owner_id: UUID
stripe_customer_id: Optional[str] = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class AccountUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=255)
class AccountInviteCreate(BaseModel):
email: str = Field(..., max_length=255)
role: str = Field("engineer", pattern="^(engineer|viewer)$")
expires_in_days: Optional[int] = Field(None, ge=1, le=30)
class AccountInviteResponse(BaseModel):
id: UUID
account_id: UUID
email: str
code: str
role: str
expires_at: Optional[datetime] = None
created_at: datetime
used_at: Optional[datetime] = None
model_config = {"from_attributes": True}

View File

@@ -20,7 +20,7 @@ class CategoryBase(BaseModel):
class CategoryCreate(CategoryBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
class CategoryUpdate(BaseModel):
@@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel):
class CategoryResponse(CategoryBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
created_at: datetime
@@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel):
name: str
slug: str
description: Optional[str] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
tree_count: int = 0

View File

@@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel):
class StepCategoryCreate(StepCategoryBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
class StepCategoryUpdate(BaseModel):
@@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel):
class StepCategoryResponse(StepCategoryBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
created_at: datetime
@@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel):
name: str
slug: str
description: Optional[str] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
step_count: int = 0

View File

@@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel):
class StepLibraryCreate(StepLibraryBase):
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
class StepLibraryUpdate(BaseModel):
@@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel):
class StepLibraryResponse(StepLibraryBase):
id: UUID
created_by: UUID
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
rating_average: Decimal
rating_count: int

View File

@@ -0,0 +1,40 @@
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel
class SubscriptionResponse(BaseModel):
id: UUID
plan: str
status: str
billing_interval: Optional[str] = None
current_period_start: Optional[datetime] = None
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
stripe_subscription_id: Optional[str] = None
model_config = {"from_attributes": True}
class PlanLimitsResponse(BaseModel):
plan: str
max_trees: Optional[int] = None
max_sessions_per_month: Optional[int] = None
max_users: Optional[int] = None
custom_branding: bool = False
priority_support: bool = False
export_formats: list[str] = ["markdown", "text"]
model_config = {"from_attributes": True}
class UsageResponse(BaseModel):
tree_count: int
session_count_this_month: int
class SubscriptionDetails(BaseModel):
subscription: SubscriptionResponse
limits: PlanLimitsResponse
usage: UsageResponse

View File

@@ -19,13 +19,13 @@ class TagBase(BaseModel):
class TagCreate(TagBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.")
class TagResponse(TagBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
created_at: datetime
@@ -37,7 +37,7 @@ class TagListResponse(BaseModel):
id: UUID
name: str
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
class Config:
@@ -53,4 +53,4 @@ class TagSearchParams(BaseModel):
"""Query parameters for tag search/autocomplete."""
q: str = Field(..., min_length=1, description="Search query")
limit: int = Field(10, ge=1, le=50)
include_team: bool = Field(True, description="Include team-specific tags")
include_account: bool = Field(True, description="Include account-specific tags")

View File

@@ -44,7 +44,7 @@ class TreeResponse(TreeBase):
id: UUID
tree_structure: dict[str, Any]
author_id: Optional[UUID] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
category_id: Optional[UUID] = None
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
@@ -69,7 +69,7 @@ class TreeListResponse(BaseModel):
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
author_id: Optional[UUID] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
is_active: bool
is_public: bool
is_default: bool

View File

@@ -13,6 +13,7 @@ class UserBase(BaseModel):
class UserCreate(UserBase):
password: str = Field(..., min_length=10, description="Password must be at least 10 characters")
invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)")
account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account")
@field_validator('password')
@classmethod
@@ -38,11 +39,11 @@ class UserLogin(BaseModel):
class UserResponse(UserBase):
id: UUID
role: str
role: str = "engineer"
account_id: UUID
account_role: str
is_super_admin: bool = False
is_team_admin: bool = False
is_active: bool = True
team_id: Optional[UUID] = None
created_at: datetime
last_login: Optional[datetime] = None
@@ -54,5 +55,5 @@ class RoleUpdate(BaseModel):
role: Literal["engineer", "viewer"]
class TeamAdminUpdate(BaseModel):
is_team_admin: bool
class AccountRoleUpdate(BaseModel):
account_role: str = Field(..., pattern="^(engineer|viewer)$")

View File

@@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
await conn.execute(sa.text("CREATE SCHEMA public"))
await conn.run_sync(Base.metadata.create_all)
# Seed plan_limits for subscription checks
await conn.execute(sa.text("""
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
VALUES
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
"""))
# Create async session maker
async_session_maker = async_sessionmaker(
engine,

View File

@@ -0,0 +1,170 @@
"""Integration tests for account management endpoints."""
import pytest
from httpx import AsyncClient
class TestAccountEndpoints:
"""Test suite for account management endpoints."""
@pytest.mark.asyncio
async def test_get_my_account(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's account."""
response = await client.get("/api/v1/accounts/me", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "name" in data
assert "display_code" in data
assert "owner_id" in data
assert len(data["display_code"]) == 8
@pytest.mark.asyncio
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's subscription details."""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "subscription" in data
assert "limits" in data
assert "usage" in data
assert data["subscription"]["plan"] == "free"
assert data["subscription"]["status"] == "active"
assert data["limits"]["max_trees"] == 3
assert data["limits"]["max_sessions_per_month"] == 20
@pytest.mark.asyncio
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
"""Test getting members of current user's account."""
response = await client.get("/api/v1/accounts/me/members", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
# Current user should be in members list
assert any(m["account_role"] == "owner" for m in data)
@pytest.mark.asyncio
async def test_update_my_account(self, client: AsyncClient, auth_headers: dict):
"""Test updating account name."""
response = await client.patch(
"/api/v1/accounts/me",
json={"name": "Updated Account Name"},
headers=auth_headers
)
assert response.status_code == 200
assert response.json()["name"] == "Updated Account Name"
@pytest.mark.asyncio
async def test_update_account_requires_owner(self, client: AsyncClient):
"""Test that non-owners cannot update account settings."""
# Register two users
owner_data = {
"email": "owner@example.com",
"password": "OwnerPass123!",
"name": "Owner"
}
await client.post("/api/v1/auth/register", json=owner_data)
# Login as owner and create an invite
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "owner@example.com",
"password": "OwnerPass123!"
})
owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "member@example.com", "role": "engineer"},
headers=owner_headers
)
assert invite_resp.status_code == 201
invite_code = invite_resp.json()["code"]
# Register member with account invite code
member_data = {
"email": "member@example.com",
"password": "MemberPass123!",
"name": "Member",
"account_invite_code": invite_code
}
reg_resp = await client.post("/api/v1/auth/register", json=member_data)
assert reg_resp.status_code == 201
assert reg_resp.json()["account_role"] == "engineer"
# Login as member
member_login = await client.post("/api/v1/auth/login/json", json={
"email": "member@example.com",
"password": "MemberPass123!"
})
member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"}
# Member should not be able to update account
response = await client.patch(
"/api/v1/accounts/me",
json={"name": "Hacked Name"},
headers=member_headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict):
"""Test creating and listing account invites."""
# Create invite
response = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "invitee@example.com", "role": "engineer"},
headers=auth_headers
)
assert response.status_code == 201
data = response.json()
assert data["email"] == "invitee@example.com"
assert data["role"] == "engineer"
assert "code" in data
# List invites
list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers)
assert list_response.status_code == 200
invites = list_response.json()
assert len(invites) >= 1
assert any(i["email"] == "invitee@example.com" for i in invites)
@pytest.mark.asyncio
async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict):
"""Test that account invite code joins user to existing account."""
# Get current account
account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers)
account_id = account_resp.json()["id"]
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "joiner@example.com", "role": "viewer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
# Register with account invite code
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "joiner@example.com",
"password": "JoinerPass123!",
"name": "Joiner",
"account_invite_code": invite_code
})
assert reg_resp.status_code == 201
data = reg_resp.json()
assert data["account_id"] == account_id
assert data["account_role"] == "viewer"
@pytest.mark.asyncio
async def test_register_with_invalid_invite_code(self, client: AsyncClient):
"""Test that invalid account invite code is rejected."""
response = await client.post("/api/v1/auth/register", json={
"email": "bad@example.com",
"password": "BadPassword123!",
"name": "Bad User",
"account_invite_code": "INVALID_CODE"
})
assert response.status_code == 400
assert "invalid" in response.json()["detail"].lower()

View File

@@ -169,3 +169,51 @@ class TestAdminEndpoints:
log = result.scalar_one_or_none()
assert log is not None
assert str(log.resource_id) == user_id
@pytest.mark.asyncio
async def test_change_account_role(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test changing a user's account role."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "viewer"},
headers=admin_auth_headers
)
assert response.status_code == 200
assert response.json()["account_role"] == "viewer"
@pytest.mark.asyncio
async def test_change_account_role_invalid(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test that invalid account_role values are rejected."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "owner"},
headers=admin_auth_headers
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_audit_log_created_on_account_role_change(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession
):
"""Test that changing account role creates an audit log entry."""
user_id = test_user["user_data"]["id"]
await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "viewer"},
headers=admin_auth_headers
)
result = await test_db.execute(
select(AuditLog).where(AuditLog.action == "user.account_role_change")
)
log = result.scalar_one_or_none()
assert log is not None
assert str(log.resource_id) == user_id
assert log.details["old_account_role"] == "owner"
assert log.details["new_account_role"] == "viewer"

View File

@@ -23,6 +23,8 @@ class TestAuthentication:
assert data["email"] == user_data["email"]
assert data["name"] == user_data["name"]
assert data["role"] == "engineer"
assert "account_id" in data
assert data["account_role"] == "owner"
assert "id" in data
assert "password" not in data # Password should not be returned
@@ -107,6 +109,7 @@ class TestAuthentication:
assert response.status_code == 201
data = response.json()
assert data["role"] == "engineer"
assert data["account_role"] == "owner"
@pytest.mark.asyncio
async def test_register_default_role_is_engineer(self, client: AsyncClient):
@@ -121,6 +124,7 @@ class TestAuthentication:
assert response.status_code == 201
assert response.json()["role"] == "engineer"
assert response.json()["account_role"] == "owner"
@pytest.mark.asyncio
async def test_register_rejects_no_uppercase(self, client: AsyncClient):

View File

@@ -0,0 +1,205 @@
"""Integration tests for account-based permissions."""
import pytest
from httpx import AsyncClient
class TestAccountPermissions:
"""Test suite for account-based permission checks."""
@pytest.mark.asyncio
async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db):
"""Test that viewers cannot create trees."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Register a user
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "viewer@example.com",
"password": "ViewerPass123!",
"name": "Viewer User"
})
assert reg_resp.status_code == 201
user_id = PyUUID(reg_resp.json()["id"])
# Demote to viewer via ORM
result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one()
user.account_role = "viewer"
await test_db.commit()
# Login as viewer
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "viewer@example.com",
"password": "ViewerPass123!"
})
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Try to create tree
response = await client.post("/api/v1/trees", json={
"name": "Viewer Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"}
}, headers=viewer_headers)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db):
"""Test that viewers can browse/list trees."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Create a public tree as the regular user first
await client.post("/api/v1/trees", json={
"name": "Public Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": True
}, headers=auth_headers)
# Register viewer
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "viewer2@example.com",
"password": "ViewerPass123!",
"name": "Viewer 2"
})
user_id = PyUUID(reg_resp.json()["id"])
result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one()
user.account_role = "viewer"
await test_db.commit()
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "viewer2@example.com",
"password": "ViewerPass123!"
})
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Viewer can list trees
response = await client.get("/api/v1/trees", headers=viewer_headers)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db):
"""Test that account owner can edit trees created by account members."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Get owner's account
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id = me_resp.json()["account_id"]
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "engineer@example.com", "role": "engineer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
# Register engineer in same account
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "engineer@example.com",
"password": "EngineerPass123!",
"name": "Engineer",
"account_invite_code": invite_code
})
assert reg_resp.status_code == 201
assert reg_resp.json()["account_id"] == account_id
# Login as engineer
eng_login = await client.post("/api/v1/auth/login/json", json={
"email": "engineer@example.com",
"password": "EngineerPass123!"
})
eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"}
# Engineer creates a tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Engineer's Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}
}, headers=eng_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Owner can edit engineer's tree
update_resp = await client.put(
f"/api/v1/trees/{tree_id}",
json={"name": "Owner Updated Name"},
headers=auth_headers
)
assert update_resp.status_code == 200
assert update_resp.json()["name"] == "Owner Updated Name"
@pytest.mark.asyncio
async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict):
"""Test that account members can see each other's non-public trees."""
# Get owner's account
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id = me_resp.json()["account_id"]
# Owner creates a private tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Private Account Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": False
}, headers=auth_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Create invite and add member
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "teammate@example.com", "role": "engineer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
await client.post("/api/v1/auth/register", json={
"email": "teammate@example.com",
"password": "TeammatePass123!",
"name": "Teammate",
"account_invite_code": invite_code
})
mate_login = await client.post("/api/v1/auth/login/json", json={
"email": "teammate@example.com",
"password": "TeammatePass123!"
})
mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"}
# Teammate should see the private tree (same account)
response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers)
assert response.status_code == 200
assert response.json()["name"] == "Private Account Tree"
@pytest.mark.asyncio
async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict):
"""Test that users from different accounts cannot see private trees."""
# Owner creates a private tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Secret Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": False
}, headers=auth_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Register a completely separate user (different account)
await client.post("/api/v1/auth/register", json={
"email": "outsider@example.com",
"password": "OutsiderPass123!",
"name": "Outsider"
})
outsider_login = await client.post("/api/v1/auth/login/json", json={
"email": "outsider@example.com",
"password": "OutsiderPass123!"
})
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
# Outsider should NOT see the private tree
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
assert response.status_code == 403

View File

@@ -0,0 +1,129 @@
"""Integration tests for subscription limits."""
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
class TestSubscriptionLimits:
"""Test suite for subscription plan limits."""
@pytest.mark.asyncio
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
"""Test that free plan has tree creation limit of 3."""
tree_template = {
"name": "Limit Test Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test tree"
}
}
# Create trees up to the limit
for i in range(3):
tree_data = {**tree_template, "name": f"Tree {i+1}"}
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}"
# 4th tree should fail with 402
tree_data = {**tree_template, "name": "Tree 4 Over Limit"}
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert response.status_code == 402
assert "limit" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict):
"""Test that subscription details reflect actual usage."""
# Check initial usage
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
initial_usage = response.json()["usage"]
assert initial_usage["tree_count"] == 0
# Create a tree
tree_data = {
"name": "Usage Test Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test"
}
}
create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert create_resp.status_code == 201
# Check usage increased
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
updated_usage = response.json()["usage"]
assert updated_usage["tree_count"] == 1
@pytest.mark.asyncio
async def test_super_admin_bypasses_limits(
self, client: AsyncClient, admin_auth_headers: dict
):
"""Test that super admin can create trees without limit checks."""
tree_template = {
"name": "Admin Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test tree"
},
"is_default": True # Default trees skip limit check
}
# Super admin creating default trees should always work
for i in range(5):
tree_data = {**tree_template, "name": f"Admin Tree {i+1}"}
response = await client.post(
"/api/v1/trees", json=tree_data, headers=admin_auth_headers
)
assert response.status_code == 201
@pytest.mark.asyncio
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
"""Test that free plan limits are correct."""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
limits = response.json()["limits"]
assert limits["plan"] == "free"
assert limits["max_trees"] == 3
assert limits["max_sessions_per_month"] == 20
assert limits["max_users"] == 1
assert limits["custom_branding"] is False
assert limits["priority_support"] is False
@pytest.mark.asyncio
async def test_upgraded_plan_has_higher_limits(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that upgrading plan increases limits."""
from app.models.subscription import Subscription
from app.models.user import User
# Get the user's subscription and upgrade it
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id_str = me_resp.json()["account_id"]
from uuid import UUID
account_id = UUID(account_id_str)
result = await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)
sub = result.scalar_one()
sub.plan = "pro"
await test_db.commit()
# Check limits are now pro
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
limits = response.json()["limits"]
assert limits["plan"] == "pro"
assert limits["max_trees"] == 25
assert limits["max_sessions_per_month"] == 200