From e0089a9c5aeb72a7e988b031ad8ec92e2d292ed8 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:39:01 -0500 Subject: [PATCH] feat: update all endpoints and schemas for account-based model Replace team_id with account_id across all API endpoints (trees, categories, tags, steps, step_categories, admin, auth). Add new accounts and webhooks endpoints. Registration now atomically creates Account + Subscription, with account_invite_code bypassing the platform invite gate. Schemas updated for account_id/account_role. 82 tests passing including 18 new tests for accounts, subscriptions, and permissions. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/accounts.py | 236 +++++++++++++++++++ backend/app/api/endpoints/admin.py | 29 +-- backend/app/api/endpoints/auth.py | 84 ++++++- backend/app/api/endpoints/categories.py | 48 ++-- backend/app/api/endpoints/step_categories.py | 48 ++-- backend/app/api/endpoints/steps.py | 20 +- backend/app/api/endpoints/tags.py | 76 +++--- backend/app/api/endpoints/trees.py | 40 ++-- backend/app/api/endpoints/webhooks.py | 62 +++++ backend/app/api/router.py | 4 +- backend/app/schemas/account.py | 39 +++ backend/app/schemas/category.py | 6 +- backend/app/schemas/step_category.py | 6 +- backend/app/schemas/step_library.py | 4 +- backend/app/schemas/subscription.py | 40 ++++ backend/app/schemas/tag.py | 8 +- backend/app/schemas/tree.py | 4 +- backend/app/schemas/user.py | 11 +- backend/tests/conftest.py | 9 + backend/tests/test_account_management.py | 170 +++++++++++++ backend/tests/test_admin.py | 48 ++++ backend/tests/test_auth.py | 4 + backend/tests/test_permissions_account.py | 205 ++++++++++++++++ backend/tests/test_subscription_limits.py | 129 ++++++++++ 24 files changed, 1178 insertions(+), 152 deletions(-) create mode 100644 backend/app/api/endpoints/accounts.py create mode 100644 backend/app/api/endpoints/webhooks.py create mode 100644 backend/app/schemas/account.py create mode 100644 backend/app/schemas/subscription.py create mode 100644 backend/tests/test_account_management.py create mode 100644 backend/tests/test_permissions_account.py create mode 100644 backend/tests/test_subscription_limits.py diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py new file mode 100644 index 00000000..c2228584 --- /dev/null +++ b/backend/app/api/endpoints/accounts.py @@ -0,0 +1,236 @@ +from datetime import datetime, timezone, timedelta +from typing import Annotated, Optional +from uuid import UUID +import secrets +import string +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.core.database import get_db +from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage +from app.models.account import Account +from app.models.account_invite import AccountInvite +from app.models.subscription import Subscription +from app.models.user import User +from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse +from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails +from app.schemas.user import UserResponse, AccountRoleUpdate +from app.api.deps import get_current_active_user, require_account_owner + +router = APIRouter(prefix="/accounts", tags=["accounts"]) + + +@router.get("/me", response_model=AccountResponse) +async def get_my_account( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get current user's account.""" + result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = result.scalar_one_or_none() + if not account: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Account not found" + ) + return account + + +@router.get("/me/subscription", response_model=SubscriptionDetails) +async def get_my_subscription( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get current user's subscription details including limits and usage.""" + sub = await get_account_subscription(current_user.account_id, db) + if not sub: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No subscription found" + ) + + limits = await get_plan_limits(sub.plan, db) + if not limits: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Plan limits not configured" + ) + + usage = await get_account_usage(current_user.account_id, db) + + return SubscriptionDetails( + subscription=SubscriptionResponse.model_validate(sub), + limits=PlanLimitsResponse.model_validate(limits), + usage=UsageResponse(**usage), + ) + + +@router.get("/me/members", response_model=list[UserResponse]) +async def get_my_members( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get members of current user's account.""" + result = await db.execute( + select(User).where(User.account_id == current_user.account_id) + .order_by(User.created_at) + ) + return result.scalars().all() + + +@router.patch("/me", response_model=AccountResponse) +async def update_my_account( + data: AccountUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Update account settings (owner only).""" + result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = result.scalar_one_or_none() + if not account: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Account not found" + ) + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(account, field, value) + + await db.commit() + await db.refresh(account) + return account + + +@router.patch("/me/members/{user_id}/role", response_model=UserResponse) +async def update_member_role( + user_id: UUID, + data: AccountRoleUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Change a member's role within the account (owner only).""" + result = await db.execute( + select(User).where( + User.id == user_id, + User.account_id == current_user.account_id + ) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found in your account" + ) + + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot change your own role" + ) + + user.account_role = data.account_role + await db.commit() + await db.refresh(user) + return user + + +@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def remove_member( + user_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Remove a member from the account (owner only). + + The removed user gets a new personal account. + """ + result = await db.execute( + select(User).where( + User.id == user_id, + User.account_id == current_user.account_id + ) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found in your account" + ) + + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot remove yourself from your own account" + ) + + # Create a personal account for the removed user + chars = string.ascii_uppercase + string.digits + display_code = ''.join(secrets.choice(chars) for _ in range(8)) + + new_account = Account( + name=f"{user.name}'s Account", + display_code=display_code, + owner_id=user.id, + ) + db.add(new_account) + await db.flush() + + new_sub = Subscription( + account_id=new_account.id, + plan="free", + status="active", + ) + db.add(new_sub) + + user.account_id = new_account.id + user.account_role = "owner" + + await db.commit() + return None + + +@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED) +async def create_invite( + data: AccountInviteCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Create an invite to join this account (owner only).""" + code = secrets.token_urlsafe(16) + + expires_at = None + if data.expires_in_days: + expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days) + + invite = AccountInvite( + account_id=current_user.account_id, + invited_by_id=current_user.id, + email=data.email, + code=code, + role=data.role, + expires_at=expires_at, + ) + db.add(invite) + await db.commit() + await db.refresh(invite) + return invite + + +@router.get("/me/invites", response_model=list[AccountInviteResponse]) +async def list_invites( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """List invites for this account (owner only).""" + result = await db.execute( + select(AccountInvite) + .where(AccountInvite.account_id == current_user.account_id) + .order_by(AccountInvite.created_at.desc()) + ) + return result.scalars().all() diff --git a/backend/app/api/endpoints/admin.py b/backend/app/api/endpoints/admin.py index daa04166..e6bde866 100644 --- a/backend/app/api/endpoints/admin.py +++ b/backend/app/api/endpoints/admin.py @@ -7,7 +7,7 @@ from sqlalchemy import select, func from app.core.database import get_db from app.core.audit import log_audit from app.models.user import User -from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate +from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate from app.api.deps import require_admin router = APIRouter(prefix="/admin", tags=["admin"]) @@ -21,7 +21,7 @@ async def list_users( limit: int = Query(100, ge=1, le=100), is_active: Optional[bool] = Query(None, description="Filter by active status"), role: Optional[str] = Query(None, description="Filter by role"), - team_id: Optional[UUID] = Query(None, description="Filter by team") + account_id: Optional[UUID] = Query(None, description="Filter by account") ): """List all users (super admin only).""" query = select(User) @@ -30,8 +30,8 @@ async def list_users( query = query.where(User.is_active == is_active) if role: query = query.where(User.role == role) - if team_id: - query = query.where(User.team_id == team_id) + if account_id: + query = query.where(User.account_id == account_id) query = query.order_by(User.created_at.desc()).offset(skip).limit(limit) @@ -91,14 +91,14 @@ async def update_user_role( return user -@router.put("/users/{user_id}/team-admin", response_model=UserResponse) -async def toggle_team_admin( +@router.put("/users/{user_id}/account-role", response_model=UserResponse) +async def update_account_role( user_id: UUID, - data: TeamAdminUpdate, + data: AccountRoleUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)] ): - """Toggle is_team_admin for a user (super admin only).""" + """Change a user's account role (super admin only).""" result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() @@ -108,15 +108,10 @@ async def toggle_team_admin( detail="User not found" ) - if data.is_team_admin and user.team_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User must belong to a team to be a team admin" - ) - - user.is_team_admin = data.is_team_admin - await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id, - {"is_team_admin": data.is_team_admin}) + old_role = user.account_role + user.account_role = data.account_role + await log_audit(db, current_user.id, "user.account_role_change", "user", user.id, + {"old_account_role": old_role, "new_account_role": data.account_role}) await db.commit() await db.refresh(user) return user diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 5450860a..385c31fd 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -1,3 +1,5 @@ +import secrets +import string from datetime import datetime, timezone from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request @@ -18,6 +20,9 @@ from app.core.security import ( from app.models.user import User from app.models.invite_code import InviteCode from app.models.refresh_token import RefreshToken +from app.models.account import Account +from app.models.subscription import Subscription +from app.models.account_invite import AccountInvite from app.schemas.user import UserCreate, UserResponse, UserLogin from app.schemas.token import Token from app.api.deps import get_current_active_user, get_refresh_token_payload @@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id db.add(token_record) +def _generate_display_code() -> str: + """Generate a random 8-character alphanumeric display code.""" + chars = string.ascii_uppercase + string.digits + return ''.join(secrets.choice(chars) for _ in range(8)) + + @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @limiter.limit("3/minute") async def register( @@ -44,10 +55,46 @@ async def register( user_data: UserCreate, db: Annotated[AsyncSession, Depends(get_db)] ): - """Register a new user.""" - # Validate invite code if required + """Register a new user. + + Supports two flows: + - account_invite_code: Join an existing account (bypasses platform invite gate) + - invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled) + + After user creation, if no account invite was used, a personal Account + and free Subscription are created automatically. + """ + # Check for account invite code FIRST — bypasses platform invite gate + account_invite_record = None + if user_data.account_invite_code: + result = await db.execute( + select(AccountInvite).where( + AccountInvite.code == user_data.account_invite_code + ) + ) + account_invite_record = result.scalar_one_or_none() + + if not account_invite_record: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid account invite code" + ) + + if account_invite_record.is_used: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Account invite code has already been used" + ) + + if account_invite_record.is_expired: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Account invite code has expired" + ) + + # Validate platform invite code if required (skip if account invite was provided) invite_code_record = None - if settings.REQUIRE_INVITE_CODE: + if not account_invite_record and settings.REQUIRE_INVITE_CODE: if not user_data.invite_code: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -96,8 +143,37 @@ async def register( invite_code_id=invite_code_record.id if invite_code_record else None ) db.add(new_user) + await db.flush() # Get user ID before creating account - # Mark invite code as used + if account_invite_record: + # Join existing account via account invite + new_user.account_id = account_invite_record.account_id + new_user.account_role = account_invite_record.role + + # Mark account invite as used + account_invite_record.accepted_by_id = new_user.id + account_invite_record.used_at = datetime.now(timezone.utc) + else: + # Create personal Account + free Subscription + new_account = Account( + name=f"{user_data.name}'s Account", + display_code=_generate_display_code(), + owner_id=new_user.id, + ) + db.add(new_account) + await db.flush() # Get account ID + + new_subscription = Subscription( + account_id=new_account.id, + plan="free", + status="active", + ) + db.add(new_subscription) + + new_user.account_id = new_account.id + new_user.account_role = "owner" + + # Mark platform invite code as used if invite_code_record: invite_code_record.used_by_id = new_user.id invite_code_record.used_at = datetime.now(timezone.utc) diff --git a/backend/app/api/endpoints/categories.py b/backend/app/api/endpoints/categories.py index 22341ef8..73505c05 100644 --- a/backend/app/api/endpoints/categories.py +++ b/backend/app/api/endpoints/categories.py @@ -28,11 +28,11 @@ async def list_categories( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], include_inactive: bool = Query(False, description="Include inactive categories"), - team_only: bool = Query(False, description="Only show team-specific categories") + account_only: bool = Query(False, description="Only show account-specific categories") ): """List categories visible to the user. - Returns global categories plus team-specific categories for the user's team. + Returns global categories plus account-specific categories for the user's account. """ # Build query for accessible categories query = select(TreeCategory) @@ -41,19 +41,19 @@ async def list_categories( if not include_inactive: query = query.where(TreeCategory.is_active == True) - # Filter by visibility: global OR user's team - if team_only and current_user.team_id: - query = query.where(TreeCategory.team_id == current_user.team_id) - elif current_user.team_id: + # Filter by visibility: global OR user's account + if account_only and current_user.account_id: + query = query.where(TreeCategory.account_id == current_user.account_id) + elif current_user.account_id: query = query.where( or_( - TreeCategory.team_id.is_(None), # Global - TreeCategory.team_id == current_user.team_id # User's team + TreeCategory.account_id.is_(None), # Global + TreeCategory.account_id == current_user.account_id # User's account ) ) else: - # User has no team, only show global categories - query = query.where(TreeCategory.team_id.is_(None)) + # User has no account, only show global categories + query = query.where(TreeCategory.account_id.is_(None)) query = query.order_by(TreeCategory.display_order, TreeCategory.name) @@ -76,7 +76,7 @@ async def list_categories( name=cat.name, slug=cat.slug, description=cat.description, - team_id=cat.team_id, + account_id=cat.account_id, display_order=cat.display_order, is_active=cat.is_active, tree_count=tree_count @@ -101,8 +101,8 @@ async def get_category( detail="Category not found" ) - # Check access: global categories visible to all, team categories only to team members - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global categories visible to all, account categories only to account members + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -121,7 +121,7 @@ async def get_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, @@ -138,10 +138,10 @@ async def create_category( ): """Create a new category. - - Global admins can create global categories (team_id=None) - - Team admins can create team-specific categories for their team + - Global admins can create global categories (account_id=None) + - Account admins can create account-specific categories for their account """ - if not can_create_category(current_user, category_data.team_id): + if not can_create_category(current_user, category_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this category" @@ -150,10 +150,10 @@ async def create_category( # Generate slug slug = slugify(category_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(TreeCategory).where( TreeCategory.slug == slug, - TreeCategory.team_id == category_data.team_id + TreeCategory.account_id == category_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -164,7 +164,7 @@ async def create_category( # Get next display order order_query = select(func.max(TreeCategory.display_order)).where( - TreeCategory.team_id == category_data.team_id + TreeCategory.account_id == category_data.account_id ) order_result = await db.execute(order_query) max_order = order_result.scalar() or 0 @@ -173,7 +173,7 @@ async def create_category( name=category_data.name, slug=slug, description=category_data.description, - team_id=category_data.team_id, + account_id=category_data.account_id, display_order=max_order + 1, created_by=current_user.id ) @@ -186,7 +186,7 @@ async def create_category( name=new_category.name, slug=new_category.slug, description=new_category.description, - team_id=new_category.team_id, + account_id=new_category.account_id, display_order=new_category.display_order, is_active=new_category.is_active, created_at=new_category.created_at, @@ -227,7 +227,7 @@ async def update_category( # Check for duplicate slug existing_query = select(TreeCategory).where( TreeCategory.slug == new_slug, - TreeCategory.team_id == category.team_id, + TreeCategory.account_id == category.account_id, TreeCategory.id != category_id ) existing = await db.execute(existing_query) @@ -257,7 +257,7 @@ async def update_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, diff --git a/backend/app/api/endpoints/step_categories.py b/backend/app/api/endpoints/step_categories.py index 3480929e..5d890225 100644 --- a/backend/app/api/endpoints/step_categories.py +++ b/backend/app/api/endpoints/step_categories.py @@ -25,11 +25,11 @@ async def list_step_categories( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], include_inactive: bool = Query(False, description="Include inactive categories"), - team_only: bool = Query(False, description="Only show team-specific categories") + account_only: bool = Query(False, description="Only show account-specific categories") ): """List step categories visible to the user. - Returns global categories plus team-specific categories for the user's team. + Returns global categories plus account-specific categories for the user's account. """ # Build query for accessible categories query = select(StepCategory) @@ -38,19 +38,19 @@ async def list_step_categories( if not include_inactive: query = query.where(StepCategory.is_active == True) - # Filter by visibility: global OR user's team - if team_only and current_user.team_id: - query = query.where(StepCategory.team_id == current_user.team_id) - elif current_user.team_id: + # Filter by visibility: global OR user's account + if account_only and current_user.account_id: + query = query.where(StepCategory.account_id == current_user.account_id) + elif current_user.account_id: query = query.where( or_( - StepCategory.team_id.is_(None), # Global - StepCategory.team_id == current_user.team_id # User's team + StepCategory.account_id.is_(None), # Global + StepCategory.account_id == current_user.account_id # User's account ) ) else: - # User has no team, only show global categories - query = query.where(StepCategory.team_id.is_(None)) + # User has no account, only show global categories + query = query.where(StepCategory.account_id.is_(None)) query = query.order_by(StepCategory.display_order, StepCategory.name) @@ -66,7 +66,7 @@ async def list_step_categories( name=cat.name, slug=cat.slug, description=cat.description, - team_id=cat.team_id, + account_id=cat.account_id, display_order=cat.display_order, is_active=cat.is_active, step_count=0 # Will be computed when step_library exists @@ -91,8 +91,8 @@ async def get_step_category( detail="Step category not found" ) - # Check access: global categories visible to all, team categories only to team members - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global categories visible to all, account categories only to account members + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this step category" @@ -103,7 +103,7 @@ async def get_step_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, @@ -120,10 +120,10 @@ async def create_step_category( ): """Create a new step category. - - Global admins can create global categories (team_id=None) - - Team admins can create team-specific categories for their team + - Global admins can create global categories (account_id=None) + - Account admins can create account-specific categories for their account """ - if not can_create_step_category(current_user, category_data.team_id): + if not can_create_step_category(current_user, category_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this step category" @@ -132,10 +132,10 @@ async def create_step_category( # Generate slug slug = slugify(category_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(StepCategory).where( StepCategory.slug == slug, - StepCategory.team_id == category_data.team_id + StepCategory.account_id == category_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -146,7 +146,7 @@ async def create_step_category( # Get next display order order_query = select(func.max(StepCategory.display_order)).where( - StepCategory.team_id == category_data.team_id + StepCategory.account_id == category_data.account_id ) order_result = await db.execute(order_query) max_order = order_result.scalar() or 0 @@ -155,7 +155,7 @@ async def create_step_category( name=category_data.name, slug=slug, description=category_data.description, - team_id=category_data.team_id, + account_id=category_data.account_id, display_order=max_order + 1, created_by=current_user.id ) @@ -168,7 +168,7 @@ async def create_step_category( name=new_category.name, slug=new_category.slug, description=new_category.description, - team_id=new_category.team_id, + account_id=new_category.account_id, display_order=new_category.display_order, is_active=new_category.is_active, created_at=new_category.created_at, @@ -209,7 +209,7 @@ async def update_step_category( # Check for duplicate slug existing_query = select(StepCategory).where( StepCategory.slug == new_slug, - StepCategory.team_id == category.team_id, + StepCategory.account_id == category.account_id, StepCategory.id != category_id ) existing = await db.execute(existing_query) @@ -231,7 +231,7 @@ async def update_step_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, diff --git a/backend/app/api/endpoints/steps.py b/backend/app/api/endpoints/steps.py index 49605379..d1e5a160 100644 --- a/backend/app/api/endpoints/steps.py +++ b/backend/app/api/endpoints/steps.py @@ -55,10 +55,10 @@ async def get_step_or_404( def build_visibility_filter(user: User): """Build SQLAlchemy filter for step visibility based on user.""" - if user.team_id: + if user.account_id: return or_( StepLibrary.visibility == 'public', - and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id), + and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id), StepLibrary.created_by == user.id # Own private steps ) else: @@ -249,7 +249,7 @@ async def get_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, @@ -296,10 +296,10 @@ async def create_step( if not cat_result.scalar_one_or_none(): raise HTTPException(status_code=400, detail="Invalid category") - # Team validation: can only set team_id to own team - team_id = step_data.team_id - if team_id and team_id != current_user.team_id and not current_user.is_super_admin: - raise HTTPException(status_code=403, detail="Cannot create step for another team") + # Account validation: can only set account_id to own account + account_id = step_data.account_id + if account_id and account_id != current_user.account_id and not current_user.is_super_admin: + raise HTTPException(status_code=403, detail="Cannot create step for another account") step = StepLibrary( title=step_data.title, @@ -309,7 +309,7 @@ async def create_step( tags=step_data.tags, visibility=step_data.visibility, created_by=current_user.id, - team_id=team_id or current_user.team_id, + account_id=account_id or current_user.account_id, ) db.add(step) @@ -326,7 +326,7 @@ async def create_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, @@ -393,7 +393,7 @@ async def update_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, diff --git a/backend/app/api/endpoints/tags.py b/backend/app/api/endpoints/tags.py index 4f764544..334e33f8 100644 --- a/backend/app/api/endpoints/tags.py +++ b/backend/app/api/endpoints/tags.py @@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"]) async def list_tags( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], - include_team: bool = Query(True, description="Include team-specific tags") + include_account: bool = Query(True, description="Include account-specific tags") ): """List tags visible to the user. - Returns global tags plus team-specific tags for the user's team. + Returns global tags plus account-specific tags for the user's account. Tags are ordered by usage count (most used first). """ query = select(TreeTag) - # Filter by visibility: global OR user's team - if include_team and current_user.team_id: + # Filter by visibility: global OR user's account + if include_account and current_user.account_id: query = query.where( or_( - TreeTag.team_id.is_(None), # Global - TreeTag.team_id == current_user.team_id # User's team + TreeTag.account_id.is_(None), # Global + TreeTag.account_id == current_user.account_id # User's account ) ) else: # Only show global tags - query = query.where(TreeTag.team_id.is_(None)) + query = query.where(TreeTag.account_id.is_(None)) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name) @@ -55,7 +55,7 @@ async def search_tags( current_user: Annotated[User, Depends(get_current_active_user)], q: str = Query(..., min_length=1, description="Search query"), limit: int = Query(10, ge=1, le=50), - include_team: bool = Query(True, description="Include team-specific tags") + include_account: bool = Query(True, description="Include account-specific tags") ): """Search/autocomplete tags. @@ -68,15 +68,15 @@ async def search_tags( ) # Filter by visibility - if include_team and current_user.team_id: + if include_account and current_user.account_id: query = query.where( or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == current_user.team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == current_user.account_id ) ) else: - query = query.where(TreeTag.team_id.is_(None)) + query = query.where(TreeTag.account_id.is_(None)) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit) @@ -102,8 +102,8 @@ async def get_tag( detail="Tag not found" ) - # Check access: global tags visible to all, team tags only to team members - if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global tags visible to all, account tags only to account members + if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this tag" @@ -120,10 +120,10 @@ async def create_tag( ): """Create a new tag. - - Global admins can create global tags (team_id=None) - - Team members can create team-specific tags for their team + - Global admins can create global tags (account_id=None) + - Account members can create account-specific tags for their account """ - if not can_create_tag(current_user, tag_data.team_id): + if not can_create_tag(current_user, tag_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this tag" @@ -132,10 +132,10 @@ async def create_tag( # Generate slug slug = TreeTag.slugify(tag_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(TreeTag).where( TreeTag.slug == slug, - TreeTag.team_id == tag_data.team_id + TreeTag.account_id == tag_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -147,7 +147,7 @@ async def create_tag( new_tag = TreeTag( name=tag_data.name, slug=slug, - team_id=tag_data.team_id, + account_id=tag_data.account_id, created_by=current_user.id ) db.add(new_tag) @@ -200,30 +200,30 @@ async def add_tags_to_tree( continue # Try to find existing tag - # Determine scope: use tree's team, or global for admin-owned trees - tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + # Determine scope: use tree's account, or global for admin-owned trees + tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), # Global tag - TreeTag.team_id == tag_team_id # Team tag + TreeTag.account_id.is_(None), # Global tag + TreeTag.account_id == tag_account_id # Account tag ) ) tag_result = await db.execute(tag_query) tag = tag_result.scalar_one_or_none() if not tag: - # Create new tag - prefer team scope unless admin creating on public tree - new_team_id = tag_team_id - if not can_create_tag(current_user, new_team_id): - # Fall back to user's team if they can't create in tree's scope - new_team_id = current_user.team_id + # Create new tag - prefer account scope unless admin creating on public tree + new_account_id = tag_account_id + if not can_create_tag(current_user, new_account_id): + # Fall back to user's account if they can't create in tree's scope + new_account_id = current_user.account_id tag = TreeTag( name=tag_name, slug=slug, - team_id=new_team_id, + account_id=new_account_id, created_by=current_user.id ) db.add(tag) @@ -331,7 +331,7 @@ async def replace_tree_tags( tree.tags.clear() # Add new tags - tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) for tag_name in tag_data.tags: slug = TreeTag.slugify(tag_name) @@ -340,8 +340,8 @@ async def replace_tree_tags( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tag_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tag_account_id ) ) tag_result = await db.execute(tag_query) @@ -349,14 +349,14 @@ async def replace_tree_tags( if not tag: # Create new tag - new_team_id = tag_team_id - if not can_create_tag(current_user, new_team_id): - new_team_id = current_user.team_id + new_account_id = tag_account_id + if not can_create_tag(current_user, new_account_id): + new_account_id = current_user.account_id tag = TreeTag( name=tag_name, slug=slug, - team_id=new_team_id, + account_id=new_account_id, created_by=current_user.id ) db.add(tag) @@ -397,7 +397,7 @@ async def get_tree_tags( # Check if user can view the tree if not tree.is_public: if tree.author_id != current_user.id: - if tree.team_id != current_user.team_id: + if tree.account_id != current_user.account_id: if not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 8f238f74..0a0fb6d0 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin from app.core.permissions import can_edit_tree, can_access_tree +from app.core.subscriptions import check_tree_limit from app.core.audit import log_audit router = APIRouter(prefix="/trees", tags=["trees"]) @@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User): Tree.is_public == True, Tree.author_id == current_user.id, ] - if current_user.team_id: - conditions.append(Tree.team_id == current_user.team_id) + if current_user.account_id: + conditions.append(Tree.account_id == current_user.account_id) return or_(*conditions) @@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse: category_info=category_info, tags=tree.tag_names, author_id=tree.author_id, - team_id=tree.team_id, + account_id=tree.account_id, is_active=tree.is_active, is_public=tree.is_public, is_default=tree.is_default, @@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse: tags=tree.tag_names, tree_structure=tree.tree_structure, author_id=tree.author_id, - team_id=tree.team_id, + account_id=tree.account_id, is_active=tree.is_active, is_public=tree.is_public, is_default=tree.is_default, @@ -289,7 +290,7 @@ async def create_tree( detail="Category not found" ) # Check category access - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -302,16 +303,25 @@ async def create_tree( category_id=tree_data.category_id, tree_structure=tree_data.tree_structure, author_id=None if is_default else current_user.id, # Default trees have no author - team_id=None if is_default else current_user.team_id, + account_id=None if is_default else current_user.account_id, is_public=True if is_default else tree_data.is_public, # Default trees are always public is_default=is_default ) + # Check subscription tree limit + if not is_default and current_user.account_id: + can_create, limit, count = await check_tree_limit(current_user.account_id, db) + if not can_create: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees." + ) + db.add(new_tree) await db.flush() # Get the ID # Handle tags if tree_data.tags: - tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) # Collect tags to add tags_to_add = [] @@ -322,8 +332,8 @@ async def create_tree( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tree_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tree_account_id ) ) tag_result = await db.execute(tag_query) @@ -334,7 +344,7 @@ async def create_tree( tag = TreeTag( name=tag_name, slug=slug, - team_id=tree_team_id, + account_id=tree_account_id, created_by=current_user.id ) db.add(tag) @@ -420,7 +430,7 @@ async def update_tree( status_code=status.HTTP_404_NOT_FOUND, detail="Category not found" ) - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -450,7 +460,7 @@ async def update_tree( ) # Add new tags - tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) added_tag_ids = set() for tag_name in tags_data: @@ -459,8 +469,8 @@ async def update_tree( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tree_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tree_account_id ) ) tag_result = await db.execute(tag_query) @@ -470,7 +480,7 @@ async def update_tree( tag = TreeTag( name=tag_name, slug=slug, - team_id=tree_team_id, + account_id=tree_account_id, created_by=current_user.id ) db.add(tag) diff --git a/backend/app/api/endpoints/webhooks.py b/backend/app/api/endpoints/webhooks.py new file mode 100644 index 00000000..1773ec22 --- /dev/null +++ b/backend/app/api/endpoints/webhooks.py @@ -0,0 +1,62 @@ +import logging +from fastapi import APIRouter, Request, HTTPException, status, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.config import settings +from app.core.stripe_handlers import WEBHOOK_HANDLERS + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/webhooks", tags=["webhooks"]) + + +@router.post("/stripe") +async def stripe_webhook( + request: Request, + db: AsyncSession = Depends(get_db), +): + """Handle Stripe webhook events. + + Returns 200 for all events to prevent Stripe retries. + Actual processing happens only when Stripe is configured. + """ + if not settings.stripe_enabled: + return {"status": "ok", "message": "Stripe not configured, event ignored"} + + payload = await request.body() + sig_header = request.headers.get("stripe-signature") + + if not sig_header: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing stripe-signature header" + ) + + # Verify webhook signature + try: + import stripe + stripe.api_key = settings.STRIPE_SECRET_KEY + event = stripe.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except ImportError: + logger.warning("stripe package not installed, cannot verify webhook") + return {"status": "ok", "message": "stripe package not installed"} + except Exception as e: + logger.error("Stripe webhook signature verification failed: %s", e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid signature" + ) + + event_type = event.get("type", "") + handler = WEBHOOK_HANDLERS.get(event_type) + + if handler: + try: + await handler(event, db) + except Exception: + logger.exception("Error handling Stripe event %s", event_type) + + return {"status": "ok"} diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 5940ea3f..05a773bc 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin +from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks api_router = APIRouter() @@ -13,3 +13,5 @@ api_router.include_router(folders.router) api_router.include_router(step_categories.router) api_router.include_router(steps.router) api_router.include_router(admin.router) +api_router.include_router(accounts.router) +api_router.include_router(webhooks.router) diff --git a/backend/app/schemas/account.py b/backend/app/schemas/account.py new file mode 100644 index 00000000..8a9a101e --- /dev/null +++ b/backend/app/schemas/account.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field + + +class AccountResponse(BaseModel): + id: UUID + name: str + display_code: str + owner_id: UUID + stripe_customer_id: Optional[str] = None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class AccountUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=255) + + +class AccountInviteCreate(BaseModel): + email: str = Field(..., max_length=255) + role: str = Field("engineer", pattern="^(engineer|viewer)$") + expires_in_days: Optional[int] = Field(None, ge=1, le=30) + + +class AccountInviteResponse(BaseModel): + id: UUID + account_id: UUID + email: str + code: str + role: str + expires_at: Optional[datetime] = None + created_at: datetime + used_at: Optional[datetime] = None + + model_config = {"from_attributes": True} diff --git a/backend/app/schemas/category.py b/backend/app/schemas/category.py index 2cca0694..13e9955c 100644 --- a/backend/app/schemas/category.py +++ b/backend/app/schemas/category.py @@ -20,7 +20,7 @@ class CategoryBase(BaseModel): class CategoryCreate(CategoryBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.") class CategoryUpdate(BaseModel): @@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel): class CategoryResponse(CategoryBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool created_at: datetime @@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel): name: str slug: str description: Optional[str] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool tree_count: int = 0 diff --git a/backend/app/schemas/step_category.py b/backend/app/schemas/step_category.py index 9d03667e..106c32c9 100644 --- a/backend/app/schemas/step_category.py +++ b/backend/app/schemas/step_category.py @@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel): class StepCategoryCreate(StepCategoryBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.") class StepCategoryUpdate(BaseModel): @@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel): class StepCategoryResponse(StepCategoryBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool created_at: datetime @@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel): name: str slug: str description: Optional[str] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool step_count: int = 0 diff --git a/backend/app/schemas/step_library.py b/backend/app/schemas/step_library.py index dbd7357a..93390c60 100644 --- a/backend/app/schemas/step_library.py +++ b/backend/app/schemas/step_library.py @@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel): class StepLibraryCreate(StepLibraryBase): - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None class StepLibraryUpdate(BaseModel): @@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel): class StepLibraryResponse(StepLibraryBase): id: UUID created_by: UUID - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int rating_average: Decimal rating_count: int diff --git a/backend/app/schemas/subscription.py b/backend/app/schemas/subscription.py new file mode 100644 index 00000000..9b832926 --- /dev/null +++ b/backend/app/schemas/subscription.py @@ -0,0 +1,40 @@ +from typing import Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +class SubscriptionResponse(BaseModel): + id: UUID + plan: str + status: str + billing_interval: Optional[str] = None + current_period_start: Optional[datetime] = None + current_period_end: Optional[datetime] = None + cancel_at_period_end: bool = False + stripe_subscription_id: Optional[str] = None + + model_config = {"from_attributes": True} + + +class PlanLimitsResponse(BaseModel): + plan: str + max_trees: Optional[int] = None + max_sessions_per_month: Optional[int] = None + max_users: Optional[int] = None + custom_branding: bool = False + priority_support: bool = False + export_formats: list[str] = ["markdown", "text"] + + model_config = {"from_attributes": True} + + +class UsageResponse(BaseModel): + tree_count: int + session_count_this_month: int + + +class SubscriptionDetails(BaseModel): + subscription: SubscriptionResponse + limits: PlanLimitsResponse + usage: UsageResponse diff --git a/backend/app/schemas/tag.py b/backend/app/schemas/tag.py index 2de4bfcf..47912057 100644 --- a/backend/app/schemas/tag.py +++ b/backend/app/schemas/tag.py @@ -19,13 +19,13 @@ class TagBase(BaseModel): class TagCreate(TagBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.") class TagResponse(TagBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int created_at: datetime @@ -37,7 +37,7 @@ class TagListResponse(BaseModel): id: UUID name: str slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int class Config: @@ -53,4 +53,4 @@ class TagSearchParams(BaseModel): """Query parameters for tag search/autocomplete.""" q: str = Field(..., min_length=1, description="Search query") limit: int = Field(10, ge=1, le=50) - include_team: bool = Field(True, description="Include team-specific tags") + include_account: bool = Field(True, description="Include account-specific tags") diff --git a/backend/app/schemas/tree.py b/backend/app/schemas/tree.py index 8810f0d5..86aa7567 100644 --- a/backend/app/schemas/tree.py +++ b/backend/app/schemas/tree.py @@ -44,7 +44,7 @@ class TreeResponse(TreeBase): id: UUID tree_structure: dict[str, Any] author_id: Optional[UUID] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None category_id: Optional[UUID] = None category_info: Optional[CategoryInfo] = None tags: list[str] = [] # List of tag names @@ -69,7 +69,7 @@ class TreeListResponse(BaseModel): category_info: Optional[CategoryInfo] = None tags: list[str] = [] # List of tag names author_id: Optional[UUID] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None is_active: bool is_public: bool is_default: bool diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index b2f5d81e..6777f668 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -13,6 +13,7 @@ class UserBase(BaseModel): class UserCreate(UserBase): password: str = Field(..., min_length=10, description="Password must be at least 10 characters") invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)") + account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account") @field_validator('password') @classmethod @@ -38,11 +39,11 @@ class UserLogin(BaseModel): class UserResponse(UserBase): id: UUID - role: str + role: str = "engineer" + account_id: UUID + account_role: str is_super_admin: bool = False - is_team_admin: bool = False is_active: bool = True - team_id: Optional[UUID] = None created_at: datetime last_login: Optional[datetime] = None @@ -54,5 +55,5 @@ class RoleUpdate(BaseModel): role: Literal["engineer", "viewer"] -class TeamAdminUpdate(BaseModel): - is_team_admin: bool +class AccountRoleUpdate(BaseModel): + account_role: str = Field(..., pattern="^(engineer|viewer)$") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7845ba82..e7b7667a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: await conn.execute(sa.text("CREATE SCHEMA public")) await conn.run_sync(Base.metadata.create_all) + # Seed plan_limits for subscription checks + await conn.execute(sa.text(""" + INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) + VALUES + ('free', 3, 20, 1, false, false, '["markdown", "text"]'), + ('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'), + ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') + """)) + # Create async session maker async_session_maker = async_sessionmaker( engine, diff --git a/backend/tests/test_account_management.py b/backend/tests/test_account_management.py new file mode 100644 index 00000000..a8fed198 --- /dev/null +++ b/backend/tests/test_account_management.py @@ -0,0 +1,170 @@ +"""Integration tests for account management endpoints.""" + +import pytest +from httpx import AsyncClient + + +class TestAccountEndpoints: + """Test suite for account management endpoints.""" + + @pytest.mark.asyncio + async def test_get_my_account(self, client: AsyncClient, auth_headers: dict): + """Test getting current user's account.""" + response = await client.get("/api/v1/accounts/me", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "name" in data + assert "display_code" in data + assert "owner_id" in data + assert len(data["display_code"]) == 8 + + @pytest.mark.asyncio + async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict): + """Test getting current user's subscription details.""" + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert "subscription" in data + assert "limits" in data + assert "usage" in data + assert data["subscription"]["plan"] == "free" + assert data["subscription"]["status"] == "active" + assert data["limits"]["max_trees"] == 3 + assert data["limits"]["max_sessions_per_month"] == 20 + + @pytest.mark.asyncio + async def test_get_my_members(self, client: AsyncClient, auth_headers: dict): + """Test getting members of current user's account.""" + response = await client.get("/api/v1/accounts/me/members", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) >= 1 + # Current user should be in members list + assert any(m["account_role"] == "owner" for m in data) + + @pytest.mark.asyncio + async def test_update_my_account(self, client: AsyncClient, auth_headers: dict): + """Test updating account name.""" + response = await client.patch( + "/api/v1/accounts/me", + json={"name": "Updated Account Name"}, + headers=auth_headers + ) + assert response.status_code == 200 + assert response.json()["name"] == "Updated Account Name" + + @pytest.mark.asyncio + async def test_update_account_requires_owner(self, client: AsyncClient): + """Test that non-owners cannot update account settings.""" + # Register two users + owner_data = { + "email": "owner@example.com", + "password": "OwnerPass123!", + "name": "Owner" + } + await client.post("/api/v1/auth/register", json=owner_data) + + # Login as owner and create an invite + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "owner@example.com", + "password": "OwnerPass123!" + }) + owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "member@example.com", "role": "engineer"}, + headers=owner_headers + ) + assert invite_resp.status_code == 201 + invite_code = invite_resp.json()["code"] + + # Register member with account invite code + member_data = { + "email": "member@example.com", + "password": "MemberPass123!", + "name": "Member", + "account_invite_code": invite_code + } + reg_resp = await client.post("/api/v1/auth/register", json=member_data) + assert reg_resp.status_code == 201 + assert reg_resp.json()["account_role"] == "engineer" + + # Login as member + member_login = await client.post("/api/v1/auth/login/json", json={ + "email": "member@example.com", + "password": "MemberPass123!" + }) + member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"} + + # Member should not be able to update account + response = await client.patch( + "/api/v1/accounts/me", + json={"name": "Hacked Name"}, + headers=member_headers + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict): + """Test creating and listing account invites.""" + # Create invite + response = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "invitee@example.com", "role": "engineer"}, + headers=auth_headers + ) + assert response.status_code == 201 + data = response.json() + assert data["email"] == "invitee@example.com" + assert data["role"] == "engineer" + assert "code" in data + + # List invites + list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers) + assert list_response.status_code == 200 + invites = list_response.json() + assert len(invites) >= 1 + assert any(i["email"] == "invitee@example.com" for i in invites) + + @pytest.mark.asyncio + async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict): + """Test that account invite code joins user to existing account.""" + # Get current account + account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers) + account_id = account_resp.json()["id"] + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "joiner@example.com", "role": "viewer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + # Register with account invite code + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "joiner@example.com", + "password": "JoinerPass123!", + "name": "Joiner", + "account_invite_code": invite_code + }) + assert reg_resp.status_code == 201 + data = reg_resp.json() + assert data["account_id"] == account_id + assert data["account_role"] == "viewer" + + @pytest.mark.asyncio + async def test_register_with_invalid_invite_code(self, client: AsyncClient): + """Test that invalid account invite code is rejected.""" + response = await client.post("/api/v1/auth/register", json={ + "email": "bad@example.com", + "password": "BadPassword123!", + "name": "Bad User", + "account_invite_code": "INVALID_CODE" + }) + assert response.status_code == 400 + assert "invalid" in response.json()["detail"].lower() diff --git a/backend/tests/test_admin.py b/backend/tests/test_admin.py index aebc9ebc..1c96c03b 100644 --- a/backend/tests/test_admin.py +++ b/backend/tests/test_admin.py @@ -169,3 +169,51 @@ class TestAdminEndpoints: log = result.scalar_one_or_none() assert log is not None assert str(log.resource_id) == user_id + + @pytest.mark.asyncio + async def test_change_account_role( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict + ): + """Test changing a user's account role.""" + user_id = test_user["user_data"]["id"] + response = await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "viewer"}, + headers=admin_auth_headers + ) + assert response.status_code == 200 + assert response.json()["account_role"] == "viewer" + + @pytest.mark.asyncio + async def test_change_account_role_invalid( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict + ): + """Test that invalid account_role values are rejected.""" + user_id = test_user["user_data"]["id"] + response = await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "owner"}, + headers=admin_auth_headers + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_audit_log_created_on_account_role_change( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession + ): + """Test that changing account role creates an audit log entry.""" + user_id = test_user["user_data"]["id"] + await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "viewer"}, + headers=admin_auth_headers + ) + + result = await test_db.execute( + select(AuditLog).where(AuditLog.action == "user.account_role_change") + ) + log = result.scalar_one_or_none() + assert log is not None + assert str(log.resource_id) == user_id + assert log.details["old_account_role"] == "owner" + assert log.details["new_account_role"] == "viewer" diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 5e578900..f1463949 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -23,6 +23,8 @@ class TestAuthentication: assert data["email"] == user_data["email"] assert data["name"] == user_data["name"] assert data["role"] == "engineer" + assert "account_id" in data + assert data["account_role"] == "owner" assert "id" in data assert "password" not in data # Password should not be returned @@ -107,6 +109,7 @@ class TestAuthentication: assert response.status_code == 201 data = response.json() assert data["role"] == "engineer" + assert data["account_role"] == "owner" @pytest.mark.asyncio async def test_register_default_role_is_engineer(self, client: AsyncClient): @@ -121,6 +124,7 @@ class TestAuthentication: assert response.status_code == 201 assert response.json()["role"] == "engineer" + assert response.json()["account_role"] == "owner" @pytest.mark.asyncio async def test_register_rejects_no_uppercase(self, client: AsyncClient): diff --git a/backend/tests/test_permissions_account.py b/backend/tests/test_permissions_account.py new file mode 100644 index 00000000..211a7340 --- /dev/null +++ b/backend/tests/test_permissions_account.py @@ -0,0 +1,205 @@ +"""Integration tests for account-based permissions.""" + +import pytest +from httpx import AsyncClient + + +class TestAccountPermissions: + """Test suite for account-based permission checks.""" + + @pytest.mark.asyncio + async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db): + """Test that viewers cannot create trees.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Register a user + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "viewer@example.com", + "password": "ViewerPass123!", + "name": "Viewer User" + }) + assert reg_resp.status_code == 201 + user_id = PyUUID(reg_resp.json()["id"]) + + # Demote to viewer via ORM + result = await test_db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one() + user.account_role = "viewer" + await test_db.commit() + + # Login as viewer + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "viewer@example.com", + "password": "ViewerPass123!" + }) + viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Try to create tree + response = await client.post("/api/v1/trees", json={ + "name": "Viewer Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"} + }, headers=viewer_headers) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db): + """Test that viewers can browse/list trees.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Create a public tree as the regular user first + await client.post("/api/v1/trees", json={ + "name": "Public Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": True + }, headers=auth_headers) + + # Register viewer + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "viewer2@example.com", + "password": "ViewerPass123!", + "name": "Viewer 2" + }) + user_id = PyUUID(reg_resp.json()["id"]) + + result = await test_db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one() + user.account_role = "viewer" + await test_db.commit() + + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "viewer2@example.com", + "password": "ViewerPass123!" + }) + viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Viewer can list trees + response = await client.get("/api/v1/trees", headers=viewer_headers) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db): + """Test that account owner can edit trees created by account members.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Get owner's account + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id = me_resp.json()["account_id"] + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "engineer@example.com", "role": "engineer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + # Register engineer in same account + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "engineer@example.com", + "password": "EngineerPass123!", + "name": "Engineer", + "account_invite_code": invite_code + }) + assert reg_resp.status_code == 201 + assert reg_resp.json()["account_id"] == account_id + + # Login as engineer + eng_login = await client.post("/api/v1/auth/login/json", json={ + "email": "engineer@example.com", + "password": "EngineerPass123!" + }) + eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"} + + # Engineer creates a tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Engineer's Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"} + }, headers=eng_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Owner can edit engineer's tree + update_resp = await client.put( + f"/api/v1/trees/{tree_id}", + json={"name": "Owner Updated Name"}, + headers=auth_headers + ) + assert update_resp.status_code == 200 + assert update_resp.json()["name"] == "Owner Updated Name" + + @pytest.mark.asyncio + async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict): + """Test that account members can see each other's non-public trees.""" + # Get owner's account + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id = me_resp.json()["account_id"] + + # Owner creates a private tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Private Account Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": False + }, headers=auth_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Create invite and add member + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "teammate@example.com", "role": "engineer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + await client.post("/api/v1/auth/register", json={ + "email": "teammate@example.com", + "password": "TeammatePass123!", + "name": "Teammate", + "account_invite_code": invite_code + }) + + mate_login = await client.post("/api/v1/auth/login/json", json={ + "email": "teammate@example.com", + "password": "TeammatePass123!" + }) + mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"} + + # Teammate should see the private tree (same account) + response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers) + assert response.status_code == 200 + assert response.json()["name"] == "Private Account Tree" + + @pytest.mark.asyncio + async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict): + """Test that users from different accounts cannot see private trees.""" + # Owner creates a private tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Secret Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": False + }, headers=auth_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Register a completely separate user (different account) + await client.post("/api/v1/auth/register", json={ + "email": "outsider@example.com", + "password": "OutsiderPass123!", + "name": "Outsider" + }) + + outsider_login = await client.post("/api/v1/auth/login/json", json={ + "email": "outsider@example.com", + "password": "OutsiderPass123!" + }) + outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"} + + # Outsider should NOT see the private tree + response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers) + assert response.status_code == 403 diff --git a/backend/tests/test_subscription_limits.py b/backend/tests/test_subscription_limits.py new file mode 100644 index 00000000..540e42af --- /dev/null +++ b/backend/tests/test_subscription_limits.py @@ -0,0 +1,129 @@ +"""Integration tests for subscription limits.""" + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + + +class TestSubscriptionLimits: + """Test suite for subscription plan limits.""" + + @pytest.mark.asyncio + async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict): + """Test that free plan has tree creation limit of 3.""" + tree_template = { + "name": "Limit Test Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test tree" + } + } + + # Create trees up to the limit + for i in range(3): + tree_data = {**tree_template, "name": f"Tree {i+1}"} + response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}" + + # 4th tree should fail with 402 + tree_data = {**tree_template, "name": "Tree 4 Over Limit"} + response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert response.status_code == 402 + assert "limit" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict): + """Test that subscription details reflect actual usage.""" + # Check initial usage + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + initial_usage = response.json()["usage"] + assert initial_usage["tree_count"] == 0 + + # Create a tree + tree_data = { + "name": "Usage Test Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test" + } + } + create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert create_resp.status_code == 201 + + # Check usage increased + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + updated_usage = response.json()["usage"] + assert updated_usage["tree_count"] == 1 + + @pytest.mark.asyncio + async def test_super_admin_bypasses_limits( + self, client: AsyncClient, admin_auth_headers: dict + ): + """Test that super admin can create trees without limit checks.""" + tree_template = { + "name": "Admin Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test tree" + }, + "is_default": True # Default trees skip limit check + } + + # Super admin creating default trees should always work + for i in range(5): + tree_data = {**tree_template, "name": f"Admin Tree {i+1}"} + response = await client.post( + "/api/v1/trees", json=tree_data, headers=admin_auth_headers + ) + assert response.status_code == 201 + + @pytest.mark.asyncio + async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict): + """Test that free plan limits are correct.""" + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + limits = response.json()["limits"] + assert limits["plan"] == "free" + assert limits["max_trees"] == 3 + assert limits["max_sessions_per_month"] == 20 + assert limits["max_users"] == 1 + assert limits["custom_branding"] is False + assert limits["priority_support"] is False + + @pytest.mark.asyncio + async def test_upgraded_plan_has_higher_limits( + self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession + ): + """Test that upgrading plan increases limits.""" + from app.models.subscription import Subscription + from app.models.user import User + + # Get the user's subscription and upgrade it + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id_str = me_resp.json()["account_id"] + + from uuid import UUID + account_id = UUID(account_id_str) + result = await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + sub = result.scalar_one() + sub.plan = "pro" + await test_db.commit() + + # Check limits are now pro + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + limits = response.json()["limits"] + assert limits["plan"] == "pro" + assert limits["max_trees"] == 25 + assert limits["max_sessions_per_month"] == 200