feat: account-based subscription tiers (Free/Pro/Team) #32

Merged
chihlasm merged 5 commits from feat/subscription-tiers into main 2026-02-07 08:18:39 +00:00
73 changed files with 2944 additions and 246 deletions

View File

@@ -0,0 +1,110 @@
"""add accounts, subscriptions, plan_limits, and account_invites tables
Revision ID: 016
Revises: 015
Create Date: 2026-02-07
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
# revision identifiers, used by Alembic.
revision: str = '016'
down_revision: Union[str, None] = '015'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. accounts table
op.create_table(
'accounts',
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
sa.Column('name', sa.String(255), nullable=False),
sa.Column('display_code', sa.String(8), nullable=False),
sa.Column('owner_id', UUID(as_uuid=True), nullable=True), # nullable until user created
sa.Column('stripe_customer_id', sa.String(255), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('display_code', name='uq_accounts_display_code'),
)
# 2. subscriptions table
op.create_table(
'subscriptions',
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
sa.Column('stripe_subscription_id', sa.String(255), nullable=True),
sa.Column('stripe_price_id', sa.String(255), nullable=True),
sa.Column('plan', sa.String(50), nullable=False, server_default='free'),
sa.Column('billing_interval', sa.String(20), nullable=True), # 'monthly' or 'annual'
sa.Column('status', sa.String(50), nullable=False, server_default='active'),
sa.Column('seat_limit', sa.Integer, nullable=True),
sa.Column('current_period_start', sa.DateTime(timezone=True), nullable=True),
sa.Column('current_period_end', sa.DateTime(timezone=True), nullable=True),
sa.Column('cancel_at_period_end', sa.Boolean, nullable=False, server_default='false'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
sa.UniqueConstraint('account_id', name='uq_subscriptions_account_id'),
)
op.create_index('ix_subscriptions_account_id', 'subscriptions', ['account_id'])
op.create_index('ix_subscriptions_plan', 'subscriptions', ['plan'])
# 3. plan_limits table (configuration — seeded with 3 rows)
op.create_table(
'plan_limits',
sa.Column('plan', sa.String(50), nullable=False),
sa.Column('max_trees', sa.Integer, nullable=True), # NULL = unlimited
sa.Column('max_sessions_per_month', sa.Integer, nullable=True),
sa.Column('max_users', sa.Integer, nullable=True),
sa.Column('custom_branding', sa.Boolean, nullable=False, server_default='false'),
sa.Column('priority_support', sa.Boolean, nullable=False, server_default='false'),
sa.Column('export_formats', JSONB, nullable=False, server_default='["markdown", "text"]'),
sa.PrimaryKeyConstraint('plan'),
)
# Seed plan_limits
op.execute("""
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
VALUES
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
('pro', 25, 200, 1, false, true, '["markdown", "text", "html", "pdf"]'),
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html", "pdf"]')
""")
# 4. account_invites table
op.create_table(
'account_invites',
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
sa.Column('invited_by_id', UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('code', sa.String(32), nullable=False),
sa.Column('role', sa.String(50), nullable=False, server_default='engineer'),
sa.Column('accepted_by_id', UUID(as_uuid=True), nullable=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
sa.Column('used_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['invited_by_id'], ['users.id']),
sa.ForeignKeyConstraint(['accepted_by_id'], ['users.id']),
sa.UniqueConstraint('code', name='uq_account_invites_code'),
sa.CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'),
)
op.create_index('ix_account_invites_account_id', 'account_invites', ['account_id'])
op.create_index('ix_account_invites_email', 'account_invites', ['email'])
def downgrade() -> None:
op.drop_table('account_invites')
op.drop_table('plan_limits')
op.drop_table('subscriptions')
op.drop_table('accounts')

View File

@@ -0,0 +1,31 @@
"""add account_id and account_role columns to users
Revision ID: 017
Revises: 016
Create Date: 2026-02-07
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '017'
down_revision: Union[str, None] = '016'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('users', sa.Column('account_id', UUID(as_uuid=True), nullable=True))
op.add_column('users', sa.Column('account_role', sa.String(50), nullable=True))
op.create_index('ix_users_account_id', 'users', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_users_account_id', table_name='users')
op.drop_column('users', 'account_role')
op.drop_column('users', 'account_id')

View File

@@ -0,0 +1,187 @@
"""migrate existing users and teams to accounts
Revision ID: 018
Revises: 017
Create Date: 2026-02-07
This is the most critical migration. It creates a _team_account_mapping table
for deterministic cross-migration lookups, then migrates all users to accounts.
Three paths:
A) Teams with users → Account with deterministic owner
B) Teams with zero users → Account with owner_id=NULL, subscription status='orphaned'
C) Users without a team → Personal account, user is owner
"""
import secrets
import string
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '018'
down_revision: Union[str, None] = '017'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
# Characters for display codes — exclude confusing chars
DISPLAY_CODE_CHARS = string.ascii_uppercase + string.digits
DISPLAY_CODE_CHARS = DISPLAY_CODE_CHARS.replace('0', '').replace('O', '').replace('I', '').replace('1', '').replace('L', '')
def _generate_display_code(existing_codes: set) -> str:
"""Generate a unique 8-character display code."""
for _ in range(100):
code = ''.join(secrets.choice(DISPLAY_CODE_CHARS) for _ in range(8))
if code not in existing_codes:
existing_codes.add(code)
return code
raise RuntimeError("Failed to generate unique display code after 100 attempts")
def upgrade() -> None:
conn = op.get_bind()
# Create mapping table for deterministic cross-migration lookups
op.create_table(
'_team_account_mapping',
sa.Column('team_id', UUID(as_uuid=True), nullable=False),
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
sa.Column('owner_user_id', UUID(as_uuid=True), nullable=True),
sa.PrimaryKeyConstraint('team_id'),
)
existing_codes: set = set()
# --- Path A & B: Process all teams ---
teams = conn.execute(sa.text("SELECT id, name FROM teams")).fetchall()
for team in teams:
team_id = team[0]
team_name = team[1]
display_code = _generate_display_code(existing_codes)
# Find deterministic owner: team admin first, then earliest user
owner_row = conn.execute(sa.text("""
SELECT id FROM users
WHERE team_id = :tid
ORDER BY is_team_admin DESC, created_at ASC, id ASC
LIMIT 1
"""), {"tid": team_id}).fetchone()
owner_user_id = owner_row[0] if owner_row else None
# Create account
conn.execute(sa.text("""
INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at)
VALUES (gen_random_uuid(), :name, :code, :owner_id, NOW(), NOW())
"""), {"name": team_name, "code": display_code, "owner_id": owner_user_id})
# Get the account we just created
account_row = conn.execute(sa.text(
"SELECT id FROM accounts WHERE display_code = :code"
), {"code": display_code}).fetchone()
account_id = account_row[0]
# Insert mapping
conn.execute(sa.text("""
INSERT INTO _team_account_mapping (team_id, account_id, owner_user_id)
VALUES (:tid, :aid, :uid)
"""), {"tid": team_id, "aid": account_id, "uid": owner_user_id})
if owner_user_id is not None:
# Path A: Team with users
# Create active subscription
conn.execute(sa.text("""
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW())
"""), {"aid": account_id})
# Update all users in this team
# Team admins become owners, others keep their role
conn.execute(sa.text("""
UPDATE users SET
account_id = :aid,
account_role = CASE
WHEN is_team_admin = true THEN 'owner'
ELSE role
END
WHERE team_id = :tid
"""), {"aid": account_id, "tid": team_id})
else:
# Path B: Team with zero users (orphan)
conn.execute(sa.text("""
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
VALUES (gen_random_uuid(), :aid, 'free', 'orphaned', NOW(), NOW())
"""), {"aid": account_id})
# --- Path C: Users without a team ---
teamless_users = conn.execute(sa.text(
"SELECT id, name FROM users WHERE team_id IS NULL AND account_id IS NULL"
)).fetchall()
for user in teamless_users:
user_id = user[0]
user_name = user[1]
display_code = _generate_display_code(existing_codes)
# Create personal account (owner_id set to NULL initially)
conn.execute(sa.text("""
INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at)
VALUES (gen_random_uuid(), :name, :code, NULL, NOW(), NOW())
"""), {"name": f"{user_name}'s Account", "code": display_code})
account_row = conn.execute(sa.text(
"SELECT id FROM accounts WHERE display_code = :code"
), {"code": display_code}).fetchone()
account_id = account_row[0]
# Update user
conn.execute(sa.text("""
UPDATE users SET account_id = :aid, account_role = 'owner'
WHERE id = :uid
"""), {"aid": account_id, "uid": user_id})
# Set owner
conn.execute(sa.text("""
UPDATE accounts SET owner_id = :uid WHERE id = :aid
"""), {"uid": user_id, "aid": account_id})
# Create free subscription
conn.execute(sa.text("""
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW())
"""), {"aid": account_id})
# --- Validation ---
orphaned_users = conn.execute(sa.text(
"SELECT COUNT(*) FROM users WHERE account_id IS NULL"
)).scalar()
if orphaned_users > 0:
raise RuntimeError(
f"Migration 018 failed validation: {orphaned_users} users still have NULL account_id"
)
team_count = conn.execute(sa.text("SELECT COUNT(*) FROM teams")).scalar()
mapping_count = conn.execute(sa.text("SELECT COUNT(*) FROM _team_account_mapping")).scalar()
if mapping_count != team_count:
raise RuntimeError(
f"Migration 018 failed: mapping count ({mapping_count}) != team count ({team_count})"
)
def downgrade() -> None:
conn = op.get_bind()
# Clear account data from users
conn.execute(sa.text("UPDATE users SET account_id = NULL, account_role = NULL"))
# Delete all subscriptions and accounts created by this migration
conn.execute(sa.text("DELETE FROM subscriptions"))
conn.execute(sa.text("DELETE FROM accounts"))
# Drop mapping table
op.drop_table('_team_account_mapping')

View File

@@ -0,0 +1,56 @@
"""add account_id to content tables and backfill from _team_account_mapping
Revision ID: 019
Revises: 018
Create Date: 2026-02-07
Uses the _team_account_mapping table from migration 018 for deterministic
FK backfill instead of non-deterministic LIMIT 1 subqueries.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '019'
down_revision: Union[str, None] = '018'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
# Tables that have team_id and need account_id
CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories']
def upgrade() -> None:
conn = op.get_bind()
for table in CONTENT_TABLES:
# Add account_id column
op.add_column(table, sa.Column('account_id', UUID(as_uuid=True), nullable=True))
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
# Backfill from mapping table (deterministic)
conn.execute(sa.text(f"""
UPDATE {table} SET account_id = m.account_id
FROM _team_account_mapping m
WHERE {table}.team_id = m.team_id
"""))
# Validate: no rows with team_id but missing account_id
orphaned = conn.execute(sa.text(f"""
SELECT COUNT(*) FROM {table}
WHERE team_id IS NOT NULL AND account_id IS NULL
""")).scalar()
if orphaned > 0:
raise RuntimeError(
f"Migration 019 failed: {table} has {orphaned} rows with team_id but no account_id"
)
def downgrade() -> None:
for table in reversed(CONTENT_TABLES):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,104 @@
"""finalize account migration — add constraints, clean up orphans
Revision ID: 020
Revises: 019
Create Date: 2026-02-07
Adds NOT NULL constraints, foreign keys, and CHECK constraints.
Cleans up orphan accounts (zero-user teams with no content).
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '020'
down_revision: Union[str, None] = '019'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories']
def upgrade() -> None:
conn = op.get_bind()
# 1. Clean up orphan accounts (zero-user teams with no content)
conn.execute(sa.text("""
DELETE FROM subscriptions WHERE account_id IN (
SELECT a.id FROM accounts a
WHERE a.owner_id IS NULL
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id)
)
"""))
# Also remove the mapping entries for these orphans
conn.execute(sa.text("""
DELETE FROM _team_account_mapping WHERE account_id IN (
SELECT a.id FROM accounts a
WHERE a.owner_id IS NULL
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id)
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id)
)
"""))
conn.execute(sa.text("""
DELETE FROM accounts
WHERE owner_id IS NULL
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = accounts.id)
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = accounts.id)
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = accounts.id)
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = accounts.id)
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = accounts.id)
"""))
# 2. Users: enforce NOT NULL and add FK + CHECK
op.alter_column('users', 'account_id', nullable=False)
op.alter_column('users', 'account_role', nullable=False)
op.create_foreign_key(
'fk_users_account_id', 'users', 'accounts',
['account_id'], ['id'], ondelete='CASCADE'
)
op.create_check_constraint(
'ck_users_account_role_enum', 'users',
"account_role IN ('owner', 'engineer', 'viewer')"
)
# 3. Content tables: add FK on account_id (nullable OK — NULL means global)
for table in CONTENT_TABLES:
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE'
)
# 4. Accounts: add owner FK (owner_id stays nullable due to circular FK with users)
op.create_foreign_key(
'fk_accounts_owner_id', 'accounts', 'users',
['owner_id'], ['id'], ondelete='RESTRICT'
)
def downgrade() -> None:
# Remove account owner FK and nullable constraint
op.drop_constraint('fk_accounts_owner_id', 'accounts', type_='foreignkey')
op.alter_column('accounts', 'owner_id', nullable=True)
# Remove content table FKs
for table in reversed(CONTENT_TABLES):
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
# Remove user constraints
op.drop_constraint('ck_users_account_role_enum', 'users', type_='check')
op.drop_constraint('fk_users_account_id', 'users', type_='foreignkey')
op.alter_column('users', 'account_role', nullable=True)
op.alter_column('users', 'account_id', nullable=True)

