feat: update all endpoints and schemas for account-based model

Replace team_id with account_id across all API endpoints (trees,
categories, tags, steps, step_categories, admin, auth). Add new
accounts and webhooks endpoints. Registration now atomically creates
Account + Subscription, with account_invite_code bypassing the
platform invite gate.

Schemas updated for account_id/account_role. 82 tests passing
including 18 new tests for accounts, subscriptions, and permissions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-07 02:39:01 -05:00
parent 4ccb93ee31
commit e0089a9c5a
24 changed files with 1178 additions and 152 deletions

View File

@@ -0,0 +1,236 @@
from datetime import datetime, timezone, timedelta
from typing import Annotated, Optional
from uuid import UUID
import secrets
import string
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.subscription import Subscription
from app.models.user import User
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
from app.schemas.user import UserResponse, AccountRoleUpdate
from app.api.deps import get_current_active_user, require_account_owner
router = APIRouter(prefix="/accounts", tags=["accounts"])
@router.get("/me", response_model=AccountResponse)
async def get_my_account(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's account."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
return account
@router.get("/me/subscription", response_model=SubscriptionDetails)
async def get_my_subscription(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's subscription details including limits and usage."""
sub = await get_account_subscription(current_user.account_id, db)
if not sub:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No subscription found"
)
limits = await get_plan_limits(sub.plan, db)
if not limits:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Plan limits not configured"
)
usage = await get_account_usage(current_user.account_id, db)
return SubscriptionDetails(
subscription=SubscriptionResponse.model_validate(sub),
limits=PlanLimitsResponse.model_validate(limits),
usage=UsageResponse(**usage),
)
@router.get("/me/members", response_model=list[UserResponse])
async def get_my_members(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get members of current user's account."""
result = await db.execute(
select(User).where(User.account_id == current_user.account_id)
.order_by(User.created_at)
)
return result.scalars().all()
@router.patch("/me", response_model=AccountResponse)
async def update_my_account(
data: AccountUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Update account settings (owner only)."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(account, field, value)
await db.commit()
await db.refresh(account)
return account
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
async def update_member_role(
user_id: UUID,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Change a member's role within the account (owner only)."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot change your own role"
)
user.account_role = data.account_role
await db.commit()
await db.refresh(user)
return user
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_member(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Remove a member from the account (owner only).
The removed user gets a new personal account.
"""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot remove yourself from your own account"
)
# Create a personal account for the removed user
chars = string.ascii_uppercase + string.digits
display_code = ''.join(secrets.choice(chars) for _ in range(8))
new_account = Account(
name=f"{user.name}'s Account",
display_code=display_code,
owner_id=user.id,
)
db.add(new_account)
await db.flush()
new_sub = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_sub)
user.account_id = new_account.id
user.account_role = "owner"
await db.commit()
return None
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
async def create_invite(
data: AccountInviteCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Create an invite to join this account (owner only)."""
code = secrets.token_urlsafe(16)
expires_at = None
if data.expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
invite = AccountInvite(
account_id=current_user.account_id,
invited_by_id=current_user.id,
email=data.email,
code=code,
role=data.role,
expires_at=expires_at,
)
db.add(invite)
await db.commit()
await db.refresh(invite)
return invite
@router.get("/me/invites", response_model=list[AccountInviteResponse])
async def list_invites(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""List invites for this account (owner only)."""
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.account_id == current_user.account_id)
.order_by(AccountInvite.created_at.desc())
)
return result.scalars().all()

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select, func
from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
from app.api.deps import require_admin
router = APIRouter(prefix="/admin", tags=["admin"])
@@ -21,7 +21,7 @@ async def list_users(
limit: int = Query(100, ge=1, le=100),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
role: Optional[str] = Query(None, description="Filter by role"),
team_id: Optional[UUID] = Query(None, description="Filter by team")
account_id: Optional[UUID] = Query(None, description="Filter by account")
):
"""List all users (super admin only)."""
query = select(User)
@@ -30,8 +30,8 @@ async def list_users(
query = query.where(User.is_active == is_active)
if role:
query = query.where(User.role == role)
if team_id:
query = query.where(User.team_id == team_id)
if account_id:
query = query.where(User.account_id == account_id)
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
@@ -91,14 +91,14 @@ async def update_user_role(
return user
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
async def toggle_team_admin(
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
async def update_account_role(
user_id: UUID,
data: TeamAdminUpdate,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Toggle is_team_admin for a user (super admin only)."""
"""Change a user's account role (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -108,15 +108,10 @@ async def toggle_team_admin(
detail="User not found"
)
if data.is_team_admin and user.team_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User must belong to a team to be a team admin"
)
user.is_team_admin = data.is_team_admin
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
{"is_team_admin": data.is_team_admin})
old_role = user.account_role
user.account_role = data.account_role
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
{"old_account_role": old_role, "new_account_role": data.account_role})
await db.commit()
await db.refresh(user)
return user

View File

@@ -1,3 +1,5 @@
import secrets
import string
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
@@ -18,6 +20,9 @@ from app.core.security import (
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.schemas.user import UserCreate, UserResponse, UserLogin
from app.schemas.token import Token
from app.api.deps import get_current_active_user, get_refresh_token_payload
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
db.add(token_record)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(8))
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
@@ -44,10 +55,46 @@ async def register(
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Register a new user."""
# Validate invite code if required
"""Register a new user.
Supports two flows:
- account_invite_code: Join an existing account (bypasses platform invite gate)
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
)
account_invite_record = result.scalar_one_or_none()
if not account_invite_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid account invite code"
)
if account_invite_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has already been used"
)
if account_invite_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has expired"
)
# Validate platform invite code if required (skip if account invite was provided)
invite_code_record = None
if settings.REQUIRE_INVITE_CODE:
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
if not user_data.invite_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -96,8 +143,37 @@ async def register(
invite_code_id=invite_code_record.id if invite_code_record else None
)
db.add(new_user)
await db.flush() # Get user ID before creating account
# Mark invite code as used
if account_invite_record:
# Join existing account via account invite
new_user.account_id = account_invite_record.account_id
new_user.account_role = account_invite_record.role
# Mark account invite as used
account_invite_record.accepted_by_id = new_user.id
account_invite_record.used_at = datetime.now(timezone.utc)
else:
# Create personal Account + free Subscription
new_account = Account(
name=f"{user_data.name}'s Account",
display_code=_generate_display_code(),
owner_id=new_user.id,
)
db.add(new_account)
await db.flush() # Get account ID
new_subscription = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_subscription)
new_user.account_id = new_account.id
new_user.account_role = "owner"
# Mark platform invite code as used
if invite_code_record:
invite_code_record.used_by_id = new_user.id
invite_code_record.used_at = datetime.now(timezone.utc)

View File

@@ -28,11 +28,11 @@ async def list_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(TreeCategory)
@@ -41,19 +41,19 @@ async def list_categories(
if not include_inactive:
query = query.where(TreeCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(TreeCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(TreeCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
TreeCategory.team_id.is_(None), # Global
TreeCategory.team_id == current_user.team_id # User's team
TreeCategory.account_id.is_(None), # Global
TreeCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(TreeCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(TreeCategory.account_id.is_(None))
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
@@ -76,7 +76,7 @@ async def list_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
tree_count=tree_count
@@ -101,8 +101,8 @@ async def get_category(
detail="Category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -121,7 +121,7 @@ async def get_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -138,10 +138,10 @@ async def create_category(
):
"""Create a new category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_category(current_user, category_data.team_id):
if not can_create_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this category"
@@ -150,10 +150,10 @@ async def create_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeCategory).where(
TreeCategory.slug == slug,
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -164,7 +164,7 @@ async def create_category(
# Get next display order
order_query = select(func.max(TreeCategory.display_order)).where(
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -173,7 +173,7 @@ async def create_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -186,7 +186,7 @@ async def create_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -227,7 +227,7 @@ async def update_category(
# Check for duplicate slug
existing_query = select(TreeCategory).where(
TreeCategory.slug == new_slug,
TreeCategory.team_id == category.team_id,
TreeCategory.account_id == category.account_id,
TreeCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -257,7 +257,7 @@ async def update_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -25,11 +25,11 @@ async def list_step_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List step categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(StepCategory)
@@ -38,19 +38,19 @@ async def list_step_categories(
if not include_inactive:
query = query.where(StepCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(StepCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(StepCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
StepCategory.team_id.is_(None), # Global
StepCategory.team_id == current_user.team_id # User's team
StepCategory.account_id.is_(None), # Global
StepCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(StepCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(StepCategory.account_id.is_(None))
query = query.order_by(StepCategory.display_order, StepCategory.name)
@@ -66,7 +66,7 @@ async def list_step_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
step_count=0 # Will be computed when step_library exists
@@ -91,8 +91,8 @@ async def get_step_category(
detail="Step category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this step category"
@@ -103,7 +103,7 @@ async def get_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -120,10 +120,10 @@ async def create_step_category(
):
"""Create a new step category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_step_category(current_user, category_data.team_id):
if not can_create_step_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this step category"
@@ -132,10 +132,10 @@ async def create_step_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(StepCategory).where(
StepCategory.slug == slug,
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -146,7 +146,7 @@ async def create_step_category(
# Get next display order
order_query = select(func.max(StepCategory.display_order)).where(
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -155,7 +155,7 @@ async def create_step_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -168,7 +168,7 @@ async def create_step_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -209,7 +209,7 @@ async def update_step_category(
# Check for duplicate slug
existing_query = select(StepCategory).where(
StepCategory.slug == new_slug,
StepCategory.team_id == category.team_id,
StepCategory.account_id == category.account_id,
StepCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -231,7 +231,7 @@ async def update_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -55,10 +55,10 @@ async def get_step_or_404(
def build_visibility_filter(user: User):
"""Build SQLAlchemy filter for step visibility based on user."""
if user.team_id:
if user.account_id:
return or_(
StepLibrary.visibility == 'public',
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
StepLibrary.created_by == user.id # Own private steps
)
else:
@@ -249,7 +249,7 @@ async def get_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -296,10 +296,10 @@ async def create_step(
if not cat_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Invalid category")
# Team validation: can only set team_id to own team
team_id = step_data.team_id
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another team")
# Account validation: can only set account_id to own account
account_id = step_data.account_id
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another account")
step = StepLibrary(
title=step_data.title,
@@ -309,7 +309,7 @@ async def create_step(
tags=step_data.tags,
visibility=step_data.visibility,
created_by=current_user.id,
team_id=team_id or current_user.team_id,
account_id=account_id or current_user.account_id,
)
db.add(step)
@@ -326,7 +326,7 @@ async def create_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -393,7 +393,7 @@ async def update_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,

View File

@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
async def list_tags(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""List tags visible to the user.
Returns global tags plus team-specific tags for the user's team.
Returns global tags plus account-specific tags for the user's account.
Tags are ordered by usage count (most used first).
"""
query = select(TreeTag)
# Filter by visibility: global OR user's team
if include_team and current_user.team_id:
# Filter by visibility: global OR user's account
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None), # Global
TreeTag.team_id == current_user.team_id # User's team
TreeTag.account_id.is_(None), # Global
TreeTag.account_id == current_user.account_id # User's account
)
)
else:
# Only show global tags
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
@@ -55,7 +55,7 @@ async def search_tags(
current_user: Annotated[User, Depends(get_current_active_user)],
q: str = Query(..., min_length=1, description="Search query"),
limit: int = Query(10, ge=1, le=50),
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""Search/autocomplete tags.
@@ -68,15 +68,15 @@ async def search_tags(
)
# Filter by visibility
if include_team and current_user.team_id:
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == current_user.team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == current_user.account_id
)
)
else:
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
@@ -102,8 +102,8 @@ async def get_tag(
detail="Tag not found"
)
# Check access: global tags visible to all, team tags only to team members
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global tags visible to all, account tags only to account members
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tag"
@@ -120,10 +120,10 @@ async def create_tag(
):
"""Create a new tag.
- Global admins can create global tags (team_id=None)
- Team members can create team-specific tags for their team
- Global admins can create global tags (account_id=None)
- Account members can create account-specific tags for their account
"""
if not can_create_tag(current_user, tag_data.team_id):
if not can_create_tag(current_user, tag_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this tag"
@@ -132,10 +132,10 @@ async def create_tag(
# Generate slug
slug = TreeTag.slugify(tag_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeTag).where(
TreeTag.slug == slug,
TreeTag.team_id == tag_data.team_id
TreeTag.account_id == tag_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -147,7 +147,7 @@ async def create_tag(
new_tag = TreeTag(
name=tag_data.name,
slug=slug,
team_id=tag_data.team_id,
account_id=tag_data.account_id,
created_by=current_user.id
)
db.add(new_tag)
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
continue
# Try to find existing tag
# Determine scope: use tree's team, or global for admin-owned trees
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
# Determine scope: use tree's account, or global for admin-owned trees
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None), # Global tag
TreeTag.team_id == tag_team_id # Team tag
TreeTag.account_id.is_(None), # Global tag
TreeTag.account_id == tag_account_id # Account tag
)
)
tag_result = await db.execute(tag_query)
tag = tag_result.scalar_one_or_none()
if not tag:
# Create new tag - prefer team scope unless admin creating on public tree
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
# Fall back to user's team if they can't create in tree's scope
new_team_id = current_user.team_id
# Create new tag - prefer account scope unless admin creating on public tree
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
# Fall back to user's account if they can't create in tree's scope
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -331,7 +331,7 @@ async def replace_tree_tags(
tree.tags.clear()
# Add new tags
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
for tag_name in tag_data.tags:
slug = TreeTag.slugify(tag_name)
@@ -340,8 +340,8 @@ async def replace_tree_tags(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tag_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tag_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -349,14 +349,14 @@ async def replace_tree_tags(
if not tag:
# Create new tag
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
new_team_id = current_user.team_id
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -397,7 +397,7 @@ async def get_tree_tags(
# Check if user can view the tree
if not tree.is_public:
if tree.author_id != current_user.id:
if tree.team_id != current_user.team_id:
if tree.account_id != current_user.account_id:
if not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.subscriptions import check_tree_limit
from app.core.audit import log_audit
router = APIRouter(prefix="/trees", tags=["trees"])
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
Tree.is_public == True,
Tree.author_id == current_user.id,
]
if current_user.team_id:
conditions.append(Tree.team_id == current_user.team_id)
if current_user.account_id:
conditions.append(Tree.account_id == current_user.account_id)
return or_(*conditions)
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
category_info=category_info,
tags=tree.tag_names,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
tags=tree.tag_names,
tree_structure=tree.tree_structure,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -289,7 +290,7 @@ async def create_tree(
detail="Category not found"
)
# Check category access
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -302,16 +303,25 @@ async def create_tree(
category_id=tree_data.category_id,
tree_structure=tree_data.tree_structure,
author_id=None if is_default else current_user.id, # Default trees have no author
team_id=None if is_default else current_user.team_id,
account_id=None if is_default else current_user.account_id,
is_public=True if is_default else tree_data.is_public, # Default trees are always public
is_default=is_default
)
# Check subscription tree limit
if not is_default and current_user.account_id:
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
if not can_create:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
)
db.add(new_tree)
await db.flush() # Get the ID
# Handle tags
if tree_data.tags:
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
# Collect tags to add
tags_to_add = []
@@ -322,8 +332,8 @@ async def create_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -334,7 +344,7 @@ async def create_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -420,7 +430,7 @@ async def update_tree(
status_code=status.HTTP_404_NOT_FOUND,
detail="Category not found"
)
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -450,7 +460,7 @@ async def update_tree(
)
# Add new tags
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
added_tag_ids = set()
for tag_name in tags_data:
@@ -459,8 +469,8 @@ async def update_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -470,7 +480,7 @@ async def update_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)

View File

@@ -0,0 +1,62 @@
import logging
from fastapi import APIRouter, Request, HTTPException, status, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.config import settings
from app.core.stripe_handlers import WEBHOOK_HANDLERS
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
@router.post("/stripe")
async def stripe_webhook(
request: Request,
db: AsyncSession = Depends(get_db),
):
"""Handle Stripe webhook events.
Returns 200 for all events to prevent Stripe retries.
Actual processing happens only when Stripe is configured.
"""
if not settings.stripe_enabled:
return {"status": "ok", "message": "Stripe not configured, event ignored"}
payload = await request.body()
sig_header = request.headers.get("stripe-signature")
if not sig_header:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing stripe-signature header"
)
# Verify webhook signature
try:
import stripe
stripe.api_key = settings.STRIPE_SECRET_KEY
event = stripe.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except ImportError:
logger.warning("stripe package not installed, cannot verify webhook")
return {"status": "ok", "message": "stripe package not installed"}
except Exception as e:
logger.error("Stripe webhook signature verification failed: %s", e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature"
)
event_type = event.get("type", "")
handler = WEBHOOK_HANDLERS.get(event_type)
if handler:
try:
await handler(event, db)
except Exception:
logger.exception("Error handling Stripe event %s", event_type)
return {"status": "ok"}

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
api_router.include_router(step_categories.router)
api_router.include_router(steps.router)
api_router.include_router(admin.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)