feat: update all endpoints and schemas for account-based model
Replace team_id with account_id across all API endpoints (trees, categories, tags, steps, step_categories, admin, auth). Add new accounts and webhooks endpoints. Registration now atomically creates Account + Subscription, with account_invite_code bypassing the platform invite gate. Schemas updated for account_id/account_role. 82 tests passing including 18 new tests for accounts, subscriptions, and permissions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
236
backend/app/api/endpoints/accounts.py
Normal file
236
backend/app/api/endpoints/accounts.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
import secrets
|
||||
import string
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.api.deps import get_current_active_user, require_account_owner
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["accounts"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=AccountResponse)
|
||||
async def get_my_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's account."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
return account
|
||||
|
||||
|
||||
@router.get("/me/subscription", response_model=SubscriptionDetails)
|
||||
async def get_my_subscription(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's subscription details including limits and usage."""
|
||||
sub = await get_account_subscription(current_user.account_id, db)
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No subscription found"
|
||||
)
|
||||
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if not limits:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Plan limits not configured"
|
||||
)
|
||||
|
||||
usage = await get_account_usage(current_user.account_id, db)
|
||||
|
||||
return SubscriptionDetails(
|
||||
subscription=SubscriptionResponse.model_validate(sub),
|
||||
limits=PlanLimitsResponse.model_validate(limits),
|
||||
usage=UsageResponse(**usage),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me/members", response_model=list[UserResponse])
|
||||
async def get_my_members(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get members of current user's account."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.account_id == current_user.account_id)
|
||||
.order_by(User.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=AccountResponse)
|
||||
async def update_my_account(
|
||||
data: AccountUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Update account settings (owner only)."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(account, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
return account
|
||||
|
||||
|
||||
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
|
||||
async def update_member_role(
|
||||
user_id: UUID,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Change a member's role within the account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
user.account_role = data.account_role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_member(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Remove a member from the account (owner only).
|
||||
|
||||
The removed user gets a new personal account.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot remove yourself from your own account"
|
||||
)
|
||||
|
||||
# Create a personal account for the removed user
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
display_code = ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
new_account = Account(
|
||||
name=f"{user.name}'s Account",
|
||||
display_code=display_code,
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_sub = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_sub)
|
||||
|
||||
user.account_id = new_account.id
|
||||
user.account_role = "owner"
|
||||
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_invite(
|
||||
data: AccountInviteCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create an invite to join this account (owner only)."""
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
expires_at = None
|
||||
if data.expires_in_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=current_user.account_id,
|
||||
invited_by_id=current_user.id,
|
||||
email=data.email,
|
||||
code=code,
|
||||
role=data.role,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
return invite
|
||||
|
||||
|
||||
@router.get("/me/invites", response_model=list[AccountInviteResponse])
|
||||
async def list_invites(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""List invites for this account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.account_id == current_user.account_id)
|
||||
.order_by(AccountInvite.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import select, func
|
||||
from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
|
||||
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
|
||||
from app.api.deps import require_admin
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
@@ -21,7 +21,7 @@ async def list_users(
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
team_id: Optional[UUID] = Query(None, description="Filter by team")
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account")
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
@@ -30,8 +30,8 @@ async def list_users(
|
||||
query = query.where(User.is_active == is_active)
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
if team_id:
|
||||
query = query.where(User.team_id == team_id)
|
||||
if account_id:
|
||||
query = query.where(User.account_id == account_id)
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
@@ -91,14 +91,14 @@ async def update_user_role(
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
|
||||
async def toggle_team_admin(
|
||||
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
|
||||
async def update_account_role(
|
||||
user_id: UUID,
|
||||
data: TeamAdminUpdate,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Toggle is_team_admin for a user (super admin only)."""
|
||||
"""Change a user's account role (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -108,15 +108,10 @@ async def toggle_team_admin(
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if data.is_team_admin and user.team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User must belong to a team to be a team admin"
|
||||
)
|
||||
|
||||
user.is_team_admin = data.is_team_admin
|
||||
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
|
||||
{"is_team_admin": data.is_team_admin})
|
||||
old_role = user.account_role
|
||||
user.account_role = data.account_role
|
||||
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
|
||||
{"old_account_role": old_role, "new_account_role": data.account_role})
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
@@ -18,6 +20,9 @@ from app.core.security import (
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.token import Token
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
|
||||
db.add(token_record)
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
"""Generate a random 8-character alphanumeric display code."""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("3/minute")
|
||||
async def register(
|
||||
@@ -44,10 +55,46 @@ async def register(
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Register a new user."""
|
||||
# Validate invite code if required
|
||||
"""Register a new user.
|
||||
|
||||
Supports two flows:
|
||||
- account_invite_code: Join an existing account (bypasses platform invite gate)
|
||||
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
|
||||
|
||||
After user creation, if no account invite was used, a personal Account
|
||||
and free Subscription are created automatically.
|
||||
"""
|
||||
# Check for account invite code FIRST — bypasses platform invite gate
|
||||
account_invite_record = None
|
||||
if user_data.account_invite_code:
|
||||
result = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.code == user_data.account_invite_code
|
||||
)
|
||||
)
|
||||
account_invite_record = result.scalar_one_or_none()
|
||||
|
||||
if not account_invite_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid account invite code"
|
||||
)
|
||||
|
||||
if account_invite_record.is_used:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has already been used"
|
||||
)
|
||||
|
||||
if account_invite_record.is_expired:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has expired"
|
||||
)
|
||||
|
||||
# Validate platform invite code if required (skip if account invite was provided)
|
||||
invite_code_record = None
|
||||
if settings.REQUIRE_INVITE_CODE:
|
||||
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
|
||||
if not user_data.invite_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -96,8 +143,37 @@ async def register(
|
||||
invite_code_id=invite_code_record.id if invite_code_record else None
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # Get user ID before creating account
|
||||
|
||||
# Mark invite code as used
|
||||
if account_invite_record:
|
||||
# Join existing account via account invite
|
||||
new_user.account_id = account_invite_record.account_id
|
||||
new_user.account_role = account_invite_record.role
|
||||
|
||||
# Mark account invite as used
|
||||
account_invite_record.accepted_by_id = new_user.id
|
||||
account_invite_record.used_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
# Create personal Account + free Subscription
|
||||
new_account = Account(
|
||||
name=f"{user_data.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
owner_id=new_user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush() # Get account ID
|
||||
|
||||
new_subscription = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_subscription)
|
||||
|
||||
new_user.account_id = new_account.id
|
||||
new_user.account_role = "owner"
|
||||
|
||||
# Mark platform invite code as used
|
||||
if invite_code_record:
|
||||
invite_code_record.used_by_id = new_user.id
|
||||
invite_code_record.used_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -28,11 +28,11 @@ async def list_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(TreeCategory)
|
||||
@@ -41,19 +41,19 @@ async def list_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(TreeCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(TreeCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(TreeCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeCategory.team_id.is_(None), # Global
|
||||
TreeCategory.team_id == current_user.team_id # User's team
|
||||
TreeCategory.account_id.is_(None), # Global
|
||||
TreeCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(TreeCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(TreeCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
|
||||
|
||||
@@ -76,7 +76,7 @@ async def list_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
tree_count=tree_count
|
||||
@@ -101,8 +101,8 @@ async def get_category(
|
||||
detail="Category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -121,7 +121,7 @@ async def get_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -138,10 +138,10 @@ async def create_category(
|
||||
):
|
||||
"""Create a new category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_category(current_user, category_data.team_id):
|
||||
if not can_create_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this category"
|
||||
@@ -150,10 +150,10 @@ async def create_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == slug,
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -164,7 +164,7 @@ async def create_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(TreeCategory.display_order)).where(
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -173,7 +173,7 @@ async def create_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -186,7 +186,7 @@ async def create_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -227,7 +227,7 @@ async def update_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == new_slug,
|
||||
TreeCategory.team_id == category.team_id,
|
||||
TreeCategory.account_id == category.account_id,
|
||||
TreeCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -257,7 +257,7 @@ async def update_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -25,11 +25,11 @@ async def list_step_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List step categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(StepCategory)
|
||||
@@ -38,19 +38,19 @@ async def list_step_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(StepCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(StepCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(StepCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
StepCategory.team_id.is_(None), # Global
|
||||
StepCategory.team_id == current_user.team_id # User's team
|
||||
StepCategory.account_id.is_(None), # Global
|
||||
StepCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(StepCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(StepCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(StepCategory.display_order, StepCategory.name)
|
||||
|
||||
@@ -66,7 +66,7 @@ async def list_step_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
step_count=0 # Will be computed when step_library exists
|
||||
@@ -91,8 +91,8 @@ async def get_step_category(
|
||||
detail="Step category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this step category"
|
||||
@@ -103,7 +103,7 @@ async def get_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -120,10 +120,10 @@ async def create_step_category(
|
||||
):
|
||||
"""Create a new step category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_step_category(current_user, category_data.team_id):
|
||||
if not can_create_step_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this step category"
|
||||
@@ -132,10 +132,10 @@ async def create_step_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == slug,
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -146,7 +146,7 @@ async def create_step_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(StepCategory.display_order)).where(
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -155,7 +155,7 @@ async def create_step_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -168,7 +168,7 @@ async def create_step_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -209,7 +209,7 @@ async def update_step_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == new_slug,
|
||||
StepCategory.team_id == category.team_id,
|
||||
StepCategory.account_id == category.account_id,
|
||||
StepCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -231,7 +231,7 @@ async def update_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -55,10 +55,10 @@ async def get_step_or_404(
|
||||
|
||||
def build_visibility_filter(user: User):
|
||||
"""Build SQLAlchemy filter for step visibility based on user."""
|
||||
if user.team_id:
|
||||
if user.account_id:
|
||||
return or_(
|
||||
StepLibrary.visibility == 'public',
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
|
||||
StepLibrary.created_by == user.id # Own private steps
|
||||
)
|
||||
else:
|
||||
@@ -249,7 +249,7 @@ async def get_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -296,10 +296,10 @@ async def create_step(
|
||||
if not cat_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Invalid category")
|
||||
|
||||
# Team validation: can only set team_id to own team
|
||||
team_id = step_data.team_id
|
||||
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another team")
|
||||
# Account validation: can only set account_id to own account
|
||||
account_id = step_data.account_id
|
||||
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another account")
|
||||
|
||||
step = StepLibrary(
|
||||
title=step_data.title,
|
||||
@@ -309,7 +309,7 @@ async def create_step(
|
||||
tags=step_data.tags,
|
||||
visibility=step_data.visibility,
|
||||
created_by=current_user.id,
|
||||
team_id=team_id or current_user.team_id,
|
||||
account_id=account_id or current_user.account_id,
|
||||
)
|
||||
|
||||
db.add(step)
|
||||
@@ -326,7 +326,7 @@ async def create_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -393,7 +393,7 @@ async def update_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
|
||||
@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
|
||||
async def list_tags(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""List tags visible to the user.
|
||||
|
||||
Returns global tags plus team-specific tags for the user's team.
|
||||
Returns global tags plus account-specific tags for the user's account.
|
||||
Tags are ordered by usage count (most used first).
|
||||
"""
|
||||
query = select(TreeTag)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if include_team and current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global
|
||||
TreeTag.team_id == current_user.team_id # User's team
|
||||
TreeTag.account_id.is_(None), # Global
|
||||
TreeTag.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only show global tags
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
|
||||
|
||||
@@ -55,7 +55,7 @@ async def search_tags(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""Search/autocomplete tags.
|
||||
|
||||
@@ -68,15 +68,15 @@ async def search_tags(
|
||||
)
|
||||
|
||||
# Filter by visibility
|
||||
if include_team and current_user.team_id:
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == current_user.team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
|
||||
|
||||
@@ -102,8 +102,8 @@ async def get_tag(
|
||||
detail="Tag not found"
|
||||
)
|
||||
|
||||
# Check access: global tags visible to all, team tags only to team members
|
||||
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global tags visible to all, account tags only to account members
|
||||
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tag"
|
||||
@@ -120,10 +120,10 @@ async def create_tag(
|
||||
):
|
||||
"""Create a new tag.
|
||||
|
||||
- Global admins can create global tags (team_id=None)
|
||||
- Team members can create team-specific tags for their team
|
||||
- Global admins can create global tags (account_id=None)
|
||||
- Account members can create account-specific tags for their account
|
||||
"""
|
||||
if not can_create_tag(current_user, tag_data.team_id):
|
||||
if not can_create_tag(current_user, tag_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this tag"
|
||||
@@ -132,10 +132,10 @@ async def create_tag(
|
||||
# Generate slug
|
||||
slug = TreeTag.slugify(tag_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
TreeTag.team_id == tag_data.team_id
|
||||
TreeTag.account_id == tag_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -147,7 +147,7 @@ async def create_tag(
|
||||
new_tag = TreeTag(
|
||||
name=tag_data.name,
|
||||
slug=slug,
|
||||
team_id=tag_data.team_id,
|
||||
account_id=tag_data.account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(new_tag)
|
||||
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
|
||||
continue
|
||||
|
||||
# Try to find existing tag
|
||||
# Determine scope: use tree's team, or global for admin-owned trees
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
# Determine scope: use tree's account, or global for admin-owned trees
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global tag
|
||||
TreeTag.team_id == tag_team_id # Team tag
|
||||
TreeTag.account_id.is_(None), # Global tag
|
||||
TreeTag.account_id == tag_account_id # Account tag
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
tag = tag_result.scalar_one_or_none()
|
||||
|
||||
if not tag:
|
||||
# Create new tag - prefer team scope unless admin creating on public tree
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
# Fall back to user's team if they can't create in tree's scope
|
||||
new_team_id = current_user.team_id
|
||||
# Create new tag - prefer account scope unless admin creating on public tree
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
# Fall back to user's account if they can't create in tree's scope
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -331,7 +331,7 @@ async def replace_tree_tags(
|
||||
tree.tags.clear()
|
||||
|
||||
# Add new tags
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
for tag_name in tag_data.tags:
|
||||
slug = TreeTag.slugify(tag_name)
|
||||
@@ -340,8 +340,8 @@ async def replace_tree_tags(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tag_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tag_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -349,14 +349,14 @@ async def replace_tree_tags(
|
||||
|
||||
if not tag:
|
||||
# Create new tag
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
new_team_id = current_user.team_id
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -397,7 +397,7 @@ async def get_tree_tags(
|
||||
# Check if user can view the tree
|
||||
if not tree.is_public:
|
||||
if tree.author_id != current_user.id:
|
||||
if tree.team_id != current_user.team_id:
|
||||
if tree.account_id != current_user.account_id:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
from app.core.audit import log_audit
|
||||
|
||||
router = APIRouter(prefix="/trees", tags=["trees"])
|
||||
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
|
||||
Tree.is_public == True,
|
||||
Tree.author_id == current_user.id,
|
||||
]
|
||||
if current_user.team_id:
|
||||
conditions.append(Tree.team_id == current_user.team_id)
|
||||
if current_user.account_id:
|
||||
conditions.append(Tree.account_id == current_user.account_id)
|
||||
return or_(*conditions)
|
||||
|
||||
|
||||
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
|
||||
category_info=category_info,
|
||||
tags=tree.tag_names,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
tags=tree.tag_names,
|
||||
tree_structure=tree.tree_structure,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -289,7 +290,7 @@ async def create_tree(
|
||||
detail="Category not found"
|
||||
)
|
||||
# Check category access
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -302,16 +303,25 @@ async def create_tree(
|
||||
category_id=tree_data.category_id,
|
||||
tree_structure=tree_data.tree_structure,
|
||||
author_id=None if is_default else current_user.id, # Default trees have no author
|
||||
team_id=None if is_default else current_user.team_id,
|
||||
account_id=None if is_default else current_user.account_id,
|
||||
is_public=True if is_default else tree_data.is_public, # Default trees are always public
|
||||
is_default=is_default
|
||||
)
|
||||
# Check subscription tree limit
|
||||
if not is_default and current_user.account_id:
|
||||
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
||||
if not can_create:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
db.add(new_tree)
|
||||
await db.flush() # Get the ID
|
||||
|
||||
# Handle tags
|
||||
if tree_data.tags:
|
||||
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
# Collect tags to add
|
||||
tags_to_add = []
|
||||
@@ -322,8 +332,8 @@ async def create_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -334,7 +344,7 @@ async def create_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -420,7 +430,7 @@ async def update_tree(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Category not found"
|
||||
)
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -450,7 +460,7 @@ async def update_tree(
|
||||
)
|
||||
|
||||
# Add new tags
|
||||
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
added_tag_ids = set()
|
||||
|
||||
for tag_name in tags_data:
|
||||
@@ -459,8 +469,8 @@ async def update_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -470,7 +480,7 @@ async def update_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
|
||||
62
backend/app/api/endpoints/webhooks.py
Normal file
62
backend/app/api/endpoints/webhooks.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Request, HTTPException, status, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
from app.core.stripe_handlers import WEBHOOK_HANDLERS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
@router.post("/stripe")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
Returns 200 for all events to prevent Stripe retries.
|
||||
Actual processing happens only when Stripe is configured.
|
||||
"""
|
||||
if not settings.stripe_enabled:
|
||||
return {"status": "ok", "message": "Stripe not configured, event ignored"}
|
||||
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
if not sig_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing stripe-signature header"
|
||||
)
|
||||
|
||||
# Verify webhook signature
|
||||
try:
|
||||
import stripe
|
||||
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("stripe package not installed, cannot verify webhook")
|
||||
return {"status": "ok", "message": "stripe package not installed"}
|
||||
except Exception as e:
|
||||
logger.error("Stripe webhook signature verification failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid signature"
|
||||
)
|
||||
|
||||
event_type = event.get("type", "")
|
||||
handler = WEBHOOK_HANDLERS.get(event_type)
|
||||
|
||||
if handler:
|
||||
try:
|
||||
await handler(event, db)
|
||||
except Exception:
|
||||
logger.exception("Error handling Stripe event %s", event_type)
|
||||
|
||||
return {"status": "ok"}
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
|
||||
api_router.include_router(step_categories.router)
|
||||
api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
|
||||
Reference in New Issue
Block a user