View File

@@ -0,0 +1,30 @@
"""fix accounts.owner_id to be nullable (circular FK with users)
Revision ID: 021
Revises: 020
Create Date: 2026-02-07
The original migration 020 enforced NOT NULL on owner_id, but this creates
a circular FK problem: Account needs owner_id (User) and User needs
account_id (Account). Making owner_id nullable resolves this — we create
the Account first, then the User, then set owner_id.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '021'
down_revision: Union[str, None] = '020'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column('accounts', 'owner_id', nullable=True)
def downgrade() -> None:
op.alter_column('accounts', 'owner_id', nullable=False)

View File

@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Optional
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
@@ -8,6 +8,7 @@ from sqlalchemy import select
from app.core.database import get_db
from app.core.security import decode_token
from app.models.user import User
from app.models.plan_limits import PlanLimits
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -90,14 +91,35 @@ async def require_admin(
async def require_engineer_or_admin(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> User:
"""Require engineer, team admin, or super admin role (blocks viewers)."""
"""Require engineer, account owner, or super admin role (blocks viewers)."""
if current_user.is_super_admin:
return current_user
if current_user.is_team_admin and current_user.team_id is not None:
if current_user.account_role in ("owner", "engineer"):
return current_user
if current_user.role not in ("engineer",):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Engineer or admin access required"
)
return current_user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Engineer or admin access required"
)
async def require_account_owner(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> User:
"""Require account owner or super admin access."""
if current_user.is_super_admin:
return current_user
if current_user.account_role == "owner":
return current_user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account owner access required"
)
async def get_plan_limits_for_user(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> Optional[PlanLimits]:
"""Get plan limits for the current user's account."""
from app.core.subscriptions import get_user_plan_limits
return await get_user_plan_limits(current_user.account_id, db)

View File

@@ -0,0 +1,236 @@
from datetime import datetime, timezone, timedelta
from typing import Annotated, Optional
from uuid import UUID
import secrets
import string
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.database import get_db
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
from app.models.account import Account
from app.models.account_invite import AccountInvite
from app.models.subscription import Subscription
from app.models.user import User
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
from app.schemas.user import UserResponse, AccountRoleUpdate
from app.api.deps import get_current_active_user, require_account_owner
router = APIRouter(prefix="/accounts", tags=["accounts"])
@router.get("/me", response_model=AccountResponse)
async def get_my_account(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's account."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
return account
@router.get("/me/subscription", response_model=SubscriptionDetails)
async def get_my_subscription(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get current user's subscription details including limits and usage."""
sub = await get_account_subscription(current_user.account_id, db)
if not sub:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No subscription found"
)
limits = await get_plan_limits(sub.plan, db)
if not limits:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Plan limits not configured"
)
usage = await get_account_usage(current_user.account_id, db)
return SubscriptionDetails(
subscription=SubscriptionResponse.model_validate(sub),
limits=PlanLimitsResponse.model_validate(limits),
usage=UsageResponse(**usage),
)
@router.get("/me/members", response_model=list[UserResponse])
async def get_my_members(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get members of current user's account."""
result = await db.execute(
select(User).where(User.account_id == current_user.account_id)
.order_by(User.created_at)
)
return result.scalars().all()
@router.patch("/me", response_model=AccountResponse)
async def update_my_account(
data: AccountUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Update account settings (owner only)."""
result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = result.scalar_one_or_none()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Account not found"
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(account, field, value)
await db.commit()
await db.refresh(account)
return account
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
async def update_member_role(
user_id: UUID,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Change a member's role within the account (owner only)."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot change your own role"
)
user.account_role = data.account_role
await db.commit()
await db.refresh(user)
return user
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_member(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Remove a member from the account (owner only).
The removed user gets a new personal account.
"""
result = await db.execute(
select(User).where(
User.id == user_id,
User.account_id == current_user.account_id
)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in your account"
)
if user.id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot remove yourself from your own account"
)
# Create a personal account for the removed user
chars = string.ascii_uppercase + string.digits
display_code = ''.join(secrets.choice(chars) for _ in range(8))
new_account = Account(
name=f"{user.name}'s Account",
display_code=display_code,
owner_id=user.id,
)
db.add(new_account)
await db.flush()
new_sub = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_sub)
user.account_id = new_account.id
user.account_role = "owner"
await db.commit()
return None
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
async def create_invite(
data: AccountInviteCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""Create an invite to join this account (owner only)."""
code = secrets.token_urlsafe(16)
expires_at = None
if data.expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
invite = AccountInvite(
account_id=current_user.account_id,
invited_by_id=current_user.id,
email=data.email,
code=code,
role=data.role,
expires_at=expires_at,
)
db.add(invite)
await db.commit()
await db.refresh(invite)
return invite
@router.get("/me/invites", response_model=list[AccountInviteResponse])
async def list_invites(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_account_owner)]
):
"""List invites for this account (owner only)."""
result = await db.execute(
select(AccountInvite)
.where(AccountInvite.account_id == current_user.account_id)
.order_by(AccountInvite.created_at.desc())
)
return result.scalars().all()

View File

@@ -7,7 +7,7 @@ from sqlalchemy import select, func
from app.core.database import get_db
from app.core.audit import log_audit
from app.models.user import User
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
from app.api.deps import require_admin
router = APIRouter(prefix="/admin", tags=["admin"])
@@ -21,7 +21,7 @@ async def list_users(
limit: int = Query(100, ge=1, le=100),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
role: Optional[str] = Query(None, description="Filter by role"),
team_id: Optional[UUID] = Query(None, description="Filter by team")
account_id: Optional[UUID] = Query(None, description="Filter by account")
):
"""List all users (super admin only)."""
query = select(User)
@@ -30,8 +30,8 @@ async def list_users(
query = query.where(User.is_active == is_active)
if role:
query = query.where(User.role == role)
if team_id:
query = query.where(User.team_id == team_id)
if account_id:
query = query.where(User.account_id == account_id)
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
@@ -91,14 +91,14 @@ async def update_user_role(
return user
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
async def toggle_team_admin(
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
async def update_account_role(
user_id: UUID,
data: TeamAdminUpdate,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Toggle is_team_admin for a user (super admin only)."""
"""Change a user's account role (super admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@@ -108,15 +108,10 @@ async def toggle_team_admin(
detail="User not found"
)
if data.is_team_admin and user.team_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User must belong to a team to be a team admin"
)
user.is_team_admin = data.is_team_admin
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
{"is_team_admin": data.is_team_admin})
old_role = user.account_role
user.account_role = data.account_role
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
{"old_account_role": old_role, "new_account_role": data.account_role})
await db.commit()
await db.refresh(user)
return user

View File

@@ -1,3 +1,5 @@
import secrets
import string
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
@@ -18,6 +20,9 @@ from app.core.security import (
from app.models.user import User
from app.models.invite_code import InviteCode
from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.schemas.user import UserCreate, UserResponse, UserLogin
from app.schemas.token import Token
from app.api.deps import get_current_active_user, get_refresh_token_payload
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
db.add(token_record)
def _generate_display_code() -> str:
"""Generate a random 8-character alphanumeric display code."""
chars = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(chars) for _ in range(8))
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("3/minute")
async def register(
@@ -44,10 +55,46 @@ async def register(
user_data: UserCreate,
db: Annotated[AsyncSession, Depends(get_db)]
):
"""Register a new user."""
# Validate invite code if required
"""Register a new user.
Supports two flows:
- account_invite_code: Join an existing account (bypasses platform invite gate)
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
After user creation, if no account invite was used, a personal Account
and free Subscription are created automatically.
"""
# Check for account invite code FIRST — bypasses platform invite gate
account_invite_record = None
if user_data.account_invite_code:
result = await db.execute(
select(AccountInvite).where(
AccountInvite.code == user_data.account_invite_code
)
)
account_invite_record = result.scalar_one_or_none()
if not account_invite_record:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid account invite code"
)
if account_invite_record.is_used:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has already been used"
)
if account_invite_record.is_expired:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Account invite code has expired"
)
# Validate platform invite code if required (skip if account invite was provided)
invite_code_record = None
if settings.REQUIRE_INVITE_CODE:
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
if not user_data.invite_code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -87,17 +134,55 @@ async def register(
detail="Email already registered"
)
# Create new user
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None
)
db.add(new_user)
if account_invite_record:
# Join existing account via account invite
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=account_invite_record.account_id,
account_role=account_invite_record.role,
)
db.add(new_user)
await db.flush()
# Mark invite code as used
# Mark account invite as used
account_invite_record.accepted_by_id = new_user.id
account_invite_record.used_at = datetime.now(timezone.utc)
else:
# Create personal Account first (user needs account_id for NOT NULL constraint)
new_account = Account(
name=f"{user_data.name}'s Account",
display_code=_generate_display_code(),
)
db.add(new_account)
await db.flush() # Get account ID
new_user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
name=user_data.name,
role="engineer",
invite_code_id=invite_code_record.id if invite_code_record else None,
account_id=new_account.id,
account_role="owner",
)
db.add(new_user)
await db.flush() # Get user ID
# Now set account owner and create subscription
new_account.owner_id = new_user.id
new_subscription = Subscription(
account_id=new_account.id,
plan="free",
status="active",
)
db.add(new_subscription)
# Mark platform invite code as used
if invite_code_record:
invite_code_record.used_by_id = new_user.id
invite_code_record.used_at = datetime.now(timezone.utc)

View File

@@ -28,11 +28,11 @@ async def list_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(TreeCategory)
@@ -41,19 +41,19 @@ async def list_categories(
if not include_inactive:
query = query.where(TreeCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(TreeCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(TreeCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
TreeCategory.team_id.is_(None), # Global
TreeCategory.team_id == current_user.team_id # User's team
TreeCategory.account_id.is_(None), # Global
TreeCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(TreeCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(TreeCategory.account_id.is_(None))
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
@@ -76,7 +76,7 @@ async def list_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
tree_count=tree_count
@@ -101,8 +101,8 @@ async def get_category(
detail="Category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -121,7 +121,7 @@ async def get_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -138,10 +138,10 @@ async def create_category(
):
"""Create a new category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_category(current_user, category_data.team_id):
if not can_create_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this category"
@@ -150,10 +150,10 @@ async def create_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeCategory).where(
TreeCategory.slug == slug,
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -164,7 +164,7 @@ async def create_category(
# Get next display order
order_query = select(func.max(TreeCategory.display_order)).where(
TreeCategory.team_id == category_data.team_id
TreeCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -173,7 +173,7 @@ async def create_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -186,7 +186,7 @@ async def create_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -227,7 +227,7 @@ async def update_category(
# Check for duplicate slug
existing_query = select(TreeCategory).where(
TreeCategory.slug == new_slug,
TreeCategory.team_id == category.team_id,
TreeCategory.account_id == category.account_id,
TreeCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -257,7 +257,7 @@ async def update_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -25,11 +25,11 @@ async def list_step_categories(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_inactive: bool = Query(False, description="Include inactive categories"),
team_only: bool = Query(False, description="Only show team-specific categories")
account_only: bool = Query(False, description="Only show account-specific categories")
):
"""List step categories visible to the user.
Returns global categories plus team-specific categories for the user's team.
Returns global categories plus account-specific categories for the user's account.
"""
# Build query for accessible categories
query = select(StepCategory)
@@ -38,19 +38,19 @@ async def list_step_categories(
if not include_inactive:
query = query.where(StepCategory.is_active == True)
# Filter by visibility: global OR user's team
if team_only and current_user.team_id:
query = query.where(StepCategory.team_id == current_user.team_id)
elif current_user.team_id:
# Filter by visibility: global OR user's account
if account_only and current_user.account_id:
query = query.where(StepCategory.account_id == current_user.account_id)
elif current_user.account_id:
query = query.where(
or_(
StepCategory.team_id.is_(None), # Global
StepCategory.team_id == current_user.team_id # User's team
StepCategory.account_id.is_(None), # Global
StepCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no team, only show global categories
query = query.where(StepCategory.team_id.is_(None))
# User has no account, only show global categories
query = query.where(StepCategory.account_id.is_(None))
query = query.order_by(StepCategory.display_order, StepCategory.name)
@@ -66,7 +66,7 @@ async def list_step_categories(
name=cat.name,
slug=cat.slug,
description=cat.description,
team_id=cat.team_id,
account_id=cat.account_id,
display_order=cat.display_order,
is_active=cat.is_active,
step_count=0 # Will be computed when step_library exists
@@ -91,8 +91,8 @@ async def get_step_category(
detail="Step category not found"
)
# Check access: global categories visible to all, team categories only to team members
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this step category"
@@ -103,7 +103,7 @@ async def get_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,
@@ -120,10 +120,10 @@ async def create_step_category(
):
"""Create a new step category.
- Global admins can create global categories (team_id=None)
- Team admins can create team-specific categories for their team
- Global admins can create global categories (account_id=None)
- Account admins can create account-specific categories for their account
"""
if not can_create_step_category(current_user, category_data.team_id):
if not can_create_step_category(current_user, category_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this step category"
@@ -132,10 +132,10 @@ async def create_step_category(
# Generate slug
slug = slugify(category_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(StepCategory).where(
StepCategory.slug == slug,
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -146,7 +146,7 @@ async def create_step_category(
# Get next display order
order_query = select(func.max(StepCategory.display_order)).where(
StepCategory.team_id == category_data.team_id
StepCategory.account_id == category_data.account_id
)
order_result = await db.execute(order_query)
max_order = order_result.scalar() or 0
@@ -155,7 +155,7 @@ async def create_step_category(
name=category_data.name,
slug=slug,
description=category_data.description,
team_id=category_data.team_id,
account_id=category_data.account_id,
display_order=max_order + 1,
created_by=current_user.id
)
@@ -168,7 +168,7 @@ async def create_step_category(
name=new_category.name,
slug=new_category.slug,
description=new_category.description,
team_id=new_category.team_id,
account_id=new_category.account_id,
display_order=new_category.display_order,
is_active=new_category.is_active,
created_at=new_category.created_at,
@@ -209,7 +209,7 @@ async def update_step_category(
# Check for duplicate slug
existing_query = select(StepCategory).where(
StepCategory.slug == new_slug,
StepCategory.team_id == category.team_id,
StepCategory.account_id == category.account_id,
StepCategory.id != category_id
)
existing = await db.execute(existing_query)
@@ -231,7 +231,7 @@ async def update_step_category(
name=category.name,
slug=category.slug,
description=category.description,
team_id=category.team_id,
account_id=category.account_id,
display_order=category.display_order,
is_active=category.is_active,
created_at=category.created_at,

View File

@@ -55,10 +55,10 @@ async def get_step_or_404(
def build_visibility_filter(user: User):
"""Build SQLAlchemy filter for step visibility based on user."""
if user.team_id:
if user.account_id:
return or_(
StepLibrary.visibility == 'public',
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
StepLibrary.created_by == user.id # Own private steps
)
else:
@@ -249,7 +249,7 @@ async def get_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -296,10 +296,10 @@ async def create_step(
if not cat_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Invalid category")
# Team validation: can only set team_id to own team
team_id = step_data.team_id
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another team")
# Account validation: can only set account_id to own account
account_id = step_data.account_id
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another account")
step = StepLibrary(
title=step_data.title,
@@ -309,7 +309,7 @@ async def create_step(
tags=step_data.tags,
visibility=step_data.visibility,
created_by=current_user.id,
team_id=team_id or current_user.team_id,
account_id=account_id or current_user.account_id,
)
db.add(step)
@@ -326,7 +326,7 @@ async def create_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
@@ -393,7 +393,7 @@ async def update_step(
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"team_id": step.team_id,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,

View File

@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
async def list_tags(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""List tags visible to the user.
Returns global tags plus team-specific tags for the user's team.
Returns global tags plus account-specific tags for the user's account.
Tags are ordered by usage count (most used first).
"""
query = select(TreeTag)
# Filter by visibility: global OR user's team
if include_team and current_user.team_id:
# Filter by visibility: global OR user's account
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None), # Global
TreeTag.team_id == current_user.team_id # User's team
TreeTag.account_id.is_(None), # Global
TreeTag.account_id == current_user.account_id # User's account
)
)
else:
# Only show global tags
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
@@ -55,7 +55,7 @@ async def search_tags(
current_user: Annotated[User, Depends(get_current_active_user)],
q: str = Query(..., min_length=1, description="Search query"),
limit: int = Query(10, ge=1, le=50),
include_team: bool = Query(True, description="Include team-specific tags")
include_account: bool = Query(True, description="Include account-specific tags")
):
"""Search/autocomplete tags.
@@ -68,15 +68,15 @@ async def search_tags(
)
# Filter by visibility
if include_team and current_user.team_id:
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == current_user.team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == current_user.account_id
)
)
else:
query = query.where(TreeTag.team_id.is_(None))
query = query.where(TreeTag.account_id.is_(None))
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
@@ -102,8 +102,8 @@ async def get_tag(
detail="Tag not found"
)
# Check access: global tags visible to all, team tags only to team members
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
# Check access: global tags visible to all, account tags only to account members
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tag"
@@ -120,10 +120,10 @@ async def create_tag(
):
"""Create a new tag.
- Global admins can create global tags (team_id=None)
- Team members can create team-specific tags for their team
- Global admins can create global tags (account_id=None)
- Account members can create account-specific tags for their account
"""
if not can_create_tag(current_user, tag_data.team_id):
if not can_create_tag(current_user, tag_data.account_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to create this tag"
@@ -132,10 +132,10 @@ async def create_tag(
# Generate slug
slug = TreeTag.slugify(tag_data.name)
# Check for duplicate slug within same scope (global or team)
# Check for duplicate slug within same scope (global or account)
existing_query = select(TreeTag).where(
TreeTag.slug == slug,
TreeTag.team_id == tag_data.team_id
TreeTag.account_id == tag_data.account_id
)
existing = await db.execute(existing_query)
if existing.scalar_one_or_none():
@@ -147,7 +147,7 @@ async def create_tag(
new_tag = TreeTag(
name=tag_data.name,
slug=slug,
team_id=tag_data.team_id,
account_id=tag_data.account_id,
created_by=current_user.id
)
db.add(new_tag)
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
continue
# Try to find existing tag
# Determine scope: use tree's team, or global for admin-owned trees
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
# Determine scope: use tree's account, or global for admin-owned trees
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None), # Global tag
TreeTag.team_id == tag_team_id # Team tag
TreeTag.account_id.is_(None), # Global tag
TreeTag.account_id == tag_account_id # Account tag
)
)
tag_result = await db.execute(tag_query)
tag = tag_result.scalar_one_or_none()
if not tag:
# Create new tag - prefer team scope unless admin creating on public tree
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
# Fall back to user's team if they can't create in tree's scope
new_team_id = current_user.team_id
# Create new tag - prefer account scope unless admin creating on public tree
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
# Fall back to user's account if they can't create in tree's scope
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -331,7 +331,7 @@ async def replace_tree_tags(
tree.tags.clear()
# Add new tags
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
for tag_name in tag_data.tags:
slug = TreeTag.slugify(tag_name)
@@ -340,8 +340,8 @@ async def replace_tree_tags(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tag_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tag_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -349,14 +349,14 @@ async def replace_tree_tags(
if not tag:
# Create new tag
new_team_id = tag_team_id
if not can_create_tag(current_user, new_team_id):
new_team_id = current_user.team_id
new_account_id = tag_account_id
if not can_create_tag(current_user, new_account_id):
new_account_id = current_user.account_id
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=new_team_id,
account_id=new_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -397,7 +397,7 @@ async def get_tree_tags(
# Check if user can view the tree
if not tree.is_public:
if tree.author_id != current_user.id:
if tree.team_id != current_user.team_id:
if tree.account_id != current_user.account_id:
if not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.subscriptions import check_tree_limit
from app.core.audit import log_audit
router = APIRouter(prefix="/trees", tags=["trees"])
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
Tree.is_public == True,
Tree.author_id == current_user.id,
]
if current_user.team_id:
conditions.append(Tree.team_id == current_user.team_id)
if current_user.account_id:
conditions.append(Tree.account_id == current_user.account_id)
return or_(*conditions)
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
category_info=category_info,
tags=tree.tag_names,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
tags=tree.tag_names,
tree_structure=tree.tree_structure,
author_id=tree.author_id,
team_id=tree.team_id,
account_id=tree.account_id,
is_active=tree.is_active,
is_public=tree.is_public,
is_default=tree.is_default,
@@ -289,7 +290,7 @@ async def create_tree(
detail="Category not found"
)
# Check category access
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -302,16 +303,25 @@ async def create_tree(
category_id=tree_data.category_id,
tree_structure=tree_data.tree_structure,
author_id=None if is_default else current_user.id, # Default trees have no author
team_id=None if is_default else current_user.team_id,
account_id=None if is_default else current_user.account_id,
is_public=True if is_default else tree_data.is_public, # Default trees are always public
is_default=is_default
)
# Check subscription tree limit
if not is_default and current_user.account_id:
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
if not can_create:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
)
db.add(new_tree)
await db.flush() # Get the ID
# Handle tags
if tree_data.tags:
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
# Collect tags to add
tags_to_add = []
@@ -322,8 +332,8 @@ async def create_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -334,7 +344,7 @@ async def create_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)
@@ -420,7 +430,7 @@ async def update_tree(
status_code=status.HTTP_404_NOT_FOUND,
detail="Category not found"
)
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this category"
@@ -450,7 +460,7 @@ async def update_tree(
)
# Add new tags
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
added_tag_ids = set()
for tag_name in tags_data:
@@ -459,8 +469,8 @@ async def update_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.team_id.is_(None),
TreeTag.team_id == tree_team_id
TreeTag.account_id.is_(None),
TreeTag.account_id == tree_account_id
)
)
tag_result = await db.execute(tag_query)
@@ -470,7 +480,7 @@ async def update_tree(
tag = TreeTag(
name=tag_name,
slug=slug,
team_id=tree_team_id,
account_id=tree_account_id,
created_by=current_user.id
)
db.add(tag)

View File

@@ -0,0 +1,62 @@
import logging
from fastapi import APIRouter, Request, HTTPException, status, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.core.config import settings
from app.core.stripe_handlers import WEBHOOK_HANDLERS
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
@router.post("/stripe")
async def stripe_webhook(
request: Request,
db: AsyncSession = Depends(get_db),
):
"""Handle Stripe webhook events.
Returns 200 for all events to prevent Stripe retries.
Actual processing happens only when Stripe is configured.
"""
if not settings.stripe_enabled:
return {"status": "ok", "message": "Stripe not configured, event ignored"}
payload = await request.body()
sig_header = request.headers.get("stripe-signature")
if not sig_header:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing stripe-signature header"
)
# Verify webhook signature
try:
import stripe
stripe.api_key = settings.STRIPE_SECRET_KEY
event = stripe.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except ImportError:
logger.warning("stripe package not installed, cannot verify webhook")
return {"status": "ok", "message": "stripe package not installed"}
except Exception as e:
logger.error("Stripe webhook signature verification failed: %s", e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature"
)
event_type = event.get("type", "")
handler = WEBHOOK_HANDLERS.get(event_type)
if handler:
try:
await handler(event, db)
except Exception:
logger.exception("Error handling Stripe event %s", event_type)
return {"status": "ok"}

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
api_router.include_router(step_categories.router)
api_router.include_router(steps.router)
api_router.include_router(admin.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)

View File

@@ -52,6 +52,16 @@ class Settings(BaseSettings):
# Registration
REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration
# Stripe
STRIPE_SECRET_KEY: Optional[str] = None
STRIPE_PUBLISHABLE_KEY: Optional[str] = None
STRIPE_WEBHOOK_SECRET: Optional[str] = None
@property
def stripe_enabled(self) -> bool:
"""Check if Stripe is configured."""
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
# CORS - set FRONTEND_URL in production (e.g., https://patherly.up.railway.app)
CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173", "http://localhost:5174"]
FRONTEND_URL: Optional[str] = None

View File

@@ -1,12 +1,12 @@
"""
Centralized permission checks for ResolutionFlow.
Role hierarchy: super_admin > team_admin > engineer > viewer
Role hierarchy: super_admin > owner > engineer > viewer
- super_admin: is_super_admin=True, full system access
- team_admin: is_team_admin=True + valid team_id, manage team resources
- engineer: role='engineer' (default), CRUD own trees/steps
- viewer: role='viewer', read-only (can browse, run sessions, rate steps)
- owner: account_role='owner', manage account resources
- engineer: account_role='engineer' (default), CRUD own trees/steps
- viewer: account_role='viewer', read-only (can browse, run sessions, rate steps)
"""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
@@ -21,19 +21,19 @@ if TYPE_CHECKING:
ROLE_HIERARCHY = {
"super_admin": 4,
"team_admin": 3,
"owner": 3,
"engineer": 2,
"viewer": 1,
}
def get_effective_role(user: User) -> str:
"""Get the effective role considering is_super_admin and is_team_admin flags."""
"""Get the effective role considering is_super_admin and account_role."""
if user.is_super_admin:
return "super_admin"
if user.is_team_admin and user.team_id is not None:
return "team_admin"
return user.role # "engineer" or "viewer"
if user.account_role == "owner":
return "owner"
return user.account_role # "engineer" or "viewer"
def has_minimum_role(user: User, minimum_role: str) -> bool:
@@ -55,7 +55,7 @@ def can_edit_tree(user: User, tree: Tree) -> bool:
return False
if tree.author_id == user.id:
return True
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None:
return True
return False
@@ -78,7 +78,7 @@ def can_manage_category(user: User, category: TreeCategory) -> bool:
"""Can the user edit/delete this category?"""
if user.is_super_admin:
return True
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None:
return True
return False
@@ -91,7 +91,7 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool:
return False
if tree.author_id == user.id:
return True
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None:
return True
return False
@@ -102,7 +102,7 @@ def can_access_tree(user: User, tree: Tree) -> bool:
return True
if tree.author_id == user.id:
return True
if tree.team_id == user.team_id and user.team_id is not None:
if tree.account_id == user.account_id and user.account_id is not None:
return True
if user.is_super_admin:
return True
@@ -116,35 +116,35 @@ def can_view_step(user: User, step: StepLibrary) -> bool:
if step.visibility == "private":
return step.created_by == user.id
if step.visibility == "team":
return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin
return (step.account_id == user.account_id and user.account_id is not None) or user.is_super_admin
return False
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
def can_create_tag(user: User, account_id: Optional[UUID]) -> bool:
"""Can the user create a tag for the given scope?
- Super admins can create global tags (team_id=None) or any team's tags
- Engineers can create team tags for their own team
- Super admins can create global tags (account_id=None) or any account's tags
- Engineers can create account tags for their own account
- Viewers cannot create tags
"""
if user.is_super_admin:
return True
if not can_create_content(user):
return False
if team_id is not None and team_id == user.team_id:
if account_id is not None and account_id == user.account_id:
return True
return False
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
"""Can the user create a category for the given team?
def can_create_category(user: User, account_id: Optional[UUID]) -> bool:
"""Can the user create a category for the given account?
- Super admins can create global or any team's categories
- Team admins can create categories for their own team
- Super admins can create global or any account's categories
- Account owners can create categories for their own account
"""
if user.is_super_admin:
return True
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None:
return True
return False
@@ -153,19 +153,19 @@ def can_manage_step_category(user: User, category: StepCategory) -> bool:
"""Can the user edit/delete this step category?"""
if user.is_super_admin:
return True
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None:
return True
return False
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
"""Can the user create a step category for the given team?
def can_create_step_category(user: User, account_id: Optional[UUID]) -> bool:
"""Can the user create a step category for the given account?
- Super admins can create global or any team's step categories
- Team admins can create step categories for their own team
- Super admins can create global or any account's step categories
- Account owners can create step categories for their own account
"""
if user.is_super_admin:
return True
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None:
return True
return False

View File

@@ -0,0 +1,37 @@
"""Stripe webhook event handlers (stub implementations).
These handlers log events but don't process them until Stripe is fully configured.
"""
import logging
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
async def handle_checkout_completed(event: dict, db: AsyncSession) -> None:
logger.info("Stripe: checkout.session.completed — %s", event.get("id"))
async def handle_invoice_paid(event: dict, db: AsyncSession) -> None:
logger.info("Stripe: invoice.paid — %s", event.get("id"))
async def handle_invoice_payment_failed(event: dict, db: AsyncSession) -> None:
logger.warning("Stripe: invoice.payment_failed — %s", event.get("id"))
async def handle_subscription_updated(event: dict, db: AsyncSession) -> None:
logger.info("Stripe: customer.subscription.updated — %s", event.get("id"))
async def handle_subscription_deleted(event: dict, db: AsyncSession) -> None:
logger.info("Stripe: customer.subscription.deleted — %s", event.get("id"))
WEBHOOK_HANDLERS = {
"checkout.session.completed": handle_checkout_completed,
"invoice.paid": handle_invoice_paid,
"invoice.payment_failed": handle_invoice_payment_failed,
"customer.subscription.updated": handle_subscription_updated,
"customer.subscription.deleted": handle_subscription_deleted,
}

View File

@@ -0,0 +1,113 @@
"""Subscription limit checks and plan helpers."""
from typing import Optional
from uuid import UUID
from datetime import datetime, timezone
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.subscription import Subscription
from app.models.plan_limits import PlanLimits
from app.models.tree import Tree
from app.models.session import Session
async def get_account_subscription(account_id: UUID, db: AsyncSession) -> Optional[Subscription]:
result = await db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)
return result.scalar_one_or_none()
async def get_plan_limits(plan: str, db: AsyncSession) -> Optional[PlanLimits]:
result = await db.execute(
select(PlanLimits).where(PlanLimits.plan == plan)
)
return result.scalar_one_or_none()
async def get_user_plan_limits(user_account_id: UUID, db: AsyncSession) -> Optional[PlanLimits]:
sub = await get_account_subscription(user_account_id, db)
if sub is None:
return await get_plan_limits("free", db)
return await get_plan_limits(sub.plan, db)
async def check_tree_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]:
"""Check if account can create a new tree.
Returns: (can_create, limit, current_count)
"""
sub = await get_account_subscription(account_id, db)
if sub is None:
return False, 0, 0
limits = await get_plan_limits(sub.plan, db)
if limits is None or limits.max_trees is None:
return True, None, 0 # unlimited
current_count = await db.scalar(
select(func.count(Tree.id)).where(
Tree.account_id == account_id,
Tree.deleted_at.is_(None),
)
)
current_count = current_count or 0
return current_count < limits.max_trees, limits.max_trees, current_count
async def check_session_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]:
"""Check if account can create a new session this month.
Returns: (can_create, limit, current_count)
"""
sub = await get_account_subscription(account_id, db)
if sub is None:
return False, 0, 0
limits = await get_plan_limits(sub.plan, db)
if limits is None or limits.max_sessions_per_month is None:
return True, None, 0 # unlimited
# Count sessions this calendar month for all users in this account
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
from app.models.user import User
current_count = await db.scalar(
select(func.count(Session.id)).where(
Session.user_id.in_(
select(User.id).where(User.account_id == account_id)
),
Session.started_at >= month_start,
)
)
current_count = current_count or 0
return current_count < limits.max_sessions_per_month, limits.max_sessions_per_month, current_count
async def get_account_usage(account_id: UUID, db: AsyncSession) -> dict:
"""Get current usage stats for an account."""
tree_count = await db.scalar(
select(func.count(Tree.id)).where(
Tree.account_id == account_id,
Tree.deleted_at.is_(None),
)
) or 0
now = datetime.now(timezone.utc)
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
from app.models.user import User
session_count = await db.scalar(
select(func.count(Session.id)).where(
Session.user_id.in_(
select(User.id).where(User.account_id == account_id)
),
Session.started_at >= month_start,
)
) or 0
return {"tree_count": tree_count, "session_count_this_month": session_count}

View File

@@ -1,5 +1,9 @@
from .user import User
from .team import Team
from .account import Account
from .subscription import Subscription
from .plan_limits import PlanLimits
from .account_invite import AccountInvite
from .tree import Tree
from .session import Session
from .attachment import Attachment
@@ -15,6 +19,10 @@ from .audit_log import AuditLog
__all__ = [
"User",
"Team",
"Account",
"Subscription",
"PlanLimits",
"AccountInvite",
"Tree",
"Session",
"Attachment",

View File

@@ -0,0 +1,38 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
from app.models.subscription import Subscription
from app.models.tree import Tree
from app.models.category import TreeCategory
from app.models.tag import TreeTag
from app.models.step_category import StepCategory
from app.models.step_library import StepLibrary
class Account(Base):
__tablename__ = "accounts"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(255), nullable=False)
display_code: Mapped[str] = mapped_column(String(8), unique=True, nullable=False)
owner_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=True)
stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Relationships
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")
users: Mapped[list["User"]] = relationship("User", foreign_keys="[User.account_id]", back_populates="account")
subscription: Mapped[Optional["Subscription"]] = relationship("Subscription", back_populates="account", uselist=False)
trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.account_id]", back_populates="account")
categories: Mapped[list["TreeCategory"]] = relationship("TreeCategory", foreign_keys="[TreeCategory.account_id]", back_populates="account")
tags: Mapped[list["TreeTag"]] = relationship("TreeTag", foreign_keys="[TreeTag.account_id]", back_populates="account")
step_categories: Mapped[list["StepCategory"]] = relationship("StepCategory", foreign_keys="[StepCategory.account_id]", back_populates="account")
step_library: Mapped[list["StepLibrary"]] = relationship("StepLibrary", foreign_keys="[StepLibrary.account_id]", back_populates="account")

View File

@@ -0,0 +1,48 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, CheckConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.account import Account
from app.models.user import User
class AccountInvite(Base):
__tablename__ = "account_invites"
__table_args__ = (
CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'),
)
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False)
invited_by_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
email: Mapped[str] = mapped_column(String(255), nullable=False)
code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
accepted_by_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
# Relationships
account: Mapped["Account"] = relationship("Account")
invited_by: Mapped["User"] = relationship("User", foreign_keys=[invited_by_id])
accepted_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[accepted_by_id])
@property
def is_used(self) -> bool:
return self.accepted_by_id is not None
@property
def is_expired(self) -> bool:
if self.expires_at is None:
return False
return datetime.now(timezone.utc) > self.expires_at
@property
def is_valid(self) -> bool:
return not self.is_used and not self.is_expired

View File

@@ -9,6 +9,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.tree import Tree
from app.models.team import Team
from app.models.account import Account
from app.models.user import User
@@ -38,6 +39,12 @@ class TreeCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
@@ -57,10 +64,11 @@ class TreeCategory(Base):
# Relationships
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="categories")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="categories")
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
trees: Mapped[list["Tree"]] = relationship("Tree", back_populates="category_rel")
@property
def is_global(self) -> bool:
"""Returns True if this is a global category (not team-specific)."""
return self.team_id is None
"""Returns True if this is a global category (not team or account-specific)."""
return self.team_id is None and self.account_id is None

View File

@@ -0,0 +1,16 @@
from sqlalchemy import String, Integer, Boolean
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import JSONB
from app.core.database import Base
class PlanLimits(Base):
__tablename__ = "plan_limits"
plan: Mapped[str] = mapped_column(String(50), primary_key=True)
max_trees: Mapped[int | None] = mapped_column(Integer, nullable=True)
max_sessions_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True)
max_users: Mapped[int | None] = mapped_column(Integer, nullable=True)
custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"])

View File

@@ -8,6 +8,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.team import Team
from app.models.account import Account
from app.models.user import User
@@ -37,6 +38,12 @@ class StepCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
@@ -56,9 +63,10 @@ class StepCategory(Base):
# Relationships
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="step_categories")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_categories")
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
@property
def is_global(self) -> bool:
"""Returns True if this is a global category (not team-specific)."""
return self.team_id is None
"""Returns True if this is a global category (not team or account-specific)."""
return self.team_id is None and self.account_id is None

View File

@@ -10,6 +10,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
from app.models.team import Team
from app.models.account import Account
from app.models.step_category import StepCategory
from app.models.session import Session
@@ -43,6 +44,12 @@ class StepLibrary(Base):
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
index=True
)
# Organization
category_id: Mapped[Optional[uuid.UUID]] = mapped_column(
@@ -91,6 +98,7 @@ class StepLibrary(Base):
# Relationships
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
team: Mapped[Optional["Team"]] = relationship("Team")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_library")
category: Mapped[Optional["StepCategory"]] = relationship("StepCategory")
ratings: Mapped[list["StepRating"]] = relationship("StepRating", back_populates="step", cascade="all, delete-orphan")
usage_logs: Mapped[list["StepUsageLog"]] = relationship("StepUsageLog", back_populates="step", cascade="all, delete-orphan")

View File

@@ -0,0 +1,39 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.account import Account
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), unique=True, nullable=False)
stripe_subscription_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
stripe_price_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
plan: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
billing_interval: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
status: Mapped[str] = mapped_column(String(50), nullable=False, default="active")
seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
current_period_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
current_period_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Relationships
account: Mapped["Account"] = relationship("Account", back_populates="subscription")
@property
def is_active(self) -> bool:
return self.status in ("active", "trialing")
@property
def is_paid(self) -> bool:
return self.plan in ("pro", "team")

View File

@@ -9,6 +9,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.tree import Tree
from app.models.team import Team
from app.models.account import Account
from app.models.user import User
@@ -50,6 +51,12 @@ class TreeTag(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
index=True
)
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
@@ -63,6 +70,7 @@ class TreeTag(Base):
# Relationships
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="tags")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="tags")
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
trees: Mapped[list["Tree"]] = relationship(
"Tree",
@@ -72,8 +80,8 @@ class TreeTag(Base):
@property
def is_global(self) -> bool:
"""Returns True if this is a global tag (not team-specific)."""
return self.team_id is None
"""Returns True if this is a global tag (not team or account-specific)."""
return self.team_id is None and self.account_id is None
@classmethod
def slugify(cls, name: str) -> str:

View File

@@ -9,6 +9,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
from app.models.team import Team
from app.models.account import Account
from app.models.session import Session
from app.models.category import TreeCategory
from app.models.tag import TreeTag
@@ -47,6 +48,12 @@ class Tree(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
index=True
)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_public: Mapped[bool] = mapped_column(Boolean, default=False, index=True)
is_default: Mapped[bool] = mapped_column(Boolean, default=False, index=True)
@@ -75,6 +82,7 @@ class Tree(Base):
# Relationships
author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees")
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees")
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree")
# New organization relationships

View File

@@ -8,6 +8,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.team import Team
from app.models.account import Account
from app.models.tree import Tree
from app.models.session import Session
from app.models.folder import UserFolder
@@ -20,6 +21,10 @@ class User(Base):
"role IN ('engineer', 'viewer')",
name='ck_users_role_enum'
),
CheckConstraint(
"account_role IN ('owner', 'admin', 'engineer', 'viewer')",
name='ck_users_account_role_enum'
),
)
id: Mapped[uuid.UUID] = mapped_column(
@@ -34,6 +39,17 @@ class User(Base):
is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
# Account-based multi-tenancy (new)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="RESTRICT"),
nullable=True,
index=True
)
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
# Legacy team columns (kept for PR A coexistence)
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id"),
@@ -51,6 +67,8 @@ class User(Base):
last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
# Relationships
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="users")
owned_account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys="[Account.owner_id]", back_populates="owner", uselist=False)
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="users")
trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.author_id]", back_populates="author")
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="user")
@@ -62,6 +80,11 @@ class User(Base):
return self.is_super_admin
@property
def can_manage_team(self) -> bool:
"""Returns True if user can manage their team (team admin or super admin)."""
return self.is_super_admin or (self.is_team_admin and self.team_id is not None)
def is_account_owner(self) -> bool:
"""Returns True if user owns their account."""
return self.account_role == "owner"
@property
def can_manage_account(self) -> bool:
"""Returns True if user can manage their account (owner, admin, or super admin)."""
return self.is_super_admin or self.account_role in ("owner", "admin")

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, Field
class AccountResponse(BaseModel):
id: UUID
name: str
display_code: str
owner_id: UUID
stripe_customer_id: Optional[str] = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class AccountUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=255)
class AccountInviteCreate(BaseModel):
email: str = Field(..., max_length=255)
role: str = Field("engineer", pattern="^(engineer|viewer)$")
expires_in_days: Optional[int] = Field(None, ge=1, le=30)
class AccountInviteResponse(BaseModel):
id: UUID
account_id: UUID
email: str
code: str
role: str
expires_at: Optional[datetime] = None
created_at: datetime
used_at: Optional[datetime] = None
model_config = {"from_attributes": True}

View File

@@ -20,7 +20,7 @@ class CategoryBase(BaseModel):
class CategoryCreate(CategoryBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
class CategoryUpdate(BaseModel):
@@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel):
class CategoryResponse(CategoryBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
created_at: datetime
@@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel):
name: str
slug: str
description: Optional[str] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
tree_count: int = 0

View File

@@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel):
class StepCategoryCreate(StepCategoryBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
class StepCategoryUpdate(BaseModel):
@@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel):
class StepCategoryResponse(StepCategoryBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
created_at: datetime
@@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel):
name: str
slug: str
description: Optional[str] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
display_order: int
is_active: bool
step_count: int = 0

View File

@@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel):
class StepLibraryCreate(StepLibraryBase):
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
class StepLibraryUpdate(BaseModel):
@@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel):
class StepLibraryResponse(StepLibraryBase):
id: UUID
created_by: UUID
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
rating_average: Decimal
rating_count: int

View File

@@ -0,0 +1,40 @@
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel
class SubscriptionResponse(BaseModel):
id: UUID
plan: str
status: str
billing_interval: Optional[str] = None
current_period_start: Optional[datetime] = None
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
stripe_subscription_id: Optional[str] = None
model_config = {"from_attributes": True}
class PlanLimitsResponse(BaseModel):
plan: str
max_trees: Optional[int] = None
max_sessions_per_month: Optional[int] = None
max_users: Optional[int] = None
custom_branding: bool = False
priority_support: bool = False
export_formats: list[str] = ["markdown", "text"]
model_config = {"from_attributes": True}
class UsageResponse(BaseModel):
tree_count: int
session_count_this_month: int
class SubscriptionDetails(BaseModel):
subscription: SubscriptionResponse
limits: PlanLimitsResponse
usage: UsageResponse

View File

@@ -19,13 +19,13 @@ class TagBase(BaseModel):
class TagCreate(TagBase):
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.")
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.")
class TagResponse(TagBase):
id: UUID
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
created_at: datetime
@@ -37,7 +37,7 @@ class TagListResponse(BaseModel):
id: UUID
name: str
slug: str
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
usage_count: int
class Config:
@@ -53,4 +53,4 @@ class TagSearchParams(BaseModel):
"""Query parameters for tag search/autocomplete."""
q: str = Field(..., min_length=1, description="Search query")
limit: int = Field(10, ge=1, le=50)
include_team: bool = Field(True, description="Include team-specific tags")
include_account: bool = Field(True, description="Include account-specific tags")

View File

@@ -44,7 +44,7 @@ class TreeResponse(TreeBase):
id: UUID
tree_structure: dict[str, Any]
author_id: Optional[UUID] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
category_id: Optional[UUID] = None
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
@@ -69,7 +69,7 @@ class TreeListResponse(BaseModel):
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
author_id: Optional[UUID] = None
team_id: Optional[UUID] = None
account_id: Optional[UUID] = None
is_active: bool
is_public: bool
is_default: bool

View File

@@ -13,6 +13,7 @@ class UserBase(BaseModel):
class UserCreate(UserBase):
password: str = Field(..., min_length=10, description="Password must be at least 10 characters")
invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)")
account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account")
@field_validator('password')
@classmethod
@@ -38,11 +39,11 @@ class UserLogin(BaseModel):
class UserResponse(UserBase):
id: UUID
role: str
role: str = "engineer"
account_id: UUID
account_role: str
is_super_admin: bool = False
is_team_admin: bool = False
is_active: bool = True
team_id: Optional[UUID] = None
created_at: datetime
last_login: Optional[datetime] = None
@@ -54,5 +55,5 @@ class RoleUpdate(BaseModel):
role: Literal["engineer", "viewer"]
class TeamAdminUpdate(BaseModel):
is_team_admin: bool
class AccountRoleUpdate(BaseModel):
account_role: str = Field(..., pattern="^(engineer|viewer)$")

View File

@@ -22,5 +22,8 @@ email-validator==2.1.0
# Rate Limiting
slowapi==0.1.9
# Payments
stripe==14.3.0
# Utilities
python-dotenv==1.0.1

View File

@@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
await conn.execute(sa.text("CREATE SCHEMA public"))
await conn.run_sync(Base.metadata.create_all)
# Seed plan_limits for subscription checks
await conn.execute(sa.text("""
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
VALUES
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
"""))
# Create async session maker
async_session_maker = async_sessionmaker(
engine,

View File

@@ -0,0 +1,170 @@
"""Integration tests for account management endpoints."""
import pytest
from httpx import AsyncClient
class TestAccountEndpoints:
"""Test suite for account management endpoints."""
@pytest.mark.asyncio
async def test_get_my_account(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's account."""
response = await client.get("/api/v1/accounts/me", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "name" in data
assert "display_code" in data
assert "owner_id" in data
assert len(data["display_code"]) == 8
@pytest.mark.asyncio
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
"""Test getting current user's subscription details."""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "subscription" in data
assert "limits" in data
assert "usage" in data
assert data["subscription"]["plan"] == "free"
assert data["subscription"]["status"] == "active"
assert data["limits"]["max_trees"] == 3
assert data["limits"]["max_sessions_per_month"] == 20
@pytest.mark.asyncio
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
"""Test getting members of current user's account."""
response = await client.get("/api/v1/accounts/me/members", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
# Current user should be in members list
assert any(m["account_role"] == "owner" for m in data)
@pytest.mark.asyncio
async def test_update_my_account(self, client: AsyncClient, auth_headers: dict):
"""Test updating account name."""
response = await client.patch(
"/api/v1/accounts/me",
json={"name": "Updated Account Name"},
headers=auth_headers
)
assert response.status_code == 200
assert response.json()["name"] == "Updated Account Name"
@pytest.mark.asyncio
async def test_update_account_requires_owner(self, client: AsyncClient):
"""Test that non-owners cannot update account settings."""
# Register two users
owner_data = {
"email": "owner@example.com",
"password": "OwnerPass123!",
"name": "Owner"
}
await client.post("/api/v1/auth/register", json=owner_data)
# Login as owner and create an invite
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "owner@example.com",
"password": "OwnerPass123!"
})
owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "member@example.com", "role": "engineer"},
headers=owner_headers
)
assert invite_resp.status_code == 201
invite_code = invite_resp.json()["code"]
# Register member with account invite code
member_data = {
"email": "member@example.com",
"password": "MemberPass123!",
"name": "Member",
"account_invite_code": invite_code
}
reg_resp = await client.post("/api/v1/auth/register", json=member_data)
assert reg_resp.status_code == 201
assert reg_resp.json()["account_role"] == "engineer"
# Login as member
member_login = await client.post("/api/v1/auth/login/json", json={
"email": "member@example.com",
"password": "MemberPass123!"
})
member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"}
# Member should not be able to update account
response = await client.patch(
"/api/v1/accounts/me",
json={"name": "Hacked Name"},
headers=member_headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict):
"""Test creating and listing account invites."""
# Create invite
response = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "invitee@example.com", "role": "engineer"},
headers=auth_headers
)
assert response.status_code == 201
data = response.json()
assert data["email"] == "invitee@example.com"
assert data["role"] == "engineer"
assert "code" in data
# List invites
list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers)
assert list_response.status_code == 200
invites = list_response.json()
assert len(invites) >= 1
assert any(i["email"] == "invitee@example.com" for i in invites)
@pytest.mark.asyncio
async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict):
"""Test that account invite code joins user to existing account."""
# Get current account
account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers)
account_id = account_resp.json()["id"]
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "joiner@example.com", "role": "viewer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
# Register with account invite code
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "joiner@example.com",
"password": "JoinerPass123!",
"name": "Joiner",
"account_invite_code": invite_code
})
assert reg_resp.status_code == 201
data = reg_resp.json()
assert data["account_id"] == account_id
assert data["account_role"] == "viewer"
@pytest.mark.asyncio
async def test_register_with_invalid_invite_code(self, client: AsyncClient):
"""Test that invalid account invite code is rejected."""
response = await client.post("/api/v1/auth/register", json={
"email": "bad@example.com",
"password": "BadPassword123!",
"name": "Bad User",
"account_invite_code": "INVALID_CODE"
})
assert response.status_code == 400
assert "invalid" in response.json()["detail"].lower()

View File

@@ -169,3 +169,51 @@ class TestAdminEndpoints:
log = result.scalar_one_or_none()
assert log is not None
assert str(log.resource_id) == user_id
@pytest.mark.asyncio
async def test_change_account_role(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test changing a user's account role."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "viewer"},
headers=admin_auth_headers
)
assert response.status_code == 200
assert response.json()["account_role"] == "viewer"
@pytest.mark.asyncio
async def test_change_account_role_invalid(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test that invalid account_role values are rejected."""
user_id = test_user["user_data"]["id"]
response = await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "owner"},
headers=admin_auth_headers
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_audit_log_created_on_account_role_change(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession
):
"""Test that changing account role creates an audit log entry."""
user_id = test_user["user_data"]["id"]
await client.put(
f"/api/v1/admin/users/{user_id}/account-role",
json={"account_role": "viewer"},
headers=admin_auth_headers
)
result = await test_db.execute(
select(AuditLog).where(AuditLog.action == "user.account_role_change")
)
log = result.scalar_one_or_none()
assert log is not None
assert str(log.resource_id) == user_id
assert log.details["old_account_role"] == "owner"
assert log.details["new_account_role"] == "viewer"

View File

@@ -23,6 +23,8 @@ class TestAuthentication:
assert data["email"] == user_data["email"]
assert data["name"] == user_data["name"]
assert data["role"] == "engineer"
assert "account_id" in data
assert data["account_role"] == "owner"
assert "id" in data
assert "password" not in data # Password should not be returned
@@ -107,6 +109,7 @@ class TestAuthentication:
assert response.status_code == 201
data = response.json()
assert data["role"] == "engineer"
assert data["account_role"] == "owner"
@pytest.mark.asyncio
async def test_register_default_role_is_engineer(self, client: AsyncClient):
@@ -121,6 +124,7 @@ class TestAuthentication:
assert response.status_code == 201
assert response.json()["role"] == "engineer"
assert response.json()["account_role"] == "owner"
@pytest.mark.asyncio
async def test_register_rejects_no_uppercase(self, client: AsyncClient):

View File

@@ -0,0 +1,205 @@
"""Integration tests for account-based permissions."""
import pytest
from httpx import AsyncClient
class TestAccountPermissions:
"""Test suite for account-based permission checks."""
@pytest.mark.asyncio
async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db):
"""Test that viewers cannot create trees."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Register a user
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "viewer@example.com",
"password": "ViewerPass123!",
"name": "Viewer User"
})
assert reg_resp.status_code == 201
user_id = PyUUID(reg_resp.json()["id"])
# Demote to viewer via ORM
result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one()
user.account_role = "viewer"
await test_db.commit()
# Login as viewer
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "viewer@example.com",
"password": "ViewerPass123!"
})
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Try to create tree
response = await client.post("/api/v1/trees", json={
"name": "Viewer Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"}
}, headers=viewer_headers)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db):
"""Test that viewers can browse/list trees."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Create a public tree as the regular user first
await client.post("/api/v1/trees", json={
"name": "Public Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": True
}, headers=auth_headers)
# Register viewer
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "viewer2@example.com",
"password": "ViewerPass123!",
"name": "Viewer 2"
})
user_id = PyUUID(reg_resp.json()["id"])
result = await test_db.execute(select(User).where(User.id == user_id))
user = result.scalar_one()
user.account_role = "viewer"
await test_db.commit()
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "viewer2@example.com",
"password": "ViewerPass123!"
})
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Viewer can list trees
response = await client.get("/api/v1/trees", headers=viewer_headers)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db):
"""Test that account owner can edit trees created by account members."""
from sqlalchemy import select
from app.models.user import User
from uuid import UUID as PyUUID
# Get owner's account
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id = me_resp.json()["account_id"]
# Create invite
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "engineer@example.com", "role": "engineer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
# Register engineer in same account
reg_resp = await client.post("/api/v1/auth/register", json={
"email": "engineer@example.com",
"password": "EngineerPass123!",
"name": "Engineer",
"account_invite_code": invite_code
})
assert reg_resp.status_code == 201
assert reg_resp.json()["account_id"] == account_id
# Login as engineer
eng_login = await client.post("/api/v1/auth/login/json", json={
"email": "engineer@example.com",
"password": "EngineerPass123!"
})
eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"}
# Engineer creates a tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Engineer's Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}
}, headers=eng_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Owner can edit engineer's tree
update_resp = await client.put(
f"/api/v1/trees/{tree_id}",
json={"name": "Owner Updated Name"},
headers=auth_headers
)
assert update_resp.status_code == 200
assert update_resp.json()["name"] == "Owner Updated Name"
@pytest.mark.asyncio
async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict):
"""Test that account members can see each other's non-public trees."""
# Get owner's account
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id = me_resp.json()["account_id"]
# Owner creates a private tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Private Account Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": False
}, headers=auth_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Create invite and add member
invite_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "teammate@example.com", "role": "engineer"},
headers=auth_headers
)
invite_code = invite_resp.json()["code"]
await client.post("/api/v1/auth/register", json={
"email": "teammate@example.com",
"password": "TeammatePass123!",
"name": "Teammate",
"account_invite_code": invite_code
})
mate_login = await client.post("/api/v1/auth/login/json", json={
"email": "teammate@example.com",
"password": "TeammatePass123!"
})
mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"}
# Teammate should see the private tree (same account)
response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers)
assert response.status_code == 200
assert response.json()["name"] == "Private Account Tree"
@pytest.mark.asyncio
async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict):
"""Test that users from different accounts cannot see private trees."""
# Owner creates a private tree
tree_resp = await client.post("/api/v1/trees", json={
"name": "Secret Tree",
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
"is_public": False
}, headers=auth_headers)
assert tree_resp.status_code == 201
tree_id = tree_resp.json()["id"]
# Register a completely separate user (different account)
await client.post("/api/v1/auth/register", json={
"email": "outsider@example.com",
"password": "OutsiderPass123!",
"name": "Outsider"
})
outsider_login = await client.post("/api/v1/auth/login/json", json={
"email": "outsider@example.com",
"password": "OutsiderPass123!"
})
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
# Outsider should NOT see the private tree
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
assert response.status_code == 403

View File

@@ -0,0 +1,129 @@
"""Integration tests for subscription limits."""
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
class TestSubscriptionLimits:
"""Test suite for subscription plan limits."""
@pytest.mark.asyncio
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
"""Test that free plan has tree creation limit of 3."""
tree_template = {
"name": "Limit Test Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test tree"
}
}
# Create trees up to the limit
for i in range(3):
tree_data = {**tree_template, "name": f"Tree {i+1}"}
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}"
# 4th tree should fail with 402
tree_data = {**tree_template, "name": "Tree 4 Over Limit"}
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert response.status_code == 402
assert "limit" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict):
"""Test that subscription details reflect actual usage."""
# Check initial usage
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
initial_usage = response.json()["usage"]
assert initial_usage["tree_count"] == 0
# Create a tree
tree_data = {
"name": "Usage Test Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test"
}
}
create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
assert create_resp.status_code == 201
# Check usage increased
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
updated_usage = response.json()["usage"]
assert updated_usage["tree_count"] == 1
@pytest.mark.asyncio
async def test_super_admin_bypasses_limits(
self, client: AsyncClient, admin_auth_headers: dict
):
"""Test that super admin can create trees without limit checks."""
tree_template = {
"name": "Admin Tree",
"tree_structure": {
"id": "root",
"type": "solution",
"title": "Test",
"description": "Test tree"
},
"is_default": True # Default trees skip limit check
}
# Super admin creating default trees should always work
for i in range(5):
tree_data = {**tree_template, "name": f"Admin Tree {i+1}"}
response = await client.post(
"/api/v1/trees", json=tree_data, headers=admin_auth_headers
)
assert response.status_code == 201
@pytest.mark.asyncio
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
"""Test that free plan limits are correct."""
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
limits = response.json()["limits"]
assert limits["plan"] == "free"
assert limits["max_trees"] == 3
assert limits["max_sessions_per_month"] == 20
assert limits["max_users"] == 1
assert limits["custom_branding"] is False
assert limits["priority_support"] is False
@pytest.mark.asyncio
async def test_upgraded_plan_has_higher_limits(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Test that upgrading plan increases limits."""
from app.models.subscription import Subscription
from app.models.user import User
# Get the user's subscription and upgrade it
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
account_id_str = me_resp.json()["account_id"]
from uuid import UUID
account_id = UUID(account_id_str)
result = await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)
sub = result.scalar_one()
sub.plan = "pro"
await test_db.commit()
# Check limits are now pro
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
assert response.status_code == 200
limits = response.json()["limits"]
assert limits["plan"] == "pro"
assert limits["max_trees"] == 25
assert limits["max_sessions_per_month"] == 200

View File

@@ -8,6 +8,7 @@
"name": "frontend",
"version": "0.0.0",
"dependencies": {
"@stripe/stripe-js": "^8.7.0",
"axios": "^1.13.4",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
@@ -1430,6 +1431,15 @@
"win32"
]
},
"node_modules/@stripe/stripe-js": {
"version": "8.7.0",
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-8.7.0.tgz",
"integrity": "sha512-tNUerSstwNC1KuHgX4CASGO0Md3CB26IJzSXmVlSuFvhsBP4ZaEPpY4jxWOn9tfdDscuVT4Kqb8cZ2o9nLCgRQ==",
"license": "MIT",
"engines": {
"node": ">=12.16"
}
},
"node_modules/@types/babel__core": {
"version": "7.20.5",
"resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",

View File

@@ -10,6 +10,7 @@
"preview": "vite preview"
},
"dependencies": {
"@stripe/stripe-js": "^8.7.0",
"axios": "^1.13.4",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",

View File

@@ -0,0 +1,48 @@
import apiClient from './client'
import type { Account, SubscriptionDetails, AccountMember, AccountInvite } from '@/types'
export const accountsApi = {
async getMyAccount(): Promise<Account> {
const response = await apiClient.get<Account>('/accounts/me')
return response.data
},
async getMySubscription(): Promise<SubscriptionDetails> {
const response = await apiClient.get<SubscriptionDetails>('/accounts/me/subscription')
return response.data
},
async updateMyAccount(data: { name?: string }): Promise<Account> {
const response = await apiClient.patch<Account>('/accounts/me', data)
return response.data
},
async getMembers(): Promise<AccountMember[]> {
const response = await apiClient.get<AccountMember[]>('/accounts/me/members')
return response.data
},
async updateMemberRole(userId: string, role: string): Promise<AccountMember> {
const response = await apiClient.patch<AccountMember>(
`/accounts/me/members/${userId}/role`,
{ role }
)
return response.data
},
async removeMember(userId: string): Promise<void> {
await apiClient.delete(`/accounts/me/members/${userId}`)
},
async createInvite(data: { email: string; role: string }): Promise<AccountInvite> {
const response = await apiClient.post<AccountInvite>('/accounts/me/invites', data)
return response.data
},
async getInvites(): Promise<AccountInvite[]> {
const response = await apiClient.get<AccountInvite[]>('/accounts/me/invites')
return response.data
},
}
export default accountsApi

View File

@@ -2,9 +2,9 @@ import apiClient from './client'
import type { Category, CategoryListItem, CategoryCreate, CategoryUpdate } from '@/types'
export const categoriesApi = {
async list(includeInactive = false, teamOnly = false): Promise<CategoryListItem[]> {
async list(includeInactive = false, accountOnly = false): Promise<CategoryListItem[]> {
const response = await apiClient.get<CategoryListItem[]>('/categories', {
params: { include_inactive: includeInactive, team_only: teamOnly },
params: { include_inactive: includeInactive, account_only: accountOnly },
})
return response.data
},

View File

@@ -8,3 +8,4 @@ export { default as categoriesApi } from './categories'
export { default as foldersApi } from './folders'
export { default as stepsApi } from './steps'
export { default as stepCategoriesApi } from './stepCategories'
export { default as accountsApi } from './accounts'

View File

@@ -2,16 +2,16 @@ import apiClient from './client'
import type { Tag, TagListItem, TagCreate, TagAssignment } from '@/types'
export const tagsApi = {
async list(includeTeam = true): Promise<TagListItem[]> {
async list(includeAccount = true): Promise<TagListItem[]> {
const response = await apiClient.get<TagListItem[]>('/tags', {
params: { include_team: includeTeam },
params: { include_account: includeAccount },
})
return response.data
},
async search(query: string, limit = 10, includeTeam = true): Promise<TagListItem[]> {
async search(query: string, limit = 10, includeAccount = true): Promise<TagListItem[]> {
const response = await apiClient.get<TagListItem[]>('/tags/search', {
params: { q: query, limit, include_team: includeTeam },
params: { q: query, limit, include_account: includeAccount },
})
return response.data
},

View File

@@ -0,0 +1,32 @@
import { cn } from '@/lib/utils'
import { useSubscription } from '@/hooks/useSubscription'
interface UpgradePromptProps {
feature: string // e.g., "create more trees", "start more sessions"
className?: string
}
export function UpgradePrompt({ feature, className }: UpgradePromptProps) {
const { plan } = useSubscription()
return (
<div className={cn(
'rounded-lg border border-yellow-500/30 bg-yellow-500/10 p-4',
className
)}>
<h3 className="font-semibold text-foreground">Plan Limit Reached</h3>
<p className="mt-1 text-sm text-muted-foreground">
Your {plan} plan doesn't allow you to {feature}. Upgrade your plan to continue.
</p>
<button
className={cn(
'mt-3 rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
'hover:bg-primary/90'
)}
onClick={() => window.location.href = '/account'}
>
View Plans
</button>
</div>
)
}

View File

@@ -49,6 +49,7 @@ export function AppLayout() {
const navItems = [
{ path: '/trees', label: 'Trees' },
{ path: '/sessions', label: 'Sessions' },
{ path: '/account', label: 'Account' },
{ path: '/settings', label: 'Settings' },
]
@@ -98,12 +99,12 @@ export function AppLayout() {
className={cn(
'hidden rounded-full px-2 py-0.5 text-xs font-medium sm:inline-block',
effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400',
effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400'
)}
>
{effectiveRole === 'super_admin' ? 'Super Admin' :
effectiveRole === 'team_admin' ? 'Team Admin' :
effectiveRole === 'owner' ? 'Owner' :
'Viewer'}
</span>
)}
@@ -158,12 +159,12 @@ export function AppLayout() {
className={cn(
'mt-1 inline-block rounded-full px-2 py-0.5 text-xs font-medium',
effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400',
effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400'
)}
>
{effectiveRole === 'super_admin' ? 'Super Admin' :
effectiveRole === 'team_admin' ? 'Team Admin' :
effectiveRole === 'owner' ? 'Owner' :
'Viewer'}
</span>
)}

View File

@@ -27,7 +27,7 @@ export function ProtectedRoute({ requiredRole, children }: ProtectedRouteProps)
if (requiredRole) {
const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
super_admin: 4,
team_admin: 3,
owner: 3,
engineer: 2,
viewer: 1,
}

View File

@@ -0,0 +1,24 @@
import { cn } from '@/lib/utils'
interface CheckoutButtonProps {
plan: 'pro' | 'team'
className?: string
}
export function CheckoutButton({ plan, className }: CheckoutButtonProps) {
const planLabels = { pro: 'Pro', team: 'Team' }
return (
<button
disabled
title="Billing coming soon"
className={cn(
'rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
'disabled:opacity-50 disabled:cursor-not-allowed',
className
)}
>
Upgrade to {planLabels[plan]} (Coming Soon)
</button>
)
}

View File

@@ -119,7 +119,7 @@ export function TreeMetadataForm() {
{categories.map((cat) => (
<option key={cat.id} value={cat.id}>
{cat.name}
{cat.team_id ? ' (Team)' : ''}
{cat.account_id ? ' (Account)' : ''}
</option>
))}
<option value="__custom__">+ Add custom category</option>

View File

@@ -1,18 +1,18 @@
/**
* Centralized permissions hook for ResolutionFlow.
*
* Role hierarchy: super_admin > team_admin > engineer > viewer
* Role hierarchy: super_admin > owner > engineer > viewer
*
* Mirrors backend logic in backend/app/core/permissions.py
*/
import { useAuthStore } from '@/store/authStore'
import type { User } from '@/types'
export type EffectiveRole = 'super_admin' | 'team_admin' | 'engineer' | 'viewer'
export type EffectiveRole = 'super_admin' | 'owner' | 'engineer' | 'viewer'
const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
super_admin: 4,
team_admin: 3,
owner: 3,
engineer: 2,
viewer: 1,
}
@@ -20,7 +20,7 @@ const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
function getEffectiveRole(user: User | null): EffectiveRole {
if (!user) return 'viewer'
if (user.is_super_admin) return 'super_admin'
if (user.is_team_admin && user.team_id) return 'team_admin'
if (user.account_role === 'owner') return 'owner'
return user.role as EffectiveRole
}
@@ -37,7 +37,7 @@ export function usePermissions() {
return {
effectiveRole,
isSuperAdmin: effectiveRole === 'super_admin',
isTeamAdmin: effectiveRole === 'team_admin' || effectiveRole === 'super_admin',
isAccountOwner: effectiveRole === 'owner' || effectiveRole === 'super_admin',
isEngineer: hasMinimumRole(user, 'engineer'),
isViewer: effectiveRole === 'viewer',
@@ -46,12 +46,12 @@ export function usePermissions() {
canCreateSteps: hasMinimumRole(user, 'engineer'),
// Resource-specific checks
canEditTree: (tree: { author_id: string | null; team_id?: string | null }) => {
canEditTree: (tree: { author_id: string | null; account_id?: string | null }) => {
if (!user) return false
if (user.is_super_admin) return true
if (!hasMinimumRole(user, 'engineer')) return false
if (tree.author_id && tree.author_id === user.id) return true
if (user.is_team_admin && tree.team_id === user.team_id && user.team_id) return true
if (user.account_role === 'owner' && tree.account_id === user.account_id && user.account_id) return true
return false
},
@@ -68,8 +68,8 @@ export function usePermissions() {
},
// Management permissions
canManageCategories: hasMinimumRole(user, 'team_admin'),
canManageCategories: hasMinimumRole(user, 'owner'),
canManageGlobalCategories: effectiveRole === 'super_admin',
canManageTeam: effectiveRole === 'super_admin' || (effectiveRole === 'team_admin'),
canManageAccount: effectiveRole === 'super_admin' || effectiveRole === 'owner',
}
}

View File

@@ -0,0 +1,45 @@
import { useAuthStore } from '@/store/authStore'
export function useSubscription() {
const subscription = useAuthStore((s) => s.subscription)
const plan = subscription?.subscription.plan ?? 'free'
const limits = subscription?.limits ?? null
const usage = subscription?.usage ?? null
const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing'
const isPaidPlan = plan === 'pro' || plan === 'team'
const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => {
if (!limits) return false
return limits[feature]
}
const isAtTreeLimit = (): boolean => {
if (!limits || !usage) return false
if (limits.max_trees === null) return false // unlimited
return usage.tree_count >= limits.max_trees
}
const isAtSessionLimit = (): boolean => {
if (!limits || !usage) return false
if (limits.max_sessions_per_month === null) return false
return usage.session_count_this_month >= limits.max_sessions_per_month
}
const formatLimit = (value: number | null): string => {
return value === null ? 'Unlimited' : String(value)
}
return {
plan,
limits,
usage,
isActive,
isPaidPlan,
canUseFeature,
isAtTreeLimit,
isAtSessionLimit,
formatLimit,
}
}

View File

@@ -0,0 +1,494 @@
import { useEffect, useState } from 'react'
import { Building2, Users, Mail, Crown, Loader2, AlertCircle, Check, X } from 'lucide-react'
import { accountsApi } from '@/api'
import type { Account, AccountMember, AccountInvite } from '@/types'
import { cn } from '@/lib/utils'
import { usePermissions } from '@/hooks/usePermissions'
import { useSubscription } from '@/hooks/useSubscription'
import { useAuthStore } from '@/store/authStore'
import { CheckoutButton } from '@/components/subscription/CheckoutButton'
export function AccountSettingsPage() {
const { isAccountOwner } = usePermissions()
const { plan, limits, usage } = useSubscription()
const subscription = useAuthStore((s) => s.subscription)
const [account, setAccount] = useState<Account | null>(null)
const [members, setMembers] = useState<AccountMember[]>([])
const [invites, setInvites] = useState<AccountInvite[]>([])
const [isLoading, setIsLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
// Account name editing
const [isEditingName, setIsEditingName] = useState(false)
const [editedName, setEditedName] = useState('')
const [isSavingName, setIsSavingName] = useState(false)
// Invite form
const [inviteEmail, setInviteEmail] = useState('')
const [inviteRole, setInviteRole] = useState('engineer')
const [isInviting, setIsInviting] = useState(false)
const [inviteError, setInviteError] = useState<string | null>(null)
const [inviteSuccess, setInviteSuccess] = useState<string | null>(null)
useEffect(() => {
loadData()
}, [])
const loadData = async () => {
setIsLoading(true)
setError(null)
try {
const accountData = await accountsApi.getMyAccount()
setAccount(accountData)
setEditedName(accountData.name)
if (isAccountOwner) {
const [membersData, invitesData] = await Promise.all([
accountsApi.getMembers(),
accountsApi.getInvites(),
])
setMembers(membersData)
setInvites(invitesData)
}
} catch (err) {
setError('Failed to load account information')
console.error(err)
} finally {
setIsLoading(false)
}
}
const handleSaveName = async () => {
if (!editedName.trim() || editedName === account?.name) {
setIsEditingName(false)
return
}
setIsSavingName(true)
try {
const updated = await accountsApi.updateMyAccount({ name: editedName.trim() })
setAccount(updated)
setIsEditingName(false)
} catch (err) {
console.error('Failed to update account name:', err)
} finally {
setIsSavingName(false)
}
}
const handleInvite = async (e: React.FormEvent) => {
e.preventDefault()
if (!inviteEmail.trim()) return
setIsInviting(true)
setInviteError(null)
setInviteSuccess(null)
try {
await accountsApi.createInvite({ email: inviteEmail.trim(), role: inviteRole })
setInviteSuccess(`Invitation sent to ${inviteEmail}`)
setInviteEmail('')
// Refresh invites list
const invitesData = await accountsApi.getInvites()
setInvites(invitesData)
} catch (err) {
setInviteError('Failed to send invitation')
console.error(err)
} finally {
setIsInviting(false)
}
}
const handleRemoveMember = async (userId: string) => {
try {
await accountsApi.removeMember(userId)
setMembers(members.filter((m) => m.id !== userId))
} catch (err) {
console.error('Failed to remove member:', err)
}
}
if (isLoading) {
return (
<div className="flex justify-center py-12">
<div className="h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent" />
</div>
)
}
if (error) {
return (
<div className="container mx-auto px-4 py-6 sm:px-6 sm:py-8">
<div className="rounded-md bg-destructive/10 p-4 text-destructive">
<div className="flex items-center gap-2">
<AlertCircle className="h-5 w-5" />
{error}
</div>
</div>
</div>
)
}
const sub = subscription?.subscription
return (
<div className="container mx-auto px-4 py-6 sm:px-6 sm:py-8">
<div className="mb-8">
<div className="flex items-center gap-3">
<Building2 className="h-8 w-8 text-primary" />
<h1 className="text-2xl font-bold text-foreground sm:text-3xl">Account Settings</h1>
</div>
<p className="mt-2 text-muted-foreground">
Manage your account, subscription, and team
</p>
</div>
<div className="max-w-3xl space-y-6">
{/* Account Info Section */}
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
<h2 className="text-lg font-semibold text-card-foreground">Account Information</h2>
<div className="mt-4 space-y-4">
{/* Account Name */}
<div>
<label className="block font-label text-sm font-medium text-card-foreground">
Account Name
</label>
{isEditingName ? (
<div className="mt-1 flex items-center gap-2">
<input
type="text"
value={editedName}
onChange={(e) => setEditedName(e.target.value)}
className={cn(
'flex-1 rounded-md border border-input bg-background px-3 py-2',
'text-foreground placeholder:text-muted-foreground',
'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
)}
autoFocus
onKeyDown={(e) => {
if (e.key === 'Enter') handleSaveName()
if (e.key === 'Escape') {
setEditedName(account?.name ?? '')
setIsEditingName(false)
}
}}
/>
<button
onClick={handleSaveName}
disabled={isSavingName}
className={cn(
'rounded-md bg-primary p-2 text-primary-foreground',
'hover:bg-primary/90 disabled:opacity-50'
)}
>
{isSavingName ? (
<Loader2 className="h-4 w-4 animate-spin" />
) : (
<Check className="h-4 w-4" />
)}
</button>
<button
onClick={() => {
setEditedName(account?.name ?? '')
setIsEditingName(false)
}}
className="rounded-md border border-input p-2 text-muted-foreground hover:bg-accent"
>
<X className="h-4 w-4" />
</button>
</div>
) : (
<div className="mt-1 flex items-center gap-2">
<span className="text-sm text-foreground">{account?.name}</span>
{isAccountOwner && (
<button
onClick={() => setIsEditingName(true)}
className="text-xs text-primary hover:underline"
>
Edit
</button>
)}
</div>
)}
</div>
{/* Display Code */}
<div>
<label className="block font-label text-sm font-medium text-card-foreground">
Display Code
</label>
<p className="mt-1 text-sm font-mono text-muted-foreground">
{account?.display_code}
</p>
</div>
</div>
</div>
{/* Subscription Section */}
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
<h2 className="text-lg font-semibold text-card-foreground">Subscription</h2>
<div className="mt-4 space-y-4">
{/* Plan & Status */}
<div className="flex items-center gap-3">
<span
className={cn(
'inline-flex items-center gap-1.5 rounded-full px-3 py-1 text-sm font-medium',
plan === 'free' && 'bg-secondary text-secondary-foreground',
plan === 'pro' && 'bg-primary/10 text-primary',
plan === 'team' && 'bg-primary/20 text-primary'
)}
>
<Crown className="h-3.5 w-3.5" />
{plan.charAt(0).toUpperCase() + plan.slice(1)} Plan
</span>
{sub && (
<span
className={cn(
'inline-flex rounded-full px-2.5 py-0.5 text-xs font-medium',
sub.status === 'active' && 'bg-green-500/10 text-green-600',
sub.status === 'trialing' && 'bg-blue-500/10 text-blue-600',
sub.status === 'past_due' && 'bg-yellow-500/10 text-yellow-600',
sub.status === 'canceled' && 'bg-destructive/10 text-destructive',
sub.status === 'orphaned' && 'bg-muted text-muted-foreground'
)}
>
{sub.status.charAt(0).toUpperCase() + sub.status.slice(1).replace('_', ' ')}
</span>
)}
</div>
{sub?.current_period_end && (
<p className="text-sm text-muted-foreground">
Current period ends: {new Date(sub.current_period_end).toLocaleDateString()}
</p>
)}
{/* Usage Stats */}
{limits && usage && (
<div className="mt-4 grid gap-3 sm:grid-cols-3">
<UsageStat
label="Trees"
current={usage.tree_count}
max={limits.max_trees}
/>
<UsageStat
label="Sessions / month"
current={usage.session_count_this_month}
max={limits.max_sessions_per_month}
/>
<UsageStat
label="Team members"
current={usage.user_count}
max={limits.max_users}
/>
</div>
)}
{/* Upgrade buttons */}
{plan === 'free' && (
<div className="mt-4 flex gap-3">
<CheckoutButton plan="pro" />
<CheckoutButton plan="team" />
</div>
)}
{plan === 'pro' && (
<div className="mt-4">
<CheckoutButton plan="team" />
</div>
)}
</div>
</div>
{/* Team Members Section (owners only) */}
{isAccountOwner && (
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
<div className="flex items-center gap-2">
<Users className="h-5 w-5 text-primary" />
<h2 className="text-lg font-semibold text-card-foreground">Team Members</h2>
</div>
{members.length === 0 ? (
<p className="mt-4 text-sm text-muted-foreground">No team members yet.</p>
) : (
<div className="mt-4 divide-y divide-border">
{members.map((member) => (
<div
key={member.id}
className="flex items-center justify-between py-3 first:pt-0 last:pb-0"
>
<div>
<p className="text-sm font-medium text-foreground">{member.name}</p>
<p className="text-xs text-muted-foreground">{member.email}</p>
</div>
<div className="flex items-center gap-3">
<span
className={cn(
'rounded-full px-2.5 py-0.5 text-xs font-medium',
member.account_role === 'owner' && 'bg-primary/10 text-primary',
member.account_role === 'engineer' && 'bg-secondary text-secondary-foreground',
member.account_role === 'viewer' && 'bg-muted text-muted-foreground'
)}
>
{member.account_role}
</span>
{!member.is_active && (
<span className="rounded-full bg-destructive/10 px-2 py-0.5 text-xs text-destructive">
Inactive
</span>
)}
{member.account_role !== 'owner' && (
<button
onClick={() => handleRemoveMember(member.id)}
className="text-muted-foreground hover:text-destructive"
title="Remove member"
>
<X className="h-4 w-4" />
</button>
)}
</div>
</div>
))}
</div>
)}
</div>
)}
{/* Invite Member Section (owners only) */}
{isAccountOwner && (
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
<div className="flex items-center gap-2">
<Mail className="h-5 w-5 text-primary" />
<h2 className="text-lg font-semibold text-card-foreground">Invite Member</h2>
</div>
<form onSubmit={handleInvite} className="mt-4 space-y-3">
<div className="flex flex-col gap-3 sm:flex-row">
<input
type="email"
placeholder="Email address"
value={inviteEmail}
onChange={(e) => setInviteEmail(e.target.value)}
required
className={cn(
'flex-1 rounded-md border border-input bg-background px-3 py-2',
'text-foreground placeholder:text-muted-foreground',
'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
)}
/>
<select
value={inviteRole}
onChange={(e) => setInviteRole(e.target.value)}
className={cn(
'rounded-md border border-input bg-background px-3 py-2',
'text-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
)}
>
<option value="engineer">Engineer</option>
<option value="viewer">Viewer</option>
</select>
<button
type="submit"
disabled={isInviting || !inviteEmail.trim()}
className={cn(
'rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
'hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed'
)}
>
{isInviting ? (
<span className="flex items-center gap-2">
<Loader2 className="h-4 w-4 animate-spin" />
Sending...
</span>
) : (
'Send Invite'
)}
</button>
</div>
{inviteError && (
<p className="text-sm text-destructive">{inviteError}</p>
)}
{inviteSuccess && (
<p className="text-sm text-green-600">{inviteSuccess}</p>
)}
</form>
{/* Pending Invites */}
{invites.length > 0 && (
<div className="mt-6">
<h3 className="text-sm font-medium text-card-foreground">Pending Invites</h3>
<div className="mt-2 divide-y divide-border">
{invites
.filter((inv) => !inv.used_at)
.map((invite) => (
<div
key={invite.id}
className="flex items-center justify-between py-2"
>
<div>
<p className="text-sm text-foreground">{invite.email}</p>
<p className="text-xs text-muted-foreground">
Expires {new Date(invite.expires_at).toLocaleDateString()}
</p>
</div>
<span className="rounded-full bg-secondary px-2.5 py-0.5 text-xs text-secondary-foreground">
{invite.role}
</span>
</div>
))}
</div>
</div>
)}
</div>
)}
</div>
</div>
)
}
/** Small helper component for usage stat display */
function UsageStat({
label,
current,
max,
}: {
label: string
current: number
max: number | null
}) {
const isUnlimited = max === null
const percentage = isUnlimited ? 0 : Math.min((current / max) * 100, 100)
const isNearLimit = !isUnlimited && percentage >= 80
const isAtLimit = !isUnlimited && current >= max
return (
<div className="rounded-md border border-border bg-background p-3">
<p className="text-xs font-medium text-muted-foreground">{label}</p>
<p
className={cn(
'mt-1 text-lg font-semibold',
isAtLimit ? 'text-destructive' : isNearLimit ? 'text-yellow-600' : 'text-foreground'
)}
>
{current}
<span className="text-sm font-normal text-muted-foreground">
{' '}/ {isUnlimited ? 'Unlimited' : max}
</span>
</p>
{!isUnlimited && (
<div className="mt-2 h-1.5 overflow-hidden rounded-full bg-muted">
<div
className={cn(
'h-full rounded-full transition-all',
isAtLimit ? 'bg-destructive' : isNearLimit ? 'bg-yellow-500' : 'bg-primary'
)}
style={{ width: `${percentage}%` }}
/>
</div>
)}
</div>
)
}
export default AccountSettingsPage

View File

@@ -102,7 +102,7 @@ export function TreeEditorPage() {
setLoading(true)
try {
const tree = await treesApi.get(id)
if (!canEditTree({ author_id: tree.author_id, team_id: tree.team_id })) {
if (!canEditTree({ author_id: tree.author_id, account_id: tree.account_id })) {
navigate('/trees')
return
}

View File

@@ -137,6 +137,7 @@ export function TreeLibraryPage() {
try {
await treesApi.delete(treeToDelete.id)
setTrees(trees.filter((t) => t.id !== treeToDelete.id))
window.dispatchEvent(new Event('folder-changed'))
} catch (err) {
console.error('Failed to delete tree:', err)
setError('Failed to delete tree')
@@ -352,7 +353,7 @@ export function TreeLibraryPage() {
</span>
<div className="flex items-center gap-2">
<AddToFolderMenu treeId={tree.id} onFolderCreated={handleCreateFolder} />
{canEditTree({ author_id: tree.author_id, team_id: tree.team_id }) && (
{canEditTree({ author_id: tree.author_id, account_id: tree.account_id }) && (
<Link
to={`/trees/${tree.id}/edit`}
className={cn(

View File

@@ -6,3 +6,4 @@ export { default as TreeEditorPage } from './TreeEditorPage'
export { default as SessionHistoryPage } from './SessionHistoryPage'
export { default as SessionDetailPage } from './SessionDetailPage'
export { default as SettingsPage } from './SettingsPage'
export { default as AccountSettingsPage } from './AccountSettingsPage'

View File

@@ -10,6 +10,7 @@ import {
SessionHistoryPage,
SessionDetailPage,
SettingsPage,
AccountSettingsPage,
} from '@/pages'
export const router = createBrowserRouter([
@@ -64,6 +65,10 @@ export const router = createBrowserRouter([
path: 'settings',
element: <SettingsPage />,
},
{
path: 'account',
element: <AccountSettingsPage />,
},
],
},
])

View File

@@ -1,11 +1,14 @@
import { create } from 'zustand'
import { persist } from 'zustand/middleware'
import type { User, Token, UserCreate, UserLogin } from '@/types'
import type { User, Token, UserCreate, UserLogin, Account, SubscriptionDetails } from '@/types'
import { authApi } from '@/api'
import { apiClient } from '@/api'
interface AuthState {
user: User | null
token: Token | null
account: Account | null
subscription: SubscriptionDetails | null
isAuthenticated: boolean
isLoading: boolean
error: string | null
@@ -25,6 +28,8 @@ export const useAuthStore = create<AuthState>()(
(set, get) => ({
user: null,
token: null,
account: null,
subscription: null,
isAuthenticated: false,
isLoading: false,
error: null,
@@ -70,15 +75,30 @@ export const useAuthStore = create<AuthState>()(
} finally {
localStorage.removeItem('access_token')
localStorage.removeItem('refresh_token')
set({ user: null, token: null, isAuthenticated: false, error: null })
set({ user: null, token: null, account: null, subscription: null, isAuthenticated: false, error: null })
}
},
fetchUser: async () => {
set({ isLoading: true })
try {
const user = await authApi.me()
set({ user, isLoading: false })
const [userResult, accountResult, subscriptionResult] = await Promise.allSettled([
authApi.me(),
apiClient.get<Account>('/accounts/me').then(r => r.data),
apiClient.get<SubscriptionDetails>('/accounts/me/subscription').then(r => r.data),
])
const user = userResult.status === 'fulfilled' ? userResult.value : null
const account = accountResult.status === 'fulfilled' ? accountResult.value : null
const subscription = subscriptionResult.status === 'fulfilled' ? subscriptionResult.value : null
if (!user) {
// User fetch failed — propagate the error
const reason = userResult.status === 'rejected' ? userResult.reason : new Error('Failed to fetch user')
throw reason
}
set({ user, account, subscription, isLoading: false })
} catch (error: unknown) {
const message = error instanceof Error ? error.message : 'Failed to fetch user'
set({ error: message, isLoading: false })
@@ -95,6 +115,8 @@ export const useAuthStore = create<AuthState>()(
partialize: (state) => ({
token: state.token,
isAuthenticated: state.isAuthenticated,
account: state.account,
subscription: state.subscription,
}),
}
)

View File

@@ -0,0 +1,62 @@
export interface Account {
id: string
name: string
display_code: string
owner_id: string
created_at: string
updated_at: string
}
export interface Subscription {
id: string
account_id: string
plan: 'free' | 'pro' | 'team'
status: 'active' | 'past_due' | 'canceled' | 'trialing' | 'orphaned'
current_period_start: string | null
current_period_end: string | null
created_at: string
updated_at: string
}
export interface PlanLimits {
plan: string
max_trees: number | null
max_sessions_per_month: number | null
max_users: number | null
custom_branding: boolean
priority_support: boolean
export_formats: string[]
}
export interface SubscriptionDetails {
subscription: Subscription
limits: PlanLimits
usage: {
tree_count: number
session_count_this_month: number
user_count: number
}
}
export interface AccountInvite {
id: string
account_id: string
email: string
role: 'engineer' | 'viewer'
code: string
invited_by_id: string
accepted_by_id: string | null
expires_at: string
used_at: string | null
created_at: string
}
export interface AccountMember {
id: string
email: string
name: string
account_role: string
is_active: boolean
created_at: string
last_login: string | null
}

View File

@@ -5,7 +5,7 @@ export interface Category {
name: string
slug: string
description: string | null
team_id: string | null
account_id: string | null
display_order: number
is_active: boolean
created_at: string
@@ -18,7 +18,7 @@ export interface CategoryListItem {
name: string
slug: string
description: string | null
team_id: string | null
account_id: string | null
display_order: number
is_active: boolean
tree_count: number
@@ -27,7 +27,7 @@ export interface CategoryListItem {
export interface CategoryCreate {
name: string
description?: string | null
team_id?: string | null
account_id?: string | null
}
export interface CategoryUpdate {

View File

@@ -7,6 +7,7 @@ export * from './tag'
export * from './category'
export * from './folder'
export * from './step'
export type { Account, Subscription, PlanLimits, SubscriptionDetails, AccountInvite, AccountMember } from './account'
// API response wrapper types
export interface PaginatedResponse<T> {

View File

@@ -56,7 +56,7 @@ export interface StepCategory {
name: string
description?: string
display_order: number
team_id?: string
account_id?: string
is_active: boolean
}

View File

@@ -4,7 +4,7 @@ export interface Tag {
id: string
name: string
slug: string
team_id: string | null
account_id: string | null
usage_count: number
created_at: string
}
@@ -13,13 +13,13 @@ export interface TagListItem {
id: string
name: string
slug: string
team_id: string | null
account_id: string | null
usage_count: number
}
export interface TagCreate {
name: string
team_id?: string | null
account_id?: string | null
}
export interface TagAssignment {

View File

@@ -67,7 +67,7 @@ export interface Tree {
tags: string[]
tree_structure: TreeStructure
author_id: string | null
team_id: string | null
account_id: string | null
is_active: boolean
is_public: boolean
is_default: boolean
@@ -86,7 +86,7 @@ export interface TreeListItem {
category_info: CategoryInfo | null
tags: string[]
author_id: string | null
team_id: string | null
account_id: string | null
is_active: boolean
is_public: boolean
is_default: boolean

View File

@@ -6,8 +6,8 @@ export interface User {
name: string
role: UserRole
is_super_admin: boolean
is_team_admin: boolean
team_id: string | null
account_id: string
account_role: 'owner' | 'engineer' | 'viewer'
created_at: string
last_login: string | null
}