Merge pull request #32 from patherly/feat/subscription-tiers
feat: account-based subscription tiers (Free/Pro/Team)
This commit was merged in pull request #32.
This commit is contained in:
110
backend/alembic/versions/016_add_subscription_tables.py
Normal file
110
backend/alembic/versions/016_add_subscription_tables.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""add accounts, subscriptions, plan_limits, and account_invites tables
|
||||
|
||||
Revision ID: 016
|
||||
Revises: 015
|
||||
Create Date: 2026-02-07
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '016'
|
||||
down_revision: Union[str, None] = '015'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. accounts table
|
||||
op.create_table(
|
||||
'accounts',
|
||||
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('display_code', sa.String(8), nullable=False),
|
||||
sa.Column('owner_id', UUID(as_uuid=True), nullable=True), # nullable until user created
|
||||
sa.Column('stripe_customer_id', sa.String(255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('display_code', name='uq_accounts_display_code'),
|
||||
)
|
||||
|
||||
# 2. subscriptions table
|
||||
op.create_table(
|
||||
'subscriptions',
|
||||
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
|
||||
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('stripe_subscription_id', sa.String(255), nullable=True),
|
||||
sa.Column('stripe_price_id', sa.String(255), nullable=True),
|
||||
sa.Column('plan', sa.String(50), nullable=False, server_default='free'),
|
||||
sa.Column('billing_interval', sa.String(20), nullable=True), # 'monthly' or 'annual'
|
||||
sa.Column('status', sa.String(50), nullable=False, server_default='active'),
|
||||
sa.Column('seat_limit', sa.Integer, nullable=True),
|
||||
sa.Column('current_period_start', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('current_period_end', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('cancel_at_period_end', sa.Boolean, nullable=False, server_default='false'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.UniqueConstraint('account_id', name='uq_subscriptions_account_id'),
|
||||
)
|
||||
op.create_index('ix_subscriptions_account_id', 'subscriptions', ['account_id'])
|
||||
op.create_index('ix_subscriptions_plan', 'subscriptions', ['plan'])
|
||||
|
||||
# 3. plan_limits table (configuration — seeded with 3 rows)
|
||||
op.create_table(
|
||||
'plan_limits',
|
||||
sa.Column('plan', sa.String(50), nullable=False),
|
||||
sa.Column('max_trees', sa.Integer, nullable=True), # NULL = unlimited
|
||||
sa.Column('max_sessions_per_month', sa.Integer, nullable=True),
|
||||
sa.Column('max_users', sa.Integer, nullable=True),
|
||||
sa.Column('custom_branding', sa.Boolean, nullable=False, server_default='false'),
|
||||
sa.Column('priority_support', sa.Boolean, nullable=False, server_default='false'),
|
||||
sa.Column('export_formats', JSONB, nullable=False, server_default='["markdown", "text"]'),
|
||||
sa.PrimaryKeyConstraint('plan'),
|
||||
)
|
||||
|
||||
# Seed plan_limits
|
||||
op.execute("""
|
||||
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
|
||||
VALUES
|
||||
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
|
||||
('pro', 25, 200, 1, false, true, '["markdown", "text", "html", "pdf"]'),
|
||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html", "pdf"]')
|
||||
""")
|
||||
|
||||
# 4. account_invites table
|
||||
op.create_table(
|
||||
'account_invites',
|
||||
sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
|
||||
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('invited_by_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('code', sa.String(32), nullable=False),
|
||||
sa.Column('role', sa.String(50), nullable=False, server_default='engineer'),
|
||||
sa.Column('accepted_by_id', UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False),
|
||||
sa.Column('used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['invited_by_id'], ['users.id']),
|
||||
sa.ForeignKeyConstraint(['accepted_by_id'], ['users.id']),
|
||||
sa.UniqueConstraint('code', name='uq_account_invites_code'),
|
||||
sa.CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'),
|
||||
)
|
||||
op.create_index('ix_account_invites_account_id', 'account_invites', ['account_id'])
|
||||
op.create_index('ix_account_invites_email', 'account_invites', ['email'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('account_invites')
|
||||
op.drop_table('plan_limits')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_table('accounts')
|
||||
31
backend/alembic/versions/017_add_account_id_to_users.py
Normal file
31
backend/alembic/versions/017_add_account_id_to_users.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""add account_id and account_role columns to users
|
||||
|
||||
Revision ID: 017
|
||||
Revises: 016
|
||||
Create Date: 2026-02-07
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '017'
|
||||
down_revision: Union[str, None] = '016'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('users', sa.Column('account_id', UUID(as_uuid=True), nullable=True))
|
||||
op.add_column('users', sa.Column('account_role', sa.String(50), nullable=True))
|
||||
op.create_index('ix_users_account_id', 'users', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_users_account_id', table_name='users')
|
||||
op.drop_column('users', 'account_role')
|
||||
op.drop_column('users', 'account_id')
|
||||
187
backend/alembic/versions/018_migrate_users_to_accounts.py
Normal file
187
backend/alembic/versions/018_migrate_users_to_accounts.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""migrate existing users and teams to accounts
|
||||
|
||||
Revision ID: 018
|
||||
Revises: 017
|
||||
Create Date: 2026-02-07
|
||||
|
||||
This is the most critical migration. It creates a _team_account_mapping table
|
||||
for deterministic cross-migration lookups, then migrates all users to accounts.
|
||||
|
||||
Three paths:
|
||||
A) Teams with users → Account with deterministic owner
|
||||
B) Teams with zero users → Account with owner_id=NULL, subscription status='orphaned'
|
||||
C) Users without a team → Personal account, user is owner
|
||||
"""
|
||||
import secrets
|
||||
import string
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '018'
|
||||
down_revision: Union[str, None] = '017'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
# Characters for display codes — exclude confusing chars
|
||||
DISPLAY_CODE_CHARS = string.ascii_uppercase + string.digits
|
||||
DISPLAY_CODE_CHARS = DISPLAY_CODE_CHARS.replace('0', '').replace('O', '').replace('I', '').replace('1', '').replace('L', '')
|
||||
|
||||
|
||||
def _generate_display_code(existing_codes: set) -> str:
|
||||
"""Generate a unique 8-character display code."""
|
||||
for _ in range(100):
|
||||
code = ''.join(secrets.choice(DISPLAY_CODE_CHARS) for _ in range(8))
|
||||
if code not in existing_codes:
|
||||
existing_codes.add(code)
|
||||
return code
|
||||
raise RuntimeError("Failed to generate unique display code after 100 attempts")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Create mapping table for deterministic cross-migration lookups
|
||||
op.create_table(
|
||||
'_team_account_mapping',
|
||||
sa.Column('team_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('owner_user_id', UUID(as_uuid=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('team_id'),
|
||||
)
|
||||
|
||||
existing_codes: set = set()
|
||||
|
||||
# --- Path A & B: Process all teams ---
|
||||
teams = conn.execute(sa.text("SELECT id, name FROM teams")).fetchall()
|
||||
|
||||
for team in teams:
|
||||
team_id = team[0]
|
||||
team_name = team[1]
|
||||
display_code = _generate_display_code(existing_codes)
|
||||
|
||||
# Find deterministic owner: team admin first, then earliest user
|
||||
owner_row = conn.execute(sa.text("""
|
||||
SELECT id FROM users
|
||||
WHERE team_id = :tid
|
||||
ORDER BY is_team_admin DESC, created_at ASC, id ASC
|
||||
LIMIT 1
|
||||
"""), {"tid": team_id}).fetchone()
|
||||
|
||||
owner_user_id = owner_row[0] if owner_row else None
|
||||
|
||||
# Create account
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), :name, :code, :owner_id, NOW(), NOW())
|
||||
"""), {"name": team_name, "code": display_code, "owner_id": owner_user_id})
|
||||
|
||||
# Get the account we just created
|
||||
account_row = conn.execute(sa.text(
|
||||
"SELECT id FROM accounts WHERE display_code = :code"
|
||||
), {"code": display_code}).fetchone()
|
||||
account_id = account_row[0]
|
||||
|
||||
# Insert mapping
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO _team_account_mapping (team_id, account_id, owner_user_id)
|
||||
VALUES (:tid, :aid, :uid)
|
||||
"""), {"tid": team_id, "aid": account_id, "uid": owner_user_id})
|
||||
|
||||
if owner_user_id is not None:
|
||||
# Path A: Team with users
|
||||
# Create active subscription
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW())
|
||||
"""), {"aid": account_id})
|
||||
|
||||
# Update all users in this team
|
||||
# Team admins become owners, others keep their role
|
||||
conn.execute(sa.text("""
|
||||
UPDATE users SET
|
||||
account_id = :aid,
|
||||
account_role = CASE
|
||||
WHEN is_team_admin = true THEN 'owner'
|
||||
ELSE role
|
||||
END
|
||||
WHERE team_id = :tid
|
||||
"""), {"aid": account_id, "tid": team_id})
|
||||
else:
|
||||
# Path B: Team with zero users (orphan)
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), :aid, 'free', 'orphaned', NOW(), NOW())
|
||||
"""), {"aid": account_id})
|
||||
|
||||
# --- Path C: Users without a team ---
|
||||
teamless_users = conn.execute(sa.text(
|
||||
"SELECT id, name FROM users WHERE team_id IS NULL AND account_id IS NULL"
|
||||
)).fetchall()
|
||||
|
||||
for user in teamless_users:
|
||||
user_id = user[0]
|
||||
user_name = user[1]
|
||||
display_code = _generate_display_code(existing_codes)
|
||||
|
||||
# Create personal account (owner_id set to NULL initially)
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), :name, :code, NULL, NOW(), NOW())
|
||||
"""), {"name": f"{user_name}'s Account", "code": display_code})
|
||||
|
||||
account_row = conn.execute(sa.text(
|
||||
"SELECT id FROM accounts WHERE display_code = :code"
|
||||
), {"code": display_code}).fetchone()
|
||||
account_id = account_row[0]
|
||||
|
||||
# Update user
|
||||
conn.execute(sa.text("""
|
||||
UPDATE users SET account_id = :aid, account_role = 'owner'
|
||||
WHERE id = :uid
|
||||
"""), {"aid": account_id, "uid": user_id})
|
||||
|
||||
# Set owner
|
||||
conn.execute(sa.text("""
|
||||
UPDATE accounts SET owner_id = :uid WHERE id = :aid
|
||||
"""), {"uid": user_id, "aid": account_id})
|
||||
|
||||
# Create free subscription
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at)
|
||||
VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW())
|
||||
"""), {"aid": account_id})
|
||||
|
||||
# --- Validation ---
|
||||
orphaned_users = conn.execute(sa.text(
|
||||
"SELECT COUNT(*) FROM users WHERE account_id IS NULL"
|
||||
)).scalar()
|
||||
if orphaned_users > 0:
|
||||
raise RuntimeError(
|
||||
f"Migration 018 failed validation: {orphaned_users} users still have NULL account_id"
|
||||
)
|
||||
|
||||
team_count = conn.execute(sa.text("SELECT COUNT(*) FROM teams")).scalar()
|
||||
mapping_count = conn.execute(sa.text("SELECT COUNT(*) FROM _team_account_mapping")).scalar()
|
||||
if mapping_count != team_count:
|
||||
raise RuntimeError(
|
||||
f"Migration 018 failed: mapping count ({mapping_count}) != team count ({team_count})"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Clear account data from users
|
||||
conn.execute(sa.text("UPDATE users SET account_id = NULL, account_role = NULL"))
|
||||
|
||||
# Delete all subscriptions and accounts created by this migration
|
||||
conn.execute(sa.text("DELETE FROM subscriptions"))
|
||||
conn.execute(sa.text("DELETE FROM accounts"))
|
||||
|
||||
# Drop mapping table
|
||||
op.drop_table('_team_account_mapping')
|
||||
56
backend/alembic/versions/019_migrate_team_fks_to_account.py
Normal file
56
backend/alembic/versions/019_migrate_team_fks_to_account.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""add account_id to content tables and backfill from _team_account_mapping
|
||||
|
||||
Revision ID: 019
|
||||
Revises: 018
|
||||
Create Date: 2026-02-07
|
||||
|
||||
Uses the _team_account_mapping table from migration 018 for deterministic
|
||||
FK backfill instead of non-deterministic LIMIT 1 subqueries.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '019'
|
||||
down_revision: Union[str, None] = '018'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
# Tables that have team_id and need account_id
|
||||
CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories']
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for table in CONTENT_TABLES:
|
||||
# Add account_id column
|
||||
op.add_column(table, sa.Column('account_id', UUID(as_uuid=True), nullable=True))
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
# Backfill from mapping table (deterministic)
|
||||
conn.execute(sa.text(f"""
|
||||
UPDATE {table} SET account_id = m.account_id
|
||||
FROM _team_account_mapping m
|
||||
WHERE {table}.team_id = m.team_id
|
||||
"""))
|
||||
|
||||
# Validate: no rows with team_id but missing account_id
|
||||
orphaned = conn.execute(sa.text(f"""
|
||||
SELECT COUNT(*) FROM {table}
|
||||
WHERE team_id IS NOT NULL AND account_id IS NULL
|
||||
""")).scalar()
|
||||
if orphaned > 0:
|
||||
raise RuntimeError(
|
||||
f"Migration 019 failed: {table} has {orphaned} rows with team_id but no account_id"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in reversed(CONTENT_TABLES):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_column(table, 'account_id')
|
||||
104
backend/alembic/versions/020_finalize_account_migration.py
Normal file
104
backend/alembic/versions/020_finalize_account_migration.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""finalize account migration — add constraints, clean up orphans
|
||||
|
||||
Revision ID: 020
|
||||
Revises: 019
|
||||
Create Date: 2026-02-07
|
||||
|
||||
Adds NOT NULL constraints, foreign keys, and CHECK constraints.
|
||||
Cleans up orphan accounts (zero-user teams with no content).
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '020'
|
||||
down_revision: Union[str, None] = '019'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories']
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# 1. Clean up orphan accounts (zero-user teams with no content)
|
||||
conn.execute(sa.text("""
|
||||
DELETE FROM subscriptions WHERE account_id IN (
|
||||
SELECT a.id FROM accounts a
|
||||
WHERE a.owner_id IS NULL
|
||||
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id)
|
||||
)
|
||||
"""))
|
||||
|
||||
# Also remove the mapping entries for these orphans
|
||||
conn.execute(sa.text("""
|
||||
DELETE FROM _team_account_mapping WHERE account_id IN (
|
||||
SELECT a.id FROM accounts a
|
||||
WHERE a.owner_id IS NULL
|
||||
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id)
|
||||
)
|
||||
"""))
|
||||
|
||||
conn.execute(sa.text("""
|
||||
DELETE FROM accounts
|
||||
WHERE owner_id IS NULL
|
||||
AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = accounts.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = accounts.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = accounts.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = accounts.id)
|
||||
AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = accounts.id)
|
||||
"""))
|
||||
|
||||
# 2. Users: enforce NOT NULL and add FK + CHECK
|
||||
op.alter_column('users', 'account_id', nullable=False)
|
||||
op.alter_column('users', 'account_role', nullable=False)
|
||||
op.create_foreign_key(
|
||||
'fk_users_account_id', 'users', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
op.create_check_constraint(
|
||||
'ck_users_account_role_enum', 'users',
|
||||
"account_role IN ('owner', 'engineer', 'viewer')"
|
||||
)
|
||||
|
||||
# 3. Content tables: add FK on account_id (nullable OK — NULL means global)
|
||||
for table in CONTENT_TABLES:
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# 4. Accounts: add owner FK (owner_id stays nullable due to circular FK with users)
|
||||
op.create_foreign_key(
|
||||
'fk_accounts_owner_id', 'accounts', 'users',
|
||||
['owner_id'], ['id'], ondelete='RESTRICT'
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove account owner FK and nullable constraint
|
||||
op.drop_constraint('fk_accounts_owner_id', 'accounts', type_='foreignkey')
|
||||
op.alter_column('accounts', 'owner_id', nullable=True)
|
||||
|
||||
# Remove content table FKs
|
||||
for table in reversed(CONTENT_TABLES):
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
|
||||
# Remove user constraints
|
||||
op.drop_constraint('ck_users_account_role_enum', 'users', type_='check')
|
||||
op.drop_constraint('fk_users_account_id', 'users', type_='foreignkey')
|
||||
op.alter_column('users', 'account_role', nullable=True)
|
||||
op.alter_column('users', 'account_id', nullable=True)
|
||||
30
backend/alembic/versions/021_fix_owner_id_nullable.py
Normal file
30
backend/alembic/versions/021_fix_owner_id_nullable.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""fix accounts.owner_id to be nullable (circular FK with users)
|
||||
|
||||
Revision ID: 021
|
||||
Revises: 020
|
||||
Create Date: 2026-02-07
|
||||
|
||||
The original migration 020 enforced NOT NULL on owner_id, but this creates
|
||||
a circular FK problem: Account needs owner_id (User) and User needs
|
||||
account_id (Account). Making owner_id nullable resolves this — we create
|
||||
the Account first, then the User, then set owner_id.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '021'
|
||||
down_revision: Union[str, None] = '020'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column('accounts', 'owner_id', nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column('accounts', 'owner_id', nullable=False)
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from app.core.database import get_db
|
||||
from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
@@ -90,14 +91,35 @@ async def require_admin(
|
||||
async def require_engineer_or_admin(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
"""Require engineer, team admin, or super admin role (blocks viewers)."""
|
||||
"""Require engineer, account owner, or super admin role (blocks viewers)."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
if current_user.is_team_admin and current_user.team_id is not None:
|
||||
if current_user.account_role in ("owner", "engineer"):
|
||||
return current_user
|
||||
if current_user.role not in ("engineer",):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Engineer or admin access required"
|
||||
)
|
||||
return current_user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Engineer or admin access required"
|
||||
)
|
||||
|
||||
|
||||
async def require_account_owner(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
) -> User:
|
||||
"""Require account owner or super admin access."""
|
||||
if current_user.is_super_admin:
|
||||
return current_user
|
||||
if current_user.account_role == "owner":
|
||||
return current_user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account owner access required"
|
||||
)
|
||||
|
||||
|
||||
async def get_plan_limits_for_user(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> Optional[PlanLimits]:
|
||||
"""Get plan limits for the current user's account."""
|
||||
from app.core.subscriptions import get_user_plan_limits
|
||||
return await get_user_plan_limits(current_user.account_id, db)
|
||||
|
||||
236
backend/app/api/endpoints/accounts.py
Normal file
236
backend/app/api/endpoints/accounts.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
import secrets
|
||||
import string
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.api.deps import get_current_active_user, require_account_owner
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["accounts"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=AccountResponse)
|
||||
async def get_my_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's account."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
return account
|
||||
|
||||
|
||||
@router.get("/me/subscription", response_model=SubscriptionDetails)
|
||||
async def get_my_subscription(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get current user's subscription details including limits and usage."""
|
||||
sub = await get_account_subscription(current_user.account_id, db)
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No subscription found"
|
||||
)
|
||||
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if not limits:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Plan limits not configured"
|
||||
)
|
||||
|
||||
usage = await get_account_usage(current_user.account_id, db)
|
||||
|
||||
return SubscriptionDetails(
|
||||
subscription=SubscriptionResponse.model_validate(sub),
|
||||
limits=PlanLimitsResponse.model_validate(limits),
|
||||
usage=UsageResponse(**usage),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me/members", response_model=list[UserResponse])
|
||||
async def get_my_members(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get members of current user's account."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.account_id == current_user.account_id)
|
||||
.order_by(User.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.patch("/me", response_model=AccountResponse)
|
||||
async def update_my_account(
|
||||
data: AccountUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Update account settings (owner only)."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Account not found"
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(account, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
return account
|
||||
|
||||
|
||||
@router.patch("/me/members/{user_id}/role", response_model=UserResponse)
|
||||
async def update_member_role(
|
||||
user_id: UUID,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Change a member's role within the account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot change your own role"
|
||||
)
|
||||
|
||||
user.account_role = data.account_role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_member(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Remove a member from the account (owner only).
|
||||
|
||||
The removed user gets a new personal account.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
if user.id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot remove yourself from your own account"
|
||||
)
|
||||
|
||||
# Create a personal account for the removed user
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
display_code = ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
new_account = Account(
|
||||
name=f"{user.name}'s Account",
|
||||
display_code=display_code,
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_sub = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_sub)
|
||||
|
||||
user.account_id = new_account.id
|
||||
user.account_role = "owner"
|
||||
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_invite(
|
||||
data: AccountInviteCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create an invite to join this account (owner only)."""
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
expires_at = None
|
||||
if data.expires_in_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days)
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=current_user.account_id,
|
||||
invited_by_id=current_user.id,
|
||||
email=data.email,
|
||||
code=code,
|
||||
role=data.role,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
return invite
|
||||
|
||||
|
||||
@router.get("/me/invites", response_model=list[AccountInviteResponse])
|
||||
async def list_invites(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""List invites for this account (owner only)."""
|
||||
result = await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.account_id == current_user.account_id)
|
||||
.order_by(AccountInvite.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import select, func
|
||||
from app.core.database import get_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate
|
||||
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
|
||||
from app.api.deps import require_admin
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
@@ -21,7 +21,7 @@ async def list_users(
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
role: Optional[str] = Query(None, description="Filter by role"),
|
||||
team_id: Optional[UUID] = Query(None, description="Filter by team")
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account")
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
@@ -30,8 +30,8 @@ async def list_users(
|
||||
query = query.where(User.is_active == is_active)
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
if team_id:
|
||||
query = query.where(User.team_id == team_id)
|
||||
if account_id:
|
||||
query = query.where(User.account_id == account_id)
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
@@ -91,14 +91,14 @@ async def update_user_role(
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/team-admin", response_model=UserResponse)
|
||||
async def toggle_team_admin(
|
||||
@router.put("/users/{user_id}/account-role", response_model=UserResponse)
|
||||
async def update_account_role(
|
||||
user_id: UUID,
|
||||
data: TeamAdminUpdate,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Toggle is_team_admin for a user (super admin only)."""
|
||||
"""Change a user's account role (super admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -108,15 +108,10 @@ async def toggle_team_admin(
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if data.is_team_admin and user.team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User must belong to a team to be a team admin"
|
||||
)
|
||||
|
||||
user.is_team_admin = data.is_team_admin
|
||||
await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id,
|
||||
{"is_team_admin": data.is_team_admin})
|
||||
old_role = user.account_role
|
||||
user.account_role = data.account_role
|
||||
await log_audit(db, current_user.id, "user.account_role_change", "user", user.id,
|
||||
{"old_account_role": old_role, "new_account_role": data.account_role})
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
@@ -18,6 +20,9 @@ from app.core.security import (
|
||||
from app.models.user import User
|
||||
from app.models.invite_code import InviteCode
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.token import Token
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
@@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id
|
||||
db.add(token_record)
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
"""Generate a random 8-character alphanumeric display code."""
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("3/minute")
|
||||
async def register(
|
||||
@@ -44,10 +55,46 @@ async def register(
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Register a new user."""
|
||||
# Validate invite code if required
|
||||
"""Register a new user.
|
||||
|
||||
Supports two flows:
|
||||
- account_invite_code: Join an existing account (bypasses platform invite gate)
|
||||
- invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled)
|
||||
|
||||
After user creation, if no account invite was used, a personal Account
|
||||
and free Subscription are created automatically.
|
||||
"""
|
||||
# Check for account invite code FIRST — bypasses platform invite gate
|
||||
account_invite_record = None
|
||||
if user_data.account_invite_code:
|
||||
result = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.code == user_data.account_invite_code
|
||||
)
|
||||
)
|
||||
account_invite_record = result.scalar_one_or_none()
|
||||
|
||||
if not account_invite_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid account invite code"
|
||||
)
|
||||
|
||||
if account_invite_record.is_used:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has already been used"
|
||||
)
|
||||
|
||||
if account_invite_record.is_expired:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account invite code has expired"
|
||||
)
|
||||
|
||||
# Validate platform invite code if required (skip if account invite was provided)
|
||||
invite_code_record = None
|
||||
if settings.REQUIRE_INVITE_CODE:
|
||||
if not account_invite_record and settings.REQUIRE_INVITE_CODE:
|
||||
if not user_data.invite_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -87,17 +134,55 @@ async def register(
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
new_user = User(
|
||||
email=user_data.email,
|
||||
password_hash=get_password_hash(user_data.password),
|
||||
name=user_data.name,
|
||||
role="engineer",
|
||||
invite_code_id=invite_code_record.id if invite_code_record else None
|
||||
)
|
||||
db.add(new_user)
|
||||
if account_invite_record:
|
||||
# Join existing account via account invite
|
||||
new_user = User(
|
||||
email=user_data.email,
|
||||
password_hash=get_password_hash(user_data.password),
|
||||
name=user_data.name,
|
||||
role="engineer",
|
||||
invite_code_id=invite_code_record.id if invite_code_record else None,
|
||||
account_id=account_invite_record.account_id,
|
||||
account_role=account_invite_record.role,
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush()
|
||||
|
||||
# Mark invite code as used
|
||||
# Mark account invite as used
|
||||
account_invite_record.accepted_by_id = new_user.id
|
||||
account_invite_record.used_at = datetime.now(timezone.utc)
|
||||
else:
|
||||
# Create personal Account first (user needs account_id for NOT NULL constraint)
|
||||
new_account = Account(
|
||||
name=f"{user_data.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush() # Get account ID
|
||||
|
||||
new_user = User(
|
||||
email=user_data.email,
|
||||
password_hash=get_password_hash(user_data.password),
|
||||
name=user_data.name,
|
||||
role="engineer",
|
||||
invite_code_id=invite_code_record.id if invite_code_record else None,
|
||||
account_id=new_account.id,
|
||||
account_role="owner",
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # Get user ID
|
||||
|
||||
# Now set account owner and create subscription
|
||||
new_account.owner_id = new_user.id
|
||||
|
||||
new_subscription = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_subscription)
|
||||
|
||||
# Mark platform invite code as used
|
||||
if invite_code_record:
|
||||
invite_code_record.used_by_id = new_user.id
|
||||
invite_code_record.used_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -28,11 +28,11 @@ async def list_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(TreeCategory)
|
||||
@@ -41,19 +41,19 @@ async def list_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(TreeCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(TreeCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(TreeCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeCategory.team_id.is_(None), # Global
|
||||
TreeCategory.team_id == current_user.team_id # User's team
|
||||
TreeCategory.account_id.is_(None), # Global
|
||||
TreeCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(TreeCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(TreeCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
|
||||
|
||||
@@ -76,7 +76,7 @@ async def list_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
tree_count=tree_count
|
||||
@@ -101,8 +101,8 @@ async def get_category(
|
||||
detail="Category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -121,7 +121,7 @@ async def get_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -138,10 +138,10 @@ async def create_category(
|
||||
):
|
||||
"""Create a new category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_category(current_user, category_data.team_id):
|
||||
if not can_create_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this category"
|
||||
@@ -150,10 +150,10 @@ async def create_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == slug,
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -164,7 +164,7 @@ async def create_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(TreeCategory.display_order)).where(
|
||||
TreeCategory.team_id == category_data.team_id
|
||||
TreeCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -173,7 +173,7 @@ async def create_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -186,7 +186,7 @@ async def create_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -227,7 +227,7 @@ async def update_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(TreeCategory).where(
|
||||
TreeCategory.slug == new_slug,
|
||||
TreeCategory.team_id == category.team_id,
|
||||
TreeCategory.account_id == category.account_id,
|
||||
TreeCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -257,7 +257,7 @@ async def update_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -25,11 +25,11 @@ async def list_step_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_inactive: bool = Query(False, description="Include inactive categories"),
|
||||
team_only: bool = Query(False, description="Only show team-specific categories")
|
||||
account_only: bool = Query(False, description="Only show account-specific categories")
|
||||
):
|
||||
"""List step categories visible to the user.
|
||||
|
||||
Returns global categories plus team-specific categories for the user's team.
|
||||
Returns global categories plus account-specific categories for the user's account.
|
||||
"""
|
||||
# Build query for accessible categories
|
||||
query = select(StepCategory)
|
||||
@@ -38,19 +38,19 @@ async def list_step_categories(
|
||||
if not include_inactive:
|
||||
query = query.where(StepCategory.is_active == True)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if team_only and current_user.team_id:
|
||||
query = query.where(StepCategory.team_id == current_user.team_id)
|
||||
elif current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if account_only and current_user.account_id:
|
||||
query = query.where(StepCategory.account_id == current_user.account_id)
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
StepCategory.team_id.is_(None), # Global
|
||||
StepCategory.team_id == current_user.team_id # User's team
|
||||
StepCategory.account_id.is_(None), # Global
|
||||
StepCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no team, only show global categories
|
||||
query = query.where(StepCategory.team_id.is_(None))
|
||||
# User has no account, only show global categories
|
||||
query = query.where(StepCategory.account_id.is_(None))
|
||||
|
||||
query = query.order_by(StepCategory.display_order, StepCategory.name)
|
||||
|
||||
@@ -66,7 +66,7 @@ async def list_step_categories(
|
||||
name=cat.name,
|
||||
slug=cat.slug,
|
||||
description=cat.description,
|
||||
team_id=cat.team_id,
|
||||
account_id=cat.account_id,
|
||||
display_order=cat.display_order,
|
||||
is_active=cat.is_active,
|
||||
step_count=0 # Will be computed when step_library exists
|
||||
@@ -91,8 +91,8 @@ async def get_step_category(
|
||||
detail="Step category not found"
|
||||
)
|
||||
|
||||
# Check access: global categories visible to all, team categories only to team members
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global categories visible to all, account categories only to account members
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this step category"
|
||||
@@ -103,7 +103,7 @@ async def get_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
@@ -120,10 +120,10 @@ async def create_step_category(
|
||||
):
|
||||
"""Create a new step category.
|
||||
|
||||
- Global admins can create global categories (team_id=None)
|
||||
- Team admins can create team-specific categories for their team
|
||||
- Global admins can create global categories (account_id=None)
|
||||
- Account admins can create account-specific categories for their account
|
||||
"""
|
||||
if not can_create_step_category(current_user, category_data.team_id):
|
||||
if not can_create_step_category(current_user, category_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this step category"
|
||||
@@ -132,10 +132,10 @@ async def create_step_category(
|
||||
# Generate slug
|
||||
slug = slugify(category_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == slug,
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -146,7 +146,7 @@ async def create_step_category(
|
||||
|
||||
# Get next display order
|
||||
order_query = select(func.max(StepCategory.display_order)).where(
|
||||
StepCategory.team_id == category_data.team_id
|
||||
StepCategory.account_id == category_data.account_id
|
||||
)
|
||||
order_result = await db.execute(order_query)
|
||||
max_order = order_result.scalar() or 0
|
||||
@@ -155,7 +155,7 @@ async def create_step_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
team_id=category_data.team_id,
|
||||
account_id=category_data.account_id,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
@@ -168,7 +168,7 @@ async def create_step_category(
|
||||
name=new_category.name,
|
||||
slug=new_category.slug,
|
||||
description=new_category.description,
|
||||
team_id=new_category.team_id,
|
||||
account_id=new_category.account_id,
|
||||
display_order=new_category.display_order,
|
||||
is_active=new_category.is_active,
|
||||
created_at=new_category.created_at,
|
||||
@@ -209,7 +209,7 @@ async def update_step_category(
|
||||
# Check for duplicate slug
|
||||
existing_query = select(StepCategory).where(
|
||||
StepCategory.slug == new_slug,
|
||||
StepCategory.team_id == category.team_id,
|
||||
StepCategory.account_id == category.account_id,
|
||||
StepCategory.id != category_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
@@ -231,7 +231,7 @@ async def update_step_category(
|
||||
name=category.name,
|
||||
slug=category.slug,
|
||||
description=category.description,
|
||||
team_id=category.team_id,
|
||||
account_id=category.account_id,
|
||||
display_order=category.display_order,
|
||||
is_active=category.is_active,
|
||||
created_at=category.created_at,
|
||||
|
||||
@@ -55,10 +55,10 @@ async def get_step_or_404(
|
||||
|
||||
def build_visibility_filter(user: User):
|
||||
"""Build SQLAlchemy filter for step visibility based on user."""
|
||||
if user.team_id:
|
||||
if user.account_id:
|
||||
return or_(
|
||||
StepLibrary.visibility == 'public',
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id),
|
||||
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id),
|
||||
StepLibrary.created_by == user.id # Own private steps
|
||||
)
|
||||
else:
|
||||
@@ -249,7 +249,7 @@ async def get_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -296,10 +296,10 @@ async def create_step(
|
||||
if not cat_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Invalid category")
|
||||
|
||||
# Team validation: can only set team_id to own team
|
||||
team_id = step_data.team_id
|
||||
if team_id and team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another team")
|
||||
# Account validation: can only set account_id to own account
|
||||
account_id = step_data.account_id
|
||||
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Cannot create step for another account")
|
||||
|
||||
step = StepLibrary(
|
||||
title=step_data.title,
|
||||
@@ -309,7 +309,7 @@ async def create_step(
|
||||
tags=step_data.tags,
|
||||
visibility=step_data.visibility,
|
||||
created_by=current_user.id,
|
||||
team_id=team_id or current_user.team_id,
|
||||
account_id=account_id or current_user.account_id,
|
||||
)
|
||||
|
||||
db.add(step)
|
||||
@@ -326,7 +326,7 @@ async def create_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
@@ -393,7 +393,7 @@ async def update_step(
|
||||
"tags": step.tags,
|
||||
"visibility": step.visibility,
|
||||
"created_by": step.created_by,
|
||||
"team_id": step.team_id,
|
||||
"account_id": step.account_id,
|
||||
"usage_count": step.usage_count,
|
||||
"rating_average": step.rating_average,
|
||||
"rating_count": step.rating_count,
|
||||
|
||||
@@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"])
|
||||
async def list_tags(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""List tags visible to the user.
|
||||
|
||||
Returns global tags plus team-specific tags for the user's team.
|
||||
Returns global tags plus account-specific tags for the user's account.
|
||||
Tags are ordered by usage count (most used first).
|
||||
"""
|
||||
query = select(TreeTag)
|
||||
|
||||
# Filter by visibility: global OR user's team
|
||||
if include_team and current_user.team_id:
|
||||
# Filter by visibility: global OR user's account
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global
|
||||
TreeTag.team_id == current_user.team_id # User's team
|
||||
TreeTag.account_id.is_(None), # Global
|
||||
TreeTag.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only show global tags
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
|
||||
|
||||
@@ -55,7 +55,7 @@ async def search_tags(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
include_team: bool = Query(True, description="Include team-specific tags")
|
||||
include_account: bool = Query(True, description="Include account-specific tags")
|
||||
):
|
||||
"""Search/autocomplete tags.
|
||||
|
||||
@@ -68,15 +68,15 @@ async def search_tags(
|
||||
)
|
||||
|
||||
# Filter by visibility
|
||||
if include_team and current_user.team_id:
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == current_user.team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.where(TreeTag.team_id.is_(None))
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
|
||||
|
||||
@@ -102,8 +102,8 @@ async def get_tag(
|
||||
detail="Tag not found"
|
||||
)
|
||||
|
||||
# Check access: global tags visible to all, team tags only to team members
|
||||
if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
# Check access: global tags visible to all, account tags only to account members
|
||||
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tag"
|
||||
@@ -120,10 +120,10 @@ async def create_tag(
|
||||
):
|
||||
"""Create a new tag.
|
||||
|
||||
- Global admins can create global tags (team_id=None)
|
||||
- Team members can create team-specific tags for their team
|
||||
- Global admins can create global tags (account_id=None)
|
||||
- Account members can create account-specific tags for their account
|
||||
"""
|
||||
if not can_create_tag(current_user, tag_data.team_id):
|
||||
if not can_create_tag(current_user, tag_data.account_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to create this tag"
|
||||
@@ -132,10 +132,10 @@ async def create_tag(
|
||||
# Generate slug
|
||||
slug = TreeTag.slugify(tag_data.name)
|
||||
|
||||
# Check for duplicate slug within same scope (global or team)
|
||||
# Check for duplicate slug within same scope (global or account)
|
||||
existing_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
TreeTag.team_id == tag_data.team_id
|
||||
TreeTag.account_id == tag_data.account_id
|
||||
)
|
||||
existing = await db.execute(existing_query)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -147,7 +147,7 @@ async def create_tag(
|
||||
new_tag = TreeTag(
|
||||
name=tag_data.name,
|
||||
slug=slug,
|
||||
team_id=tag_data.team_id,
|
||||
account_id=tag_data.account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(new_tag)
|
||||
@@ -200,30 +200,30 @@ async def add_tags_to_tree(
|
||||
continue
|
||||
|
||||
# Try to find existing tag
|
||||
# Determine scope: use tree's team, or global for admin-owned trees
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
# Determine scope: use tree's account, or global for admin-owned trees
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None), # Global tag
|
||||
TreeTag.team_id == tag_team_id # Team tag
|
||||
TreeTag.account_id.is_(None), # Global tag
|
||||
TreeTag.account_id == tag_account_id # Account tag
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
tag = tag_result.scalar_one_or_none()
|
||||
|
||||
if not tag:
|
||||
# Create new tag - prefer team scope unless admin creating on public tree
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
# Fall back to user's team if they can't create in tree's scope
|
||||
new_team_id = current_user.team_id
|
||||
# Create new tag - prefer account scope unless admin creating on public tree
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
# Fall back to user's account if they can't create in tree's scope
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -331,7 +331,7 @@ async def replace_tree_tags(
|
||||
tree.tags.clear()
|
||||
|
||||
# Add new tags
|
||||
tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
for tag_name in tag_data.tags:
|
||||
slug = TreeTag.slugify(tag_name)
|
||||
@@ -340,8 +340,8 @@ async def replace_tree_tags(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tag_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tag_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -349,14 +349,14 @@ async def replace_tree_tags(
|
||||
|
||||
if not tag:
|
||||
# Create new tag
|
||||
new_team_id = tag_team_id
|
||||
if not can_create_tag(current_user, new_team_id):
|
||||
new_team_id = current_user.team_id
|
||||
new_account_id = tag_account_id
|
||||
if not can_create_tag(current_user, new_account_id):
|
||||
new_account_id = current_user.account_id
|
||||
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=new_team_id,
|
||||
account_id=new_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -397,7 +397,7 @@ async def get_tree_tags(
|
||||
# Check if user can view the tree
|
||||
if not tree.is_public:
|
||||
if tree.author_id != current_user.id:
|
||||
if tree.team_id != current_user.team_id:
|
||||
if tree.account_id != current_user.account_id:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
from app.core.audit import log_audit
|
||||
|
||||
router = APIRouter(prefix="/trees", tags=["trees"])
|
||||
@@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User):
|
||||
Tree.is_public == True,
|
||||
Tree.author_id == current_user.id,
|
||||
]
|
||||
if current_user.team_id:
|
||||
conditions.append(Tree.team_id == current_user.team_id)
|
||||
if current_user.account_id:
|
||||
conditions.append(Tree.account_id == current_user.account_id)
|
||||
return or_(*conditions)
|
||||
|
||||
|
||||
@@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
|
||||
category_info=category_info,
|
||||
tags=tree.tag_names,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
tags=tree.tag_names,
|
||||
tree_structure=tree.tree_structure,
|
||||
author_id=tree.author_id,
|
||||
team_id=tree.team_id,
|
||||
account_id=tree.account_id,
|
||||
is_active=tree.is_active,
|
||||
is_public=tree.is_public,
|
||||
is_default=tree.is_default,
|
||||
@@ -289,7 +290,7 @@ async def create_tree(
|
||||
detail="Category not found"
|
||||
)
|
||||
# Check category access
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -302,16 +303,25 @@ async def create_tree(
|
||||
category_id=tree_data.category_id,
|
||||
tree_structure=tree_data.tree_structure,
|
||||
author_id=None if is_default else current_user.id, # Default trees have no author
|
||||
team_id=None if is_default else current_user.team_id,
|
||||
account_id=None if is_default else current_user.account_id,
|
||||
is_public=True if is_default else tree_data.is_public, # Default trees are always public
|
||||
is_default=is_default
|
||||
)
|
||||
# Check subscription tree limit
|
||||
if not is_default and current_user.account_id:
|
||||
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
||||
if not can_create:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
db.add(new_tree)
|
||||
await db.flush() # Get the ID
|
||||
|
||||
# Handle tags
|
||||
if tree_data.tags:
|
||||
tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
|
||||
# Collect tags to add
|
||||
tags_to_add = []
|
||||
@@ -322,8 +332,8 @@ async def create_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -334,7 +344,7 @@ async def create_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
@@ -420,7 +430,7 @@ async def update_tree(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Category not found"
|
||||
)
|
||||
if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin:
|
||||
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this category"
|
||||
@@ -450,7 +460,7 @@ async def update_tree(
|
||||
)
|
||||
|
||||
# Add new tags
|
||||
tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None)
|
||||
tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None)
|
||||
added_tag_ids = set()
|
||||
|
||||
for tag_name in tags_data:
|
||||
@@ -459,8 +469,8 @@ async def update_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.team_id.is_(None),
|
||||
TreeTag.team_id == tree_team_id
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == tree_account_id
|
||||
)
|
||||
)
|
||||
tag_result = await db.execute(tag_query)
|
||||
@@ -470,7 +480,7 @@ async def update_tree(
|
||||
tag = TreeTag(
|
||||
name=tag_name,
|
||||
slug=slug,
|
||||
team_id=tree_team_id,
|
||||
account_id=tree_account_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(tag)
|
||||
|
||||
62
backend/app/api/endpoints/webhooks.py
Normal file
62
backend/app/api/endpoints/webhooks.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Request, HTTPException, status, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
from app.core.stripe_handlers import WEBHOOK_HANDLERS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/webhooks", tags=["webhooks"])
|
||||
|
||||
|
||||
@router.post("/stripe")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
Returns 200 for all events to prevent Stripe retries.
|
||||
Actual processing happens only when Stripe is configured.
|
||||
"""
|
||||
if not settings.stripe_enabled:
|
||||
return {"status": "ok", "message": "Stripe not configured, event ignored"}
|
||||
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get("stripe-signature")
|
||||
|
||||
if not sig_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing stripe-signature header"
|
||||
)
|
||||
|
||||
# Verify webhook signature
|
||||
try:
|
||||
import stripe
|
||||
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("stripe package not installed, cannot verify webhook")
|
||||
return {"status": "ok", "message": "stripe package not installed"}
|
||||
except Exception as e:
|
||||
logger.error("Stripe webhook signature verification failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid signature"
|
||||
)
|
||||
|
||||
event_type = event.get("type", "")
|
||||
handler = WEBHOOK_HANDLERS.get(event_type)
|
||||
|
||||
if handler:
|
||||
try:
|
||||
await handler(event, db)
|
||||
except Exception:
|
||||
logger.exception("Error handling Stripe event %s", event_type)
|
||||
|
||||
return {"status": "ok"}
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -13,3 +13,5 @@ api_router.include_router(folders.router)
|
||||
api_router.include_router(step_categories.router)
|
||||
api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
|
||||
@@ -52,6 +52,16 @@ class Settings(BaseSettings):
|
||||
# Registration
|
||||
REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration
|
||||
|
||||
# Stripe
|
||||
STRIPE_SECRET_KEY: Optional[str] = None
|
||||
STRIPE_PUBLISHABLE_KEY: Optional[str] = None
|
||||
STRIPE_WEBHOOK_SECRET: Optional[str] = None
|
||||
|
||||
@property
|
||||
def stripe_enabled(self) -> bool:
|
||||
"""Check if Stripe is configured."""
|
||||
return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None
|
||||
|
||||
# CORS - set FRONTEND_URL in production (e.g., https://patherly.up.railway.app)
|
||||
CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173", "http://localhost:5174"]
|
||||
FRONTEND_URL: Optional[str] = None
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
Centralized permission checks for ResolutionFlow.
|
||||
|
||||
Role hierarchy: super_admin > team_admin > engineer > viewer
|
||||
Role hierarchy: super_admin > owner > engineer > viewer
|
||||
|
||||
- super_admin: is_super_admin=True, full system access
|
||||
- team_admin: is_team_admin=True + valid team_id, manage team resources
|
||||
- engineer: role='engineer' (default), CRUD own trees/steps
|
||||
- viewer: role='viewer', read-only (can browse, run sessions, rate steps)
|
||||
- owner: account_role='owner', manage account resources
|
||||
- engineer: account_role='engineer' (default), CRUD own trees/steps
|
||||
- viewer: account_role='viewer', read-only (can browse, run sessions, rate steps)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
@@ -21,19 +21,19 @@ if TYPE_CHECKING:
|
||||
|
||||
ROLE_HIERARCHY = {
|
||||
"super_admin": 4,
|
||||
"team_admin": 3,
|
||||
"owner": 3,
|
||||
"engineer": 2,
|
||||
"viewer": 1,
|
||||
}
|
||||
|
||||
|
||||
def get_effective_role(user: User) -> str:
|
||||
"""Get the effective role considering is_super_admin and is_team_admin flags."""
|
||||
"""Get the effective role considering is_super_admin and account_role."""
|
||||
if user.is_super_admin:
|
||||
return "super_admin"
|
||||
if user.is_team_admin and user.team_id is not None:
|
||||
return "team_admin"
|
||||
return user.role # "engineer" or "viewer"
|
||||
if user.account_role == "owner":
|
||||
return "owner"
|
||||
return user.account_role # "engineer" or "viewer"
|
||||
|
||||
|
||||
def has_minimum_role(user: User, minimum_role: str) -> bool:
|
||||
@@ -55,7 +55,7 @@ def can_edit_tree(user: User, tree: Tree) -> bool:
|
||||
return False
|
||||
if tree.author_id == user.id:
|
||||
return True
|
||||
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -78,7 +78,7 @@ def can_manage_category(user: User, category: TreeCategory) -> bool:
|
||||
"""Can the user edit/delete this category?"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -91,7 +91,7 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool:
|
||||
return False
|
||||
if tree.author_id == user.id:
|
||||
return True
|
||||
if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -102,7 +102,7 @@ def can_access_tree(user: User, tree: Tree) -> bool:
|
||||
return True
|
||||
if tree.author_id == user.id:
|
||||
return True
|
||||
if tree.team_id == user.team_id and user.team_id is not None:
|
||||
if tree.account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
@@ -116,35 +116,35 @@ def can_view_step(user: User, step: StepLibrary) -> bool:
|
||||
if step.visibility == "private":
|
||||
return step.created_by == user.id
|
||||
if step.visibility == "team":
|
||||
return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin
|
||||
return (step.account_id == user.account_id and user.account_id is not None) or user.is_super_admin
|
||||
return False
|
||||
|
||||
|
||||
def can_create_tag(user: User, team_id: Optional[UUID]) -> bool:
|
||||
def can_create_tag(user: User, account_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a tag for the given scope?
|
||||
|
||||
- Super admins can create global tags (team_id=None) or any team's tags
|
||||
- Engineers can create team tags for their own team
|
||||
- Super admins can create global tags (account_id=None) or any account's tags
|
||||
- Engineers can create account tags for their own account
|
||||
- Viewers cannot create tags
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if not can_create_content(user):
|
||||
return False
|
||||
if team_id is not None and team_id == user.team_id:
|
||||
if account_id is not None and account_id == user.account_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a category for the given team?
|
||||
def can_create_category(user: User, account_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a category for the given account?
|
||||
|
||||
- Super admins can create global or any team's categories
|
||||
- Team admins can create categories for their own team
|
||||
- Super admins can create global or any account's categories
|
||||
- Account owners can create categories for their own account
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -153,19 +153,19 @@ def can_manage_step_category(user: User, category: StepCategory) -> bool:
|
||||
"""Can the user edit/delete this step category?"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a step category for the given team?
|
||||
def can_create_step_category(user: User, account_id: Optional[UUID]) -> bool:
|
||||
"""Can the user create a step category for the given account?
|
||||
|
||||
- Super admins can create global or any team's step categories
|
||||
- Team admins can create step categories for their own team
|
||||
- Super admins can create global or any account's step categories
|
||||
- Account owners can create step categories for their own account
|
||||
"""
|
||||
if user.is_super_admin:
|
||||
return True
|
||||
if user.is_team_admin and team_id == user.team_id and user.team_id is not None:
|
||||
if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
37
backend/app/core/stripe_handlers.py
Normal file
37
backend/app/core/stripe_handlers.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Stripe webhook event handlers (stub implementations).
|
||||
|
||||
These handlers log events but don't process them until Stripe is fully configured.
|
||||
"""
|
||||
import logging
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_checkout_completed(event: dict, db: AsyncSession) -> None:
|
||||
logger.info("Stripe: checkout.session.completed — %s", event.get("id"))
|
||||
|
||||
|
||||
async def handle_invoice_paid(event: dict, db: AsyncSession) -> None:
|
||||
logger.info("Stripe: invoice.paid — %s", event.get("id"))
|
||||
|
||||
|
||||
async def handle_invoice_payment_failed(event: dict, db: AsyncSession) -> None:
|
||||
logger.warning("Stripe: invoice.payment_failed — %s", event.get("id"))
|
||||
|
||||
|
||||
async def handle_subscription_updated(event: dict, db: AsyncSession) -> None:
|
||||
logger.info("Stripe: customer.subscription.updated — %s", event.get("id"))
|
||||
|
||||
|
||||
async def handle_subscription_deleted(event: dict, db: AsyncSession) -> None:
|
||||
logger.info("Stripe: customer.subscription.deleted — %s", event.get("id"))
|
||||
|
||||
|
||||
WEBHOOK_HANDLERS = {
|
||||
"checkout.session.completed": handle_checkout_completed,
|
||||
"invoice.paid": handle_invoice_paid,
|
||||
"invoice.payment_failed": handle_invoice_payment_failed,
|
||||
"customer.subscription.updated": handle_subscription_updated,
|
||||
"customer.subscription.deleted": handle_subscription_deleted,
|
||||
}
|
||||
113
backend/app/core/subscriptions.py
Normal file
113
backend/app/core/subscriptions.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Subscription limit checks and plan helpers."""
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.tree import Tree
|
||||
from app.models.session import Session
|
||||
|
||||
|
||||
async def get_account_subscription(account_id: UUID, db: AsyncSession) -> Optional[Subscription]:
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_plan_limits(plan: str, db: AsyncSession) -> Optional[PlanLimits]:
|
||||
result = await db.execute(
|
||||
select(PlanLimits).where(PlanLimits.plan == plan)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_user_plan_limits(user_account_id: UUID, db: AsyncSession) -> Optional[PlanLimits]:
|
||||
sub = await get_account_subscription(user_account_id, db)
|
||||
if sub is None:
|
||||
return await get_plan_limits("free", db)
|
||||
return await get_plan_limits(sub.plan, db)
|
||||
|
||||
|
||||
async def check_tree_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]:
|
||||
"""Check if account can create a new tree.
|
||||
|
||||
Returns: (can_create, limit, current_count)
|
||||
"""
|
||||
sub = await get_account_subscription(account_id, db)
|
||||
if sub is None:
|
||||
return False, 0, 0
|
||||
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if limits is None or limits.max_trees is None:
|
||||
return True, None, 0 # unlimited
|
||||
|
||||
current_count = await db.scalar(
|
||||
select(func.count(Tree.id)).where(
|
||||
Tree.account_id == account_id,
|
||||
Tree.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
current_count = current_count or 0
|
||||
|
||||
return current_count < limits.max_trees, limits.max_trees, current_count
|
||||
|
||||
|
||||
async def check_session_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]:
|
||||
"""Check if account can create a new session this month.
|
||||
|
||||
Returns: (can_create, limit, current_count)
|
||||
"""
|
||||
sub = await get_account_subscription(account_id, db)
|
||||
if sub is None:
|
||||
return False, 0, 0
|
||||
|
||||
limits = await get_plan_limits(sub.plan, db)
|
||||
if limits is None or limits.max_sessions_per_month is None:
|
||||
return True, None, 0 # unlimited
|
||||
|
||||
# Count sessions this calendar month for all users in this account
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
from app.models.user import User
|
||||
current_count = await db.scalar(
|
||||
select(func.count(Session.id)).where(
|
||||
Session.user_id.in_(
|
||||
select(User.id).where(User.account_id == account_id)
|
||||
),
|
||||
Session.started_at >= month_start,
|
||||
)
|
||||
)
|
||||
current_count = current_count or 0
|
||||
|
||||
return current_count < limits.max_sessions_per_month, limits.max_sessions_per_month, current_count
|
||||
|
||||
|
||||
async def get_account_usage(account_id: UUID, db: AsyncSession) -> dict:
|
||||
"""Get current usage stats for an account."""
|
||||
tree_count = await db.scalar(
|
||||
select(func.count(Tree.id)).where(
|
||||
Tree.account_id == account_id,
|
||||
Tree.deleted_at.is_(None),
|
||||
)
|
||||
) or 0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
from app.models.user import User
|
||||
session_count = await db.scalar(
|
||||
select(func.count(Session.id)).where(
|
||||
Session.user_id.in_(
|
||||
select(User.id).where(User.account_id == account_id)
|
||||
),
|
||||
Session.started_at >= month_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
return {"tree_count": tree_count, "session_count_this_month": session_count}
|
||||
@@ -1,5 +1,9 @@
|
||||
from .user import User
|
||||
from .team import Team
|
||||
from .account import Account
|
||||
from .subscription import Subscription
|
||||
from .plan_limits import PlanLimits
|
||||
from .account_invite import AccountInvite
|
||||
from .tree import Tree
|
||||
from .session import Session
|
||||
from .attachment import Attachment
|
||||
@@ -15,6 +19,10 @@ from .audit_log import AuditLog
|
||||
__all__ = [
|
||||
"User",
|
||||
"Team",
|
||||
"Account",
|
||||
"Subscription",
|
||||
"PlanLimits",
|
||||
"AccountInvite",
|
||||
"Tree",
|
||||
"Session",
|
||||
"Attachment",
|
||||
|
||||
38
backend/app/models/account.py
Normal file
38
backend/app/models/account.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.tree import Tree
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.tag import TreeTag
|
||||
from app.models.step_category import StepCategory
|
||||
from app.models.step_library import StepLibrary
|
||||
|
||||
|
||||
class Account(Base):
|
||||
__tablename__ = "accounts"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
display_code: Mapped[str] = mapped_column(String(8), unique=True, nullable=False)
|
||||
owner_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=True)
|
||||
stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Relationships
|
||||
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")
|
||||
users: Mapped[list["User"]] = relationship("User", foreign_keys="[User.account_id]", back_populates="account")
|
||||
subscription: Mapped[Optional["Subscription"]] = relationship("Subscription", back_populates="account", uselist=False)
|
||||
trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.account_id]", back_populates="account")
|
||||
categories: Mapped[list["TreeCategory"]] = relationship("TreeCategory", foreign_keys="[TreeCategory.account_id]", back_populates="account")
|
||||
tags: Mapped[list["TreeTag"]] = relationship("TreeTag", foreign_keys="[TreeTag.account_id]", back_populates="account")
|
||||
step_categories: Mapped[list["StepCategory"]] = relationship("StepCategory", foreign_keys="[StepCategory.account_id]", back_populates="account")
|
||||
step_library: Mapped[list["StepLibrary"]] = relationship("StepLibrary", foreign_keys="[StepLibrary.account_id]", back_populates="account")
|
||||
48
backend/app/models/account_invite.py
Normal file
48
backend/app/models/account_invite.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, CheckConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class AccountInvite(Base):
|
||||
__tablename__ = "account_invites"
|
||||
__table_args__ = (
|
||||
CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False)
|
||||
invited_by_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
|
||||
accepted_by_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
account: Mapped["Account"] = relationship("Account")
|
||||
invited_by: Mapped["User"] = relationship("User", foreign_keys=[invited_by_id])
|
||||
accepted_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[accepted_by_id])
|
||||
|
||||
@property
|
||||
def is_used(self) -> bool:
|
||||
return self.accepted_by_id is not None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return not self.is_used and not self.is_expired
|
||||
@@ -9,6 +9,7 @@ from app.core.database import Base
|
||||
if TYPE_CHECKING:
|
||||
from app.models.tree import Tree
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -38,6 +39,12 @@ class TreeCategory(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
@@ -57,10 +64,11 @@ class TreeCategory(Base):
|
||||
|
||||
# Relationships
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="categories")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="categories")
|
||||
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
|
||||
trees: Mapped[list["Tree"]] = relationship("Tree", back_populates="category_rel")
|
||||
|
||||
@property
|
||||
def is_global(self) -> bool:
|
||||
"""Returns True if this is a global category (not team-specific)."""
|
||||
return self.team_id is None
|
||||
"""Returns True if this is a global category (not team or account-specific)."""
|
||||
return self.team_id is None and self.account_id is None
|
||||
|
||||
16
backend/app/models/plan_limits.py
Normal file
16
backend/app/models/plan_limits.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import String, Integer, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class PlanLimits(Base):
|
||||
__tablename__ = "plan_limits"
|
||||
|
||||
plan: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
max_trees: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
max_sessions_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
max_users: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"])
|
||||
@@ -8,6 +8,7 @@ from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -37,6 +38,12 @@ class StepCategory(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
@@ -56,9 +63,10 @@ class StepCategory(Base):
|
||||
|
||||
# Relationships
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="step_categories")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_categories")
|
||||
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
|
||||
|
||||
@property
|
||||
def is_global(self) -> bool:
|
||||
"""Returns True if this is a global category (not team-specific)."""
|
||||
return self.team_id is None
|
||||
"""Returns True if this is a global category (not team or account-specific)."""
|
||||
return self.team_id is None and self.account_id is None
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.core.database import Base
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.step_category import StepCategory
|
||||
from app.models.session import Session
|
||||
|
||||
@@ -43,6 +44,12 @@ class StepLibrary(Base):
|
||||
ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
|
||||
# Organization
|
||||
category_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
@@ -91,6 +98,7 @@ class StepLibrary(Base):
|
||||
# Relationships
|
||||
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
|
||||
team: Mapped[Optional["Team"]] = relationship("Team")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_library")
|
||||
category: Mapped[Optional["StepCategory"]] = relationship("StepCategory")
|
||||
ratings: Mapped[list["StepRating"]] = relationship("StepRating", back_populates="step", cascade="all, delete-orphan")
|
||||
usage_logs: Mapped[list["StepUsageLog"]] = relationship("StepUsageLog", back_populates="step", cascade="all, delete-orphan")
|
||||
|
||||
39
backend/app/models/subscription.py
Normal file
39
backend/app/models/subscription.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), unique=True, nullable=False)
|
||||
stripe_subscription_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
stripe_price_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
plan: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||
billing_interval: Mapped[Optional[str]] = mapped_column(String(20), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="active")
|
||||
seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
current_period_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
current_period_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Relationships
|
||||
account: Mapped["Account"] = relationship("Account", back_populates="subscription")
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
return self.status in ("active", "trialing")
|
||||
|
||||
@property
|
||||
def is_paid(self) -> bool:
|
||||
return self.plan in ("pro", "team")
|
||||
@@ -9,6 +9,7 @@ from app.core.database import Base
|
||||
if TYPE_CHECKING:
|
||||
from app.models.tree import Tree
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@@ -50,6 +51,12 @@ class TreeTag(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
@@ -63,6 +70,7 @@ class TreeTag(Base):
|
||||
|
||||
# Relationships
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="tags")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="tags")
|
||||
creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by])
|
||||
trees: Mapped[list["Tree"]] = relationship(
|
||||
"Tree",
|
||||
@@ -72,8 +80,8 @@ class TreeTag(Base):
|
||||
|
||||
@property
|
||||
def is_global(self) -> bool:
|
||||
"""Returns True if this is a global tag (not team-specific)."""
|
||||
return self.team_id is None
|
||||
"""Returns True if this is a global tag (not team or account-specific)."""
|
||||
return self.team_id is None and self.account_id is None
|
||||
|
||||
@classmethod
|
||||
def slugify(cls, name: str) -> str:
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.core.database import Base
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.session import Session
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.tag import TreeTag
|
||||
@@ -47,6 +48,12 @@ class Tree(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False, index=True)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, index=True)
|
||||
@@ -75,6 +82,7 @@ class Tree(Base):
|
||||
# Relationships
|
||||
author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees")
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees")
|
||||
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree")
|
||||
|
||||
# New organization relationships
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.tree import Tree
|
||||
from app.models.session import Session
|
||||
from app.models.folder import UserFolder
|
||||
@@ -20,6 +21,10 @@ class User(Base):
|
||||
"role IN ('engineer', 'viewer')",
|
||||
name='ck_users_role_enum'
|
||||
),
|
||||
CheckConstraint(
|
||||
"account_role IN ('owner', 'admin', 'engineer', 'viewer')",
|
||||
name='ck_users_account_role_enum'
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
@@ -34,6 +39,17 @@ class User(Base):
|
||||
is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
|
||||
|
||||
# Account-based multi-tenancy (new)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
|
||||
|
||||
# Legacy team columns (kept for PR A coexistence)
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("teams.id"),
|
||||
@@ -51,6 +67,8 @@ class User(Base):
|
||||
last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="users")
|
||||
owned_account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys="[Account.owner_id]", back_populates="owner", uselist=False)
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="users")
|
||||
trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.author_id]", back_populates="author")
|
||||
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="user")
|
||||
@@ -62,6 +80,11 @@ class User(Base):
|
||||
return self.is_super_admin
|
||||
|
||||
@property
|
||||
def can_manage_team(self) -> bool:
|
||||
"""Returns True if user can manage their team (team admin or super admin)."""
|
||||
return self.is_super_admin or (self.is_team_admin and self.team_id is not None)
|
||||
def is_account_owner(self) -> bool:
|
||||
"""Returns True if user owns their account."""
|
||||
return self.account_role == "owner"
|
||||
|
||||
@property
|
||||
def can_manage_account(self) -> bool:
|
||||
"""Returns True if user can manage their account (owner, admin, or super admin)."""
|
||||
return self.is_super_admin or self.account_role in ("owner", "admin")
|
||||
|
||||
39
backend/app/schemas/account.py
Normal file
39
backend/app/schemas/account.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
display_code: str
|
||||
owner_id: UUID
|
||||
stripe_customer_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
|
||||
|
||||
class AccountInviteCreate(BaseModel):
|
||||
email: str = Field(..., max_length=255)
|
||||
role: str = Field("engineer", pattern="^(engineer|viewer)$")
|
||||
expires_in_days: Optional[int] = Field(None, ge=1, le=30)
|
||||
|
||||
|
||||
class AccountInviteResponse(BaseModel):
|
||||
id: UUID
|
||||
account_id: UUID
|
||||
email: str
|
||||
code: str
|
||||
role: str
|
||||
expires_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -20,7 +20,7 @@ class CategoryBase(BaseModel):
|
||||
|
||||
|
||||
class CategoryCreate(CategoryBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
@@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel):
|
||||
class CategoryResponse(CategoryBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
@@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel):
|
||||
name: str
|
||||
slug: str
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
tree_count: int = 0
|
||||
|
||||
@@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel):
|
||||
|
||||
|
||||
class StepCategoryCreate(StepCategoryBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.")
|
||||
|
||||
|
||||
class StepCategoryUpdate(BaseModel):
|
||||
@@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel):
|
||||
class StepCategoryResponse(StepCategoryBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
@@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel):
|
||||
name: str
|
||||
slug: str
|
||||
description: Optional[str] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
display_order: int
|
||||
is_active: bool
|
||||
step_count: int = 0
|
||||
|
||||
@@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel):
|
||||
|
||||
|
||||
class StepLibraryCreate(StepLibraryBase):
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
|
||||
|
||||
class StepLibraryUpdate(BaseModel):
|
||||
@@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel):
|
||||
class StepLibraryResponse(StepLibraryBase):
|
||||
id: UUID
|
||||
created_by: UUID
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
rating_average: Decimal
|
||||
rating_count: int
|
||||
|
||||
40
backend/app/schemas/subscription.py
Normal file
40
backend/app/schemas/subscription.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SubscriptionResponse(BaseModel):
|
||||
id: UUID
|
||||
plan: str
|
||||
status: str
|
||||
billing_interval: Optional[str] = None
|
||||
current_period_start: Optional[datetime] = None
|
||||
current_period_end: Optional[datetime] = None
|
||||
cancel_at_period_end: bool = False
|
||||
stripe_subscription_id: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class PlanLimitsResponse(BaseModel):
|
||||
plan: str
|
||||
max_trees: Optional[int] = None
|
||||
max_sessions_per_month: Optional[int] = None
|
||||
max_users: Optional[int] = None
|
||||
custom_branding: bool = False
|
||||
priority_support: bool = False
|
||||
export_formats: list[str] = ["markdown", "text"]
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UsageResponse(BaseModel):
|
||||
tree_count: int
|
||||
session_count_this_month: int
|
||||
|
||||
|
||||
class SubscriptionDetails(BaseModel):
|
||||
subscription: SubscriptionResponse
|
||||
limits: PlanLimitsResponse
|
||||
usage: UsageResponse
|
||||
@@ -19,13 +19,13 @@ class TagBase(BaseModel):
|
||||
|
||||
|
||||
class TagCreate(TagBase):
|
||||
team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.")
|
||||
account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.")
|
||||
|
||||
|
||||
class TagResponse(TagBase):
|
||||
id: UUID
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
created_at: datetime
|
||||
|
||||
@@ -37,7 +37,7 @@ class TagListResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
slug: str
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
usage_count: int
|
||||
|
||||
class Config:
|
||||
@@ -53,4 +53,4 @@ class TagSearchParams(BaseModel):
|
||||
"""Query parameters for tag search/autocomplete."""
|
||||
q: str = Field(..., min_length=1, description="Search query")
|
||||
limit: int = Field(10, ge=1, le=50)
|
||||
include_team: bool = Field(True, description="Include team-specific tags")
|
||||
include_account: bool = Field(True, description="Include account-specific tags")
|
||||
|
||||
@@ -44,7 +44,7 @@ class TreeResponse(TreeBase):
|
||||
id: UUID
|
||||
tree_structure: dict[str, Any]
|
||||
author_id: Optional[UUID] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
category_id: Optional[UUID] = None
|
||||
category_info: Optional[CategoryInfo] = None
|
||||
tags: list[str] = [] # List of tag names
|
||||
@@ -69,7 +69,7 @@ class TreeListResponse(BaseModel):
|
||||
category_info: Optional[CategoryInfo] = None
|
||||
tags: list[str] = [] # List of tag names
|
||||
author_id: Optional[UUID] = None
|
||||
team_id: Optional[UUID] = None
|
||||
account_id: Optional[UUID] = None
|
||||
is_active: bool
|
||||
is_public: bool
|
||||
is_default: bool
|
||||
|
||||
@@ -13,6 +13,7 @@ class UserBase(BaseModel):
|
||||
class UserCreate(UserBase):
|
||||
password: str = Field(..., min_length=10, description="Password must be at least 10 characters")
|
||||
invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)")
|
||||
account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account")
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
@@ -38,11 +39,11 @@ class UserLogin(BaseModel):
|
||||
|
||||
class UserResponse(UserBase):
|
||||
id: UUID
|
||||
role: str
|
||||
role: str = "engineer"
|
||||
account_id: UUID
|
||||
account_role: str
|
||||
is_super_admin: bool = False
|
||||
is_team_admin: bool = False
|
||||
is_active: bool = True
|
||||
team_id: Optional[UUID] = None
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
@@ -54,5 +55,5 @@ class RoleUpdate(BaseModel):
|
||||
role: Literal["engineer", "viewer"]
|
||||
|
||||
|
||||
class TeamAdminUpdate(BaseModel):
|
||||
is_team_admin: bool
|
||||
class AccountRoleUpdate(BaseModel):
|
||||
account_role: str = Field(..., pattern="^(engineer|viewer)$")
|
||||
|
||||
@@ -22,5 +22,8 @@ email-validator==2.1.0
|
||||
# Rate Limiting
|
||||
slowapi==0.1.9
|
||||
|
||||
# Payments
|
||||
stripe==14.3.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
|
||||
@@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await conn.execute(sa.text("CREATE SCHEMA public"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Seed plan_limits for subscription checks
|
||||
await conn.execute(sa.text("""
|
||||
INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats)
|
||||
VALUES
|
||||
('free', 3, 20, 1, false, false, '["markdown", "text"]'),
|
||||
('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'),
|
||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
||||
"""))
|
||||
|
||||
# Create async session maker
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
|
||||
170
backend/tests/test_account_management.py
Normal file
170
backend/tests/test_account_management.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Integration tests for account management endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAccountEndpoints:
|
||||
"""Test suite for account management endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_account(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting current user's account."""
|
||||
response = await client.get("/api/v1/accounts/me", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "name" in data
|
||||
assert "display_code" in data
|
||||
assert "owner_id" in data
|
||||
assert len(data["display_code"]) == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting current user's subscription details."""
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "subscription" in data
|
||||
assert "limits" in data
|
||||
assert "usage" in data
|
||||
assert data["subscription"]["plan"] == "free"
|
||||
assert data["subscription"]["status"] == "active"
|
||||
assert data["limits"]["max_trees"] == 3
|
||||
assert data["limits"]["max_sessions_per_month"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_members(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test getting members of current user's account."""
|
||||
response = await client.get("/api/v1/accounts/me/members", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
# Current user should be in members list
|
||||
assert any(m["account_role"] == "owner" for m in data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_my_account(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test updating account name."""
|
||||
response = await client.patch(
|
||||
"/api/v1/accounts/me",
|
||||
json={"name": "Updated Account Name"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Account Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_requires_owner(self, client: AsyncClient):
|
||||
"""Test that non-owners cannot update account settings."""
|
||||
# Register two users
|
||||
owner_data = {
|
||||
"email": "owner@example.com",
|
||||
"password": "OwnerPass123!",
|
||||
"name": "Owner"
|
||||
}
|
||||
await client.post("/api/v1/auth/register", json=owner_data)
|
||||
|
||||
# Login as owner and create an invite
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "owner@example.com",
|
||||
"password": "OwnerPass123!"
|
||||
})
|
||||
owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "member@example.com", "role": "engineer"},
|
||||
headers=owner_headers
|
||||
)
|
||||
assert invite_resp.status_code == 201
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register member with account invite code
|
||||
member_data = {
|
||||
"email": "member@example.com",
|
||||
"password": "MemberPass123!",
|
||||
"name": "Member",
|
||||
"account_invite_code": invite_code
|
||||
}
|
||||
reg_resp = await client.post("/api/v1/auth/register", json=member_data)
|
||||
assert reg_resp.status_code == 201
|
||||
assert reg_resp.json()["account_role"] == "engineer"
|
||||
|
||||
# Login as member
|
||||
member_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "member@example.com",
|
||||
"password": "MemberPass123!"
|
||||
})
|
||||
member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"}
|
||||
|
||||
# Member should not be able to update account
|
||||
response = await client.patch(
|
||||
"/api/v1/accounts/me",
|
||||
json={"name": "Hacked Name"},
|
||||
headers=member_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test creating and listing account invites."""
|
||||
# Create invite
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "invitee@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["email"] == "invitee@example.com"
|
||||
assert data["role"] == "engineer"
|
||||
assert "code" in data
|
||||
|
||||
# List invites
|
||||
list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers)
|
||||
assert list_response.status_code == 200
|
||||
invites = list_response.json()
|
||||
assert len(invites) >= 1
|
||||
assert any(i["email"] == "invitee@example.com" for i in invites)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that account invite code joins user to existing account."""
|
||||
# Get current account
|
||||
account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers)
|
||||
account_id = account_resp.json()["id"]
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "joiner@example.com", "role": "viewer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register with account invite code
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "joiner@example.com",
|
||||
"password": "JoinerPass123!",
|
||||
"name": "Joiner",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
data = reg_resp.json()
|
||||
assert data["account_id"] == account_id
|
||||
assert data["account_role"] == "viewer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_invalid_invite_code(self, client: AsyncClient):
|
||||
"""Test that invalid account invite code is rejected."""
|
||||
response = await client.post("/api/v1/auth/register", json={
|
||||
"email": "bad@example.com",
|
||||
"password": "BadPassword123!",
|
||||
"name": "Bad User",
|
||||
"account_invite_code": "INVALID_CODE"
|
||||
})
|
||||
assert response.status_code == 400
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
@@ -169,3 +169,51 @@ class TestAdminEndpoints:
|
||||
log = result.scalar_one_or_none()
|
||||
assert log is not None
|
||||
assert str(log.resource_id) == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_account_role(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test changing a user's account role."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["account_role"] == "viewer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_account_role_invalid(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test that invalid account_role values are rejected."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "owner"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_log_created_on_account_role_change(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession
|
||||
):
|
||||
"""Test that changing account role creates an audit log entry."""
|
||||
user_id = test_user["user_data"]["id"]
|
||||
await client.put(
|
||||
f"/api/v1/admin/users/{user_id}/account-role",
|
||||
json={"account_role": "viewer"},
|
||||
headers=admin_auth_headers
|
||||
)
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AuditLog).where(AuditLog.action == "user.account_role_change")
|
||||
)
|
||||
log = result.scalar_one_or_none()
|
||||
assert log is not None
|
||||
assert str(log.resource_id) == user_id
|
||||
assert log.details["old_account_role"] == "owner"
|
||||
assert log.details["new_account_role"] == "viewer"
|
||||
|
||||
@@ -23,6 +23,8 @@ class TestAuthentication:
|
||||
assert data["email"] == user_data["email"]
|
||||
assert data["name"] == user_data["name"]
|
||||
assert data["role"] == "engineer"
|
||||
assert "account_id" in data
|
||||
assert data["account_role"] == "owner"
|
||||
assert "id" in data
|
||||
assert "password" not in data # Password should not be returned
|
||||
|
||||
@@ -107,6 +109,7 @@ class TestAuthentication:
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["role"] == "engineer"
|
||||
assert data["account_role"] == "owner"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_default_role_is_engineer(self, client: AsyncClient):
|
||||
@@ -121,6 +124,7 @@ class TestAuthentication:
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["role"] == "engineer"
|
||||
assert response.json()["account_role"] == "owner"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_rejects_no_uppercase(self, client: AsyncClient):
|
||||
|
||||
205
backend/tests/test_permissions_account.py
Normal file
205
backend/tests/test_permissions_account.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Integration tests for account-based permissions."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAccountPermissions:
|
||||
"""Test suite for account-based permission checks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db):
|
||||
"""Test that viewers cannot create trees."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Register a user
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "viewer@example.com",
|
||||
"password": "ViewerPass123!",
|
||||
"name": "Viewer User"
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
user_id = PyUUID(reg_resp.json()["id"])
|
||||
|
||||
# Demote to viewer via ORM
|
||||
result = await test_db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "viewer"
|
||||
await test_db.commit()
|
||||
|
||||
# Login as viewer
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "viewer@example.com",
|
||||
"password": "ViewerPass123!"
|
||||
})
|
||||
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Try to create tree
|
||||
response = await client.post("/api/v1/trees", json={
|
||||
"name": "Viewer Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"}
|
||||
}, headers=viewer_headers)
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""Test that viewers can browse/list trees."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Create a public tree as the regular user first
|
||||
await client.post("/api/v1/trees", json={
|
||||
"name": "Public Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": True
|
||||
}, headers=auth_headers)
|
||||
|
||||
# Register viewer
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "viewer2@example.com",
|
||||
"password": "ViewerPass123!",
|
||||
"name": "Viewer 2"
|
||||
})
|
||||
user_id = PyUUID(reg_resp.json()["id"])
|
||||
|
||||
result = await test_db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
user.account_role = "viewer"
|
||||
await test_db.commit()
|
||||
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "viewer2@example.com",
|
||||
"password": "ViewerPass123!"
|
||||
})
|
||||
viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Viewer can list trees
|
||||
response = await client.get("/api/v1/trees", headers=viewer_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""Test that account owner can edit trees created by account members."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
# Get owner's account
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id = me_resp.json()["account_id"]
|
||||
|
||||
# Create invite
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "engineer@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
# Register engineer in same account
|
||||
reg_resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "engineer@example.com",
|
||||
"password": "EngineerPass123!",
|
||||
"name": "Engineer",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
assert reg_resp.status_code == 201
|
||||
assert reg_resp.json()["account_id"] == account_id
|
||||
|
||||
# Login as engineer
|
||||
eng_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "engineer@example.com",
|
||||
"password": "EngineerPass123!"
|
||||
})
|
||||
eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"}
|
||||
|
||||
# Engineer creates a tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Engineer's Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}
|
||||
}, headers=eng_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Owner can edit engineer's tree
|
||||
update_resp = await client.put(
|
||||
f"/api/v1/trees/{tree_id}",
|
||||
json={"name": "Owner Updated Name"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
assert update_resp.json()["name"] == "Owner Updated Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that account members can see each other's non-public trees."""
|
||||
# Get owner's account
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id = me_resp.json()["account_id"]
|
||||
|
||||
# Owner creates a private tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Private Account Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": False
|
||||
}, headers=auth_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Create invite and add member
|
||||
invite_resp = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "teammate@example.com", "role": "engineer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
invite_code = invite_resp.json()["code"]
|
||||
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "teammate@example.com",
|
||||
"password": "TeammatePass123!",
|
||||
"name": "Teammate",
|
||||
"account_invite_code": invite_code
|
||||
})
|
||||
|
||||
mate_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "teammate@example.com",
|
||||
"password": "TeammatePass123!"
|
||||
})
|
||||
mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"}
|
||||
|
||||
# Teammate should see the private tree (same account)
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Private Account Tree"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that users from different accounts cannot see private trees."""
|
||||
# Owner creates a private tree
|
||||
tree_resp = await client.post("/api/v1/trees", json={
|
||||
"name": "Secret Tree",
|
||||
"tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"},
|
||||
"is_public": False
|
||||
}, headers=auth_headers)
|
||||
assert tree_resp.status_code == 201
|
||||
tree_id = tree_resp.json()["id"]
|
||||
|
||||
# Register a completely separate user (different account)
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "outsider@example.com",
|
||||
"password": "OutsiderPass123!",
|
||||
"name": "Outsider"
|
||||
})
|
||||
|
||||
outsider_login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "outsider@example.com",
|
||||
"password": "OutsiderPass123!"
|
||||
})
|
||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||
|
||||
# Outsider should NOT see the private tree
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||
assert response.status_code == 403
|
||||
129
backend/tests/test_subscription_limits.py
Normal file
129
backend/tests/test_subscription_limits.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for subscription limits."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class TestSubscriptionLimits:
|
||||
"""Test suite for subscription plan limits."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that free plan has tree creation limit of 3."""
|
||||
tree_template = {
|
||||
"name": "Limit Test Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test tree"
|
||||
}
|
||||
}
|
||||
|
||||
# Create trees up to the limit
|
||||
for i in range(3):
|
||||
tree_data = {**tree_template, "name": f"Tree {i+1}"}
|
||||
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}"
|
||||
|
||||
# 4th tree should fail with 402
|
||||
tree_data = {**tree_template, "name": "Tree 4 Over Limit"}
|
||||
response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert response.status_code == 402
|
||||
assert "limit" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that subscription details reflect actual usage."""
|
||||
# Check initial usage
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
initial_usage = response.json()["usage"]
|
||||
assert initial_usage["tree_count"] == 0
|
||||
|
||||
# Create a tree
|
||||
tree_data = {
|
||||
"name": "Usage Test Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test"
|
||||
}
|
||||
}
|
||||
create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers)
|
||||
assert create_resp.status_code == 201
|
||||
|
||||
# Check usage increased
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
updated_usage = response.json()["usage"]
|
||||
assert updated_usage["tree_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_super_admin_bypasses_limits(
|
||||
self, client: AsyncClient, admin_auth_headers: dict
|
||||
):
|
||||
"""Test that super admin can create trees without limit checks."""
|
||||
tree_template = {
|
||||
"name": "Admin Tree",
|
||||
"tree_structure": {
|
||||
"id": "root",
|
||||
"type": "solution",
|
||||
"title": "Test",
|
||||
"description": "Test tree"
|
||||
},
|
||||
"is_default": True # Default trees skip limit check
|
||||
}
|
||||
|
||||
# Super admin creating default trees should always work
|
||||
for i in range(5):
|
||||
tree_data = {**tree_template, "name": f"Admin Tree {i+1}"}
|
||||
response = await client.post(
|
||||
"/api/v1/trees", json=tree_data, headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Test that free plan limits are correct."""
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
limits = response.json()["limits"]
|
||||
assert limits["plan"] == "free"
|
||||
assert limits["max_trees"] == 3
|
||||
assert limits["max_sessions_per_month"] == 20
|
||||
assert limits["max_users"] == 1
|
||||
assert limits["custom_branding"] is False
|
||||
assert limits["priority_support"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgraded_plan_has_higher_limits(
|
||||
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
|
||||
):
|
||||
"""Test that upgrading plan increases limits."""
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
|
||||
# Get the user's subscription and upgrade it
|
||||
me_resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id_str = me_resp.json()["account_id"]
|
||||
|
||||
from uuid import UUID
|
||||
account_id = UUID(account_id_str)
|
||||
result = await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account_id)
|
||||
)
|
||||
sub = result.scalar_one()
|
||||
sub.plan = "pro"
|
||||
await test_db.commit()
|
||||
|
||||
# Check limits are now pro
|
||||
response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
limits = response.json()["limits"]
|
||||
assert limits["plan"] == "pro"
|
||||
assert limits["max_trees"] == 25
|
||||
assert limits["max_sessions_per_month"] == 200
|
||||
10
frontend/package-lock.json
generated
10
frontend/package-lock.json
generated
@@ -8,6 +8,7 @@
|
||||
"name": "frontend",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"@stripe/stripe-js": "^8.7.0",
|
||||
"axios": "^1.13.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
@@ -1430,6 +1431,15 @@
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@stripe/stripe-js": {
|
||||
"version": "8.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-8.7.0.tgz",
|
||||
"integrity": "sha512-tNUerSstwNC1KuHgX4CASGO0Md3CB26IJzSXmVlSuFvhsBP4ZaEPpY4jxWOn9tfdDscuVT4Kqb8cZ2o9nLCgRQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12.16"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/babel__core": {
|
||||
"version": "7.20.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@stripe/stripe-js": "^8.7.0",
|
||||
"axios": "^1.13.4",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
|
||||
48
frontend/src/api/accounts.ts
Normal file
48
frontend/src/api/accounts.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import apiClient from './client'
|
||||
import type { Account, SubscriptionDetails, AccountMember, AccountInvite } from '@/types'
|
||||
|
||||
export const accountsApi = {
|
||||
async getMyAccount(): Promise<Account> {
|
||||
const response = await apiClient.get<Account>('/accounts/me')
|
||||
return response.data
|
||||
},
|
||||
|
||||
async getMySubscription(): Promise<SubscriptionDetails> {
|
||||
const response = await apiClient.get<SubscriptionDetails>('/accounts/me/subscription')
|
||||
return response.data
|
||||
},
|
||||
|
||||
async updateMyAccount(data: { name?: string }): Promise<Account> {
|
||||
const response = await apiClient.patch<Account>('/accounts/me', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async getMembers(): Promise<AccountMember[]> {
|
||||
const response = await apiClient.get<AccountMember[]>('/accounts/me/members')
|
||||
return response.data
|
||||
},
|
||||
|
||||
async updateMemberRole(userId: string, role: string): Promise<AccountMember> {
|
||||
const response = await apiClient.patch<AccountMember>(
|
||||
`/accounts/me/members/${userId}/role`,
|
||||
{ role }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async removeMember(userId: string): Promise<void> {
|
||||
await apiClient.delete(`/accounts/me/members/${userId}`)
|
||||
},
|
||||
|
||||
async createInvite(data: { email: string; role: string }): Promise<AccountInvite> {
|
||||
const response = await apiClient.post<AccountInvite>('/accounts/me/invites', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async getInvites(): Promise<AccountInvite[]> {
|
||||
const response = await apiClient.get<AccountInvite[]>('/accounts/me/invites')
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
|
||||
export default accountsApi
|
||||
@@ -2,9 +2,9 @@ import apiClient from './client'
|
||||
import type { Category, CategoryListItem, CategoryCreate, CategoryUpdate } from '@/types'
|
||||
|
||||
export const categoriesApi = {
|
||||
async list(includeInactive = false, teamOnly = false): Promise<CategoryListItem[]> {
|
||||
async list(includeInactive = false, accountOnly = false): Promise<CategoryListItem[]> {
|
||||
const response = await apiClient.get<CategoryListItem[]>('/categories', {
|
||||
params: { include_inactive: includeInactive, team_only: teamOnly },
|
||||
params: { include_inactive: includeInactive, account_only: accountOnly },
|
||||
})
|
||||
return response.data
|
||||
},
|
||||
|
||||
@@ -8,3 +8,4 @@ export { default as categoriesApi } from './categories'
|
||||
export { default as foldersApi } from './folders'
|
||||
export { default as stepsApi } from './steps'
|
||||
export { default as stepCategoriesApi } from './stepCategories'
|
||||
export { default as accountsApi } from './accounts'
|
||||
|
||||
@@ -2,16 +2,16 @@ import apiClient from './client'
|
||||
import type { Tag, TagListItem, TagCreate, TagAssignment } from '@/types'
|
||||
|
||||
export const tagsApi = {
|
||||
async list(includeTeam = true): Promise<TagListItem[]> {
|
||||
async list(includeAccount = true): Promise<TagListItem[]> {
|
||||
const response = await apiClient.get<TagListItem[]>('/tags', {
|
||||
params: { include_team: includeTeam },
|
||||
params: { include_account: includeAccount },
|
||||
})
|
||||
return response.data
|
||||
},
|
||||
|
||||
async search(query: string, limit = 10, includeTeam = true): Promise<TagListItem[]> {
|
||||
async search(query: string, limit = 10, includeAccount = true): Promise<TagListItem[]> {
|
||||
const response = await apiClient.get<TagListItem[]>('/tags/search', {
|
||||
params: { q: query, limit, include_team: includeTeam },
|
||||
params: { q: query, limit, include_account: includeAccount },
|
||||
})
|
||||
return response.data
|
||||
},
|
||||
|
||||
32
frontend/src/components/common/UpgradePrompt.tsx
Normal file
32
frontend/src/components/common/UpgradePrompt.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useSubscription } from '@/hooks/useSubscription'
|
||||
|
||||
interface UpgradePromptProps {
|
||||
feature: string // e.g., "create more trees", "start more sessions"
|
||||
className?: string
|
||||
}
|
||||
|
||||
export function UpgradePrompt({ feature, className }: UpgradePromptProps) {
|
||||
const { plan } = useSubscription()
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
'rounded-lg border border-yellow-500/30 bg-yellow-500/10 p-4',
|
||||
className
|
||||
)}>
|
||||
<h3 className="font-semibold text-foreground">Plan Limit Reached</h3>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Your {plan} plan doesn't allow you to {feature}. Upgrade your plan to continue.
|
||||
</p>
|
||||
<button
|
||||
className={cn(
|
||||
'mt-3 rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
|
||||
'hover:bg-primary/90'
|
||||
)}
|
||||
onClick={() => window.location.href = '/account'}
|
||||
>
|
||||
View Plans
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -49,6 +49,7 @@ export function AppLayout() {
|
||||
const navItems = [
|
||||
{ path: '/trees', label: 'Trees' },
|
||||
{ path: '/sessions', label: 'Sessions' },
|
||||
{ path: '/account', label: 'Account' },
|
||||
{ path: '/settings', label: 'Settings' },
|
||||
]
|
||||
|
||||
@@ -98,12 +99,12 @@ export function AppLayout() {
|
||||
className={cn(
|
||||
'hidden rounded-full px-2 py-0.5 text-xs font-medium sm:inline-block',
|
||||
effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400',
|
||||
effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
|
||||
effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
|
||||
effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400'
|
||||
)}
|
||||
>
|
||||
{effectiveRole === 'super_admin' ? 'Super Admin' :
|
||||
effectiveRole === 'team_admin' ? 'Team Admin' :
|
||||
effectiveRole === 'owner' ? 'Owner' :
|
||||
'Viewer'}
|
||||
</span>
|
||||
)}
|
||||
@@ -158,12 +159,12 @@ export function AppLayout() {
|
||||
className={cn(
|
||||
'mt-1 inline-block rounded-full px-2 py-0.5 text-xs font-medium',
|
||||
effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400',
|
||||
effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
|
||||
effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400',
|
||||
effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400'
|
||||
)}
|
||||
>
|
||||
{effectiveRole === 'super_admin' ? 'Super Admin' :
|
||||
effectiveRole === 'team_admin' ? 'Team Admin' :
|
||||
effectiveRole === 'owner' ? 'Owner' :
|
||||
'Viewer'}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -27,7 +27,7 @@ export function ProtectedRoute({ requiredRole, children }: ProtectedRouteProps)
|
||||
if (requiredRole) {
|
||||
const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
|
||||
super_admin: 4,
|
||||
team_admin: 3,
|
||||
owner: 3,
|
||||
engineer: 2,
|
||||
viewer: 1,
|
||||
}
|
||||
|
||||
24
frontend/src/components/subscription/CheckoutButton.tsx
Normal file
24
frontend/src/components/subscription/CheckoutButton.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
interface CheckoutButtonProps {
|
||||
plan: 'pro' | 'team'
|
||||
className?: string
|
||||
}
|
||||
|
||||
export function CheckoutButton({ plan, className }: CheckoutButtonProps) {
|
||||
const planLabels = { pro: 'Pro', team: 'Team' }
|
||||
|
||||
return (
|
||||
<button
|
||||
disabled
|
||||
title="Billing coming soon"
|
||||
className={cn(
|
||||
'rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
|
||||
'disabled:opacity-50 disabled:cursor-not-allowed',
|
||||
className
|
||||
)}
|
||||
>
|
||||
Upgrade to {planLabels[plan]} (Coming Soon)
|
||||
</button>
|
||||
)
|
||||
}
|
||||
@@ -119,7 +119,7 @@ export function TreeMetadataForm() {
|
||||
{categories.map((cat) => (
|
||||
<option key={cat.id} value={cat.id}>
|
||||
{cat.name}
|
||||
{cat.team_id ? ' (Team)' : ''}
|
||||
{cat.account_id ? ' (Account)' : ''}
|
||||
</option>
|
||||
))}
|
||||
<option value="__custom__">+ Add custom category</option>
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
/**
|
||||
* Centralized permissions hook for ResolutionFlow.
|
||||
*
|
||||
* Role hierarchy: super_admin > team_admin > engineer > viewer
|
||||
* Role hierarchy: super_admin > owner > engineer > viewer
|
||||
*
|
||||
* Mirrors backend logic in backend/app/core/permissions.py
|
||||
*/
|
||||
import { useAuthStore } from '@/store/authStore'
|
||||
import type { User } from '@/types'
|
||||
|
||||
export type EffectiveRole = 'super_admin' | 'team_admin' | 'engineer' | 'viewer'
|
||||
export type EffectiveRole = 'super_admin' | 'owner' | 'engineer' | 'viewer'
|
||||
|
||||
const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
|
||||
super_admin: 4,
|
||||
team_admin: 3,
|
||||
owner: 3,
|
||||
engineer: 2,
|
||||
viewer: 1,
|
||||
}
|
||||
@@ -20,7 +20,7 @@ const ROLE_HIERARCHY: Record<EffectiveRole, number> = {
|
||||
function getEffectiveRole(user: User | null): EffectiveRole {
|
||||
if (!user) return 'viewer'
|
||||
if (user.is_super_admin) return 'super_admin'
|
||||
if (user.is_team_admin && user.team_id) return 'team_admin'
|
||||
if (user.account_role === 'owner') return 'owner'
|
||||
return user.role as EffectiveRole
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ export function usePermissions() {
|
||||
return {
|
||||
effectiveRole,
|
||||
isSuperAdmin: effectiveRole === 'super_admin',
|
||||
isTeamAdmin: effectiveRole === 'team_admin' || effectiveRole === 'super_admin',
|
||||
isAccountOwner: effectiveRole === 'owner' || effectiveRole === 'super_admin',
|
||||
isEngineer: hasMinimumRole(user, 'engineer'),
|
||||
isViewer: effectiveRole === 'viewer',
|
||||
|
||||
@@ -46,12 +46,12 @@ export function usePermissions() {
|
||||
canCreateSteps: hasMinimumRole(user, 'engineer'),
|
||||
|
||||
// Resource-specific checks
|
||||
canEditTree: (tree: { author_id: string | null; team_id?: string | null }) => {
|
||||
canEditTree: (tree: { author_id: string | null; account_id?: string | null }) => {
|
||||
if (!user) return false
|
||||
if (user.is_super_admin) return true
|
||||
if (!hasMinimumRole(user, 'engineer')) return false
|
||||
if (tree.author_id && tree.author_id === user.id) return true
|
||||
if (user.is_team_admin && tree.team_id === user.team_id && user.team_id) return true
|
||||
if (user.account_role === 'owner' && tree.account_id === user.account_id && user.account_id) return true
|
||||
return false
|
||||
},
|
||||
|
||||
@@ -68,8 +68,8 @@ export function usePermissions() {
|
||||
},
|
||||
|
||||
// Management permissions
|
||||
canManageCategories: hasMinimumRole(user, 'team_admin'),
|
||||
canManageCategories: hasMinimumRole(user, 'owner'),
|
||||
canManageGlobalCategories: effectiveRole === 'super_admin',
|
||||
canManageTeam: effectiveRole === 'super_admin' || (effectiveRole === 'team_admin'),
|
||||
canManageAccount: effectiveRole === 'super_admin' || effectiveRole === 'owner',
|
||||
}
|
||||
}
|
||||
|
||||
45
frontend/src/hooks/useSubscription.ts
Normal file
45
frontend/src/hooks/useSubscription.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { useAuthStore } from '@/store/authStore'
|
||||
|
||||
export function useSubscription() {
|
||||
const subscription = useAuthStore((s) => s.subscription)
|
||||
|
||||
const plan = subscription?.subscription.plan ?? 'free'
|
||||
const limits = subscription?.limits ?? null
|
||||
const usage = subscription?.usage ?? null
|
||||
const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing'
|
||||
|
||||
const isPaidPlan = plan === 'pro' || plan === 'team'
|
||||
|
||||
const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => {
|
||||
if (!limits) return false
|
||||
return limits[feature]
|
||||
}
|
||||
|
||||
const isAtTreeLimit = (): boolean => {
|
||||
if (!limits || !usage) return false
|
||||
if (limits.max_trees === null) return false // unlimited
|
||||
return usage.tree_count >= limits.max_trees
|
||||
}
|
||||
|
||||
const isAtSessionLimit = (): boolean => {
|
||||
if (!limits || !usage) return false
|
||||
if (limits.max_sessions_per_month === null) return false
|
||||
return usage.session_count_this_month >= limits.max_sessions_per_month
|
||||
}
|
||||
|
||||
const formatLimit = (value: number | null): string => {
|
||||
return value === null ? 'Unlimited' : String(value)
|
||||
}
|
||||
|
||||
return {
|
||||
plan,
|
||||
limits,
|
||||
usage,
|
||||
isActive,
|
||||
isPaidPlan,
|
||||
canUseFeature,
|
||||
isAtTreeLimit,
|
||||
isAtSessionLimit,
|
||||
formatLimit,
|
||||
}
|
||||
}
|
||||
494
frontend/src/pages/AccountSettingsPage.tsx
Normal file
494
frontend/src/pages/AccountSettingsPage.tsx
Normal file
@@ -0,0 +1,494 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import { Building2, Users, Mail, Crown, Loader2, AlertCircle, Check, X } from 'lucide-react'
|
||||
import { accountsApi } from '@/api'
|
||||
import type { Account, AccountMember, AccountInvite } from '@/types'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { usePermissions } from '@/hooks/usePermissions'
|
||||
import { useSubscription } from '@/hooks/useSubscription'
|
||||
import { useAuthStore } from '@/store/authStore'
|
||||
import { CheckoutButton } from '@/components/subscription/CheckoutButton'
|
||||
|
||||
export function AccountSettingsPage() {
|
||||
const { isAccountOwner } = usePermissions()
|
||||
const { plan, limits, usage } = useSubscription()
|
||||
const subscription = useAuthStore((s) => s.subscription)
|
||||
|
||||
const [account, setAccount] = useState<Account | null>(null)
|
||||
const [members, setMembers] = useState<AccountMember[]>([])
|
||||
const [invites, setInvites] = useState<AccountInvite[]>([])
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
// Account name editing
|
||||
const [isEditingName, setIsEditingName] = useState(false)
|
||||
const [editedName, setEditedName] = useState('')
|
||||
const [isSavingName, setIsSavingName] = useState(false)
|
||||
|
||||
// Invite form
|
||||
const [inviteEmail, setInviteEmail] = useState('')
|
||||
const [inviteRole, setInviteRole] = useState('engineer')
|
||||
const [isInviting, setIsInviting] = useState(false)
|
||||
const [inviteError, setInviteError] = useState<string | null>(null)
|
||||
const [inviteSuccess, setInviteSuccess] = useState<string | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
loadData()
|
||||
}, [])
|
||||
|
||||
const loadData = async () => {
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
try {
|
||||
const accountData = await accountsApi.getMyAccount()
|
||||
setAccount(accountData)
|
||||
setEditedName(accountData.name)
|
||||
|
||||
if (isAccountOwner) {
|
||||
const [membersData, invitesData] = await Promise.all([
|
||||
accountsApi.getMembers(),
|
||||
accountsApi.getInvites(),
|
||||
])
|
||||
setMembers(membersData)
|
||||
setInvites(invitesData)
|
||||
}
|
||||
} catch (err) {
|
||||
setError('Failed to load account information')
|
||||
console.error(err)
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleSaveName = async () => {
|
||||
if (!editedName.trim() || editedName === account?.name) {
|
||||
setIsEditingName(false)
|
||||
return
|
||||
}
|
||||
setIsSavingName(true)
|
||||
try {
|
||||
const updated = await accountsApi.updateMyAccount({ name: editedName.trim() })
|
||||
setAccount(updated)
|
||||
setIsEditingName(false)
|
||||
} catch (err) {
|
||||
console.error('Failed to update account name:', err)
|
||||
} finally {
|
||||
setIsSavingName(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleInvite = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
if (!inviteEmail.trim()) return
|
||||
|
||||
setIsInviting(true)
|
||||
setInviteError(null)
|
||||
setInviteSuccess(null)
|
||||
try {
|
||||
await accountsApi.createInvite({ email: inviteEmail.trim(), role: inviteRole })
|
||||
setInviteSuccess(`Invitation sent to ${inviteEmail}`)
|
||||
setInviteEmail('')
|
||||
// Refresh invites list
|
||||
const invitesData = await accountsApi.getInvites()
|
||||
setInvites(invitesData)
|
||||
} catch (err) {
|
||||
setInviteError('Failed to send invitation')
|
||||
console.error(err)
|
||||
} finally {
|
||||
setIsInviting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleRemoveMember = async (userId: string) => {
|
||||
try {
|
||||
await accountsApi.removeMember(userId)
|
||||
setMembers(members.filter((m) => m.id !== userId))
|
||||
} catch (err) {
|
||||
console.error('Failed to remove member:', err)
|
||||
}
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex justify-center py-12">
|
||||
<div className="h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-6 sm:px-6 sm:py-8">
|
||||
<div className="rounded-md bg-destructive/10 p-4 text-destructive">
|
||||
<div className="flex items-center gap-2">
|
||||
<AlertCircle className="h-5 w-5" />
|
||||
{error}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const sub = subscription?.subscription
|
||||
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-6 sm:px-6 sm:py-8">
|
||||
<div className="mb-8">
|
||||
<div className="flex items-center gap-3">
|
||||
<Building2 className="h-8 w-8 text-primary" />
|
||||
<h1 className="text-2xl font-bold text-foreground sm:text-3xl">Account Settings</h1>
|
||||
</div>
|
||||
<p className="mt-2 text-muted-foreground">
|
||||
Manage your account, subscription, and team
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="max-w-3xl space-y-6">
|
||||
{/* Account Info Section */}
|
||||
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
|
||||
<h2 className="text-lg font-semibold text-card-foreground">Account Information</h2>
|
||||
|
||||
<div className="mt-4 space-y-4">
|
||||
{/* Account Name */}
|
||||
<div>
|
||||
<label className="block font-label text-sm font-medium text-card-foreground">
|
||||
Account Name
|
||||
</label>
|
||||
{isEditingName ? (
|
||||
<div className="mt-1 flex items-center gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={editedName}
|
||||
onChange={(e) => setEditedName(e.target.value)}
|
||||
className={cn(
|
||||
'flex-1 rounded-md border border-input bg-background px-3 py-2',
|
||||
'text-foreground placeholder:text-muted-foreground',
|
||||
'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
|
||||
)}
|
||||
autoFocus
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') handleSaveName()
|
||||
if (e.key === 'Escape') {
|
||||
setEditedName(account?.name ?? '')
|
||||
setIsEditingName(false)
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSaveName}
|
||||
disabled={isSavingName}
|
||||
className={cn(
|
||||
'rounded-md bg-primary p-2 text-primary-foreground',
|
||||
'hover:bg-primary/90 disabled:opacity-50'
|
||||
)}
|
||||
>
|
||||
{isSavingName ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Check className="h-4 w-4" />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setEditedName(account?.name ?? '')
|
||||
setIsEditingName(false)
|
||||
}}
|
||||
className="rounded-md border border-input p-2 text-muted-foreground hover:bg-accent"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-1 flex items-center gap-2">
|
||||
<span className="text-sm text-foreground">{account?.name}</span>
|
||||
{isAccountOwner && (
|
||||
<button
|
||||
onClick={() => setIsEditingName(true)}
|
||||
className="text-xs text-primary hover:underline"
|
||||
>
|
||||
Edit
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Display Code */}
|
||||
<div>
|
||||
<label className="block font-label text-sm font-medium text-card-foreground">
|
||||
Display Code
|
||||
</label>
|
||||
<p className="mt-1 text-sm font-mono text-muted-foreground">
|
||||
{account?.display_code}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Subscription Section */}
|
||||
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
|
||||
<h2 className="text-lg font-semibold text-card-foreground">Subscription</h2>
|
||||
|
||||
<div className="mt-4 space-y-4">
|
||||
{/* Plan & Status */}
|
||||
<div className="flex items-center gap-3">
|
||||
<span
|
||||
className={cn(
|
||||
'inline-flex items-center gap-1.5 rounded-full px-3 py-1 text-sm font-medium',
|
||||
plan === 'free' && 'bg-secondary text-secondary-foreground',
|
||||
plan === 'pro' && 'bg-primary/10 text-primary',
|
||||
plan === 'team' && 'bg-primary/20 text-primary'
|
||||
)}
|
||||
>
|
||||
<Crown className="h-3.5 w-3.5" />
|
||||
{plan.charAt(0).toUpperCase() + plan.slice(1)} Plan
|
||||
</span>
|
||||
{sub && (
|
||||
<span
|
||||
className={cn(
|
||||
'inline-flex rounded-full px-2.5 py-0.5 text-xs font-medium',
|
||||
sub.status === 'active' && 'bg-green-500/10 text-green-600',
|
||||
sub.status === 'trialing' && 'bg-blue-500/10 text-blue-600',
|
||||
sub.status === 'past_due' && 'bg-yellow-500/10 text-yellow-600',
|
||||
sub.status === 'canceled' && 'bg-destructive/10 text-destructive',
|
||||
sub.status === 'orphaned' && 'bg-muted text-muted-foreground'
|
||||
)}
|
||||
>
|
||||
{sub.status.charAt(0).toUpperCase() + sub.status.slice(1).replace('_', ' ')}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{sub?.current_period_end && (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Current period ends: {new Date(sub.current_period_end).toLocaleDateString()}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Usage Stats */}
|
||||
{limits && usage && (
|
||||
<div className="mt-4 grid gap-3 sm:grid-cols-3">
|
||||
<UsageStat
|
||||
label="Trees"
|
||||
current={usage.tree_count}
|
||||
max={limits.max_trees}
|
||||
/>
|
||||
<UsageStat
|
||||
label="Sessions / month"
|
||||
current={usage.session_count_this_month}
|
||||
max={limits.max_sessions_per_month}
|
||||
/>
|
||||
<UsageStat
|
||||
label="Team members"
|
||||
current={usage.user_count}
|
||||
max={limits.max_users}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Upgrade buttons */}
|
||||
{plan === 'free' && (
|
||||
<div className="mt-4 flex gap-3">
|
||||
<CheckoutButton plan="pro" />
|
||||
<CheckoutButton plan="team" />
|
||||
</div>
|
||||
)}
|
||||
{plan === 'pro' && (
|
||||
<div className="mt-4">
|
||||
<CheckoutButton plan="team" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Team Members Section (owners only) */}
|
||||
{isAccountOwner && (
|
||||
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
|
||||
<div className="flex items-center gap-2">
|
||||
<Users className="h-5 w-5 text-primary" />
|
||||
<h2 className="text-lg font-semibold text-card-foreground">Team Members</h2>
|
||||
</div>
|
||||
|
||||
{members.length === 0 ? (
|
||||
<p className="mt-4 text-sm text-muted-foreground">No team members yet.</p>
|
||||
) : (
|
||||
<div className="mt-4 divide-y divide-border">
|
||||
{members.map((member) => (
|
||||
<div
|
||||
key={member.id}
|
||||
className="flex items-center justify-between py-3 first:pt-0 last:pb-0"
|
||||
>
|
||||
<div>
|
||||
<p className="text-sm font-medium text-foreground">{member.name}</p>
|
||||
<p className="text-xs text-muted-foreground">{member.email}</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<span
|
||||
className={cn(
|
||||
'rounded-full px-2.5 py-0.5 text-xs font-medium',
|
||||
member.account_role === 'owner' && 'bg-primary/10 text-primary',
|
||||
member.account_role === 'engineer' && 'bg-secondary text-secondary-foreground',
|
||||
member.account_role === 'viewer' && 'bg-muted text-muted-foreground'
|
||||
)}
|
||||
>
|
||||
{member.account_role}
|
||||
</span>
|
||||
{!member.is_active && (
|
||||
<span className="rounded-full bg-destructive/10 px-2 py-0.5 text-xs text-destructive">
|
||||
Inactive
|
||||
</span>
|
||||
)}
|
||||
{member.account_role !== 'owner' && (
|
||||
<button
|
||||
onClick={() => handleRemoveMember(member.id)}
|
||||
className="text-muted-foreground hover:text-destructive"
|
||||
title="Remove member"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Invite Member Section (owners only) */}
|
||||
{isAccountOwner && (
|
||||
<div className="rounded-lg border border-border bg-card p-4 shadow-sm sm:p-6">
|
||||
<div className="flex items-center gap-2">
|
||||
<Mail className="h-5 w-5 text-primary" />
|
||||
<h2 className="text-lg font-semibold text-card-foreground">Invite Member</h2>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleInvite} className="mt-4 space-y-3">
|
||||
<div className="flex flex-col gap-3 sm:flex-row">
|
||||
<input
|
||||
type="email"
|
||||
placeholder="Email address"
|
||||
value={inviteEmail}
|
||||
onChange={(e) => setInviteEmail(e.target.value)}
|
||||
required
|
||||
className={cn(
|
||||
'flex-1 rounded-md border border-input bg-background px-3 py-2',
|
||||
'text-foreground placeholder:text-muted-foreground',
|
||||
'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
|
||||
)}
|
||||
/>
|
||||
<select
|
||||
value={inviteRole}
|
||||
onChange={(e) => setInviteRole(e.target.value)}
|
||||
className={cn(
|
||||
'rounded-md border border-input bg-background px-3 py-2',
|
||||
'text-foreground focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary'
|
||||
)}
|
||||
>
|
||||
<option value="engineer">Engineer</option>
|
||||
<option value="viewer">Viewer</option>
|
||||
</select>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isInviting || !inviteEmail.trim()}
|
||||
className={cn(
|
||||
'rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground',
|
||||
'hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed'
|
||||
)}
|
||||
>
|
||||
{isInviting ? (
|
||||
<span className="flex items-center gap-2">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
Sending...
|
||||
</span>
|
||||
) : (
|
||||
'Send Invite'
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{inviteError && (
|
||||
<p className="text-sm text-destructive">{inviteError}</p>
|
||||
)}
|
||||
{inviteSuccess && (
|
||||
<p className="text-sm text-green-600">{inviteSuccess}</p>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* Pending Invites */}
|
||||
{invites.length > 0 && (
|
||||
<div className="mt-6">
|
||||
<h3 className="text-sm font-medium text-card-foreground">Pending Invites</h3>
|
||||
<div className="mt-2 divide-y divide-border">
|
||||
{invites
|
||||
.filter((inv) => !inv.used_at)
|
||||
.map((invite) => (
|
||||
<div
|
||||
key={invite.id}
|
||||
className="flex items-center justify-between py-2"
|
||||
>
|
||||
<div>
|
||||
<p className="text-sm text-foreground">{invite.email}</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Expires {new Date(invite.expires_at).toLocaleDateString()}
|
||||
</p>
|
||||
</div>
|
||||
<span className="rounded-full bg-secondary px-2.5 py-0.5 text-xs text-secondary-foreground">
|
||||
{invite.role}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
/** Small helper component for usage stat display */
|
||||
function UsageStat({
|
||||
label,
|
||||
current,
|
||||
max,
|
||||
}: {
|
||||
label: string
|
||||
current: number
|
||||
max: number | null
|
||||
}) {
|
||||
const isUnlimited = max === null
|
||||
const percentage = isUnlimited ? 0 : Math.min((current / max) * 100, 100)
|
||||
const isNearLimit = !isUnlimited && percentage >= 80
|
||||
const isAtLimit = !isUnlimited && current >= max
|
||||
|
||||
return (
|
||||
<div className="rounded-md border border-border bg-background p-3">
|
||||
<p className="text-xs font-medium text-muted-foreground">{label}</p>
|
||||
<p
|
||||
className={cn(
|
||||
'mt-1 text-lg font-semibold',
|
||||
isAtLimit ? 'text-destructive' : isNearLimit ? 'text-yellow-600' : 'text-foreground'
|
||||
)}
|
||||
>
|
||||
{current}
|
||||
<span className="text-sm font-normal text-muted-foreground">
|
||||
{' '}/ {isUnlimited ? 'Unlimited' : max}
|
||||
</span>
|
||||
</p>
|
||||
{!isUnlimited && (
|
||||
<div className="mt-2 h-1.5 overflow-hidden rounded-full bg-muted">
|
||||
<div
|
||||
className={cn(
|
||||
'h-full rounded-full transition-all',
|
||||
isAtLimit ? 'bg-destructive' : isNearLimit ? 'bg-yellow-500' : 'bg-primary'
|
||||
)}
|
||||
style={{ width: `${percentage}%` }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default AccountSettingsPage
|
||||
@@ -102,7 +102,7 @@ export function TreeEditorPage() {
|
||||
setLoading(true)
|
||||
try {
|
||||
const tree = await treesApi.get(id)
|
||||
if (!canEditTree({ author_id: tree.author_id, team_id: tree.team_id })) {
|
||||
if (!canEditTree({ author_id: tree.author_id, account_id: tree.account_id })) {
|
||||
navigate('/trees')
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,6 +137,7 @@ export function TreeLibraryPage() {
|
||||
try {
|
||||
await treesApi.delete(treeToDelete.id)
|
||||
setTrees(trees.filter((t) => t.id !== treeToDelete.id))
|
||||
window.dispatchEvent(new Event('folder-changed'))
|
||||
} catch (err) {
|
||||
console.error('Failed to delete tree:', err)
|
||||
setError('Failed to delete tree')
|
||||
@@ -352,7 +353,7 @@ export function TreeLibraryPage() {
|
||||
</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<AddToFolderMenu treeId={tree.id} onFolderCreated={handleCreateFolder} />
|
||||
{canEditTree({ author_id: tree.author_id, team_id: tree.team_id }) && (
|
||||
{canEditTree({ author_id: tree.author_id, account_id: tree.account_id }) && (
|
||||
<Link
|
||||
to={`/trees/${tree.id}/edit`}
|
||||
className={cn(
|
||||
|
||||
@@ -6,3 +6,4 @@ export { default as TreeEditorPage } from './TreeEditorPage'
|
||||
export { default as SessionHistoryPage } from './SessionHistoryPage'
|
||||
export { default as SessionDetailPage } from './SessionDetailPage'
|
||||
export { default as SettingsPage } from './SettingsPage'
|
||||
export { default as AccountSettingsPage } from './AccountSettingsPage'
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
SessionHistoryPage,
|
||||
SessionDetailPage,
|
||||
SettingsPage,
|
||||
AccountSettingsPage,
|
||||
} from '@/pages'
|
||||
|
||||
export const router = createBrowserRouter([
|
||||
@@ -64,6 +65,10 @@ export const router = createBrowserRouter([
|
||||
path: 'settings',
|
||||
element: <SettingsPage />,
|
||||
},
|
||||
{
|
||||
path: 'account',
|
||||
element: <AccountSettingsPage />,
|
||||
},
|
||||
],
|
||||
},
|
||||
])
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist } from 'zustand/middleware'
|
||||
import type { User, Token, UserCreate, UserLogin } from '@/types'
|
||||
import type { User, Token, UserCreate, UserLogin, Account, SubscriptionDetails } from '@/types'
|
||||
import { authApi } from '@/api'
|
||||
import { apiClient } from '@/api'
|
||||
|
||||
interface AuthState {
|
||||
user: User | null
|
||||
token: Token | null
|
||||
account: Account | null
|
||||
subscription: SubscriptionDetails | null
|
||||
isAuthenticated: boolean
|
||||
isLoading: boolean
|
||||
error: string | null
|
||||
@@ -25,6 +28,8 @@ export const useAuthStore = create<AuthState>()(
|
||||
(set, get) => ({
|
||||
user: null,
|
||||
token: null,
|
||||
account: null,
|
||||
subscription: null,
|
||||
isAuthenticated: false,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
@@ -70,15 +75,30 @@ export const useAuthStore = create<AuthState>()(
|
||||
} finally {
|
||||
localStorage.removeItem('access_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
set({ user: null, token: null, isAuthenticated: false, error: null })
|
||||
set({ user: null, token: null, account: null, subscription: null, isAuthenticated: false, error: null })
|
||||
}
|
||||
},
|
||||
|
||||
fetchUser: async () => {
|
||||
set({ isLoading: true })
|
||||
try {
|
||||
const user = await authApi.me()
|
||||
set({ user, isLoading: false })
|
||||
const [userResult, accountResult, subscriptionResult] = await Promise.allSettled([
|
||||
authApi.me(),
|
||||
apiClient.get<Account>('/accounts/me').then(r => r.data),
|
||||
apiClient.get<SubscriptionDetails>('/accounts/me/subscription').then(r => r.data),
|
||||
])
|
||||
|
||||
const user = userResult.status === 'fulfilled' ? userResult.value : null
|
||||
const account = accountResult.status === 'fulfilled' ? accountResult.value : null
|
||||
const subscription = subscriptionResult.status === 'fulfilled' ? subscriptionResult.value : null
|
||||
|
||||
if (!user) {
|
||||
// User fetch failed — propagate the error
|
||||
const reason = userResult.status === 'rejected' ? userResult.reason : new Error('Failed to fetch user')
|
||||
throw reason
|
||||
}
|
||||
|
||||
set({ user, account, subscription, isLoading: false })
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : 'Failed to fetch user'
|
||||
set({ error: message, isLoading: false })
|
||||
@@ -95,6 +115,8 @@ export const useAuthStore = create<AuthState>()(
|
||||
partialize: (state) => ({
|
||||
token: state.token,
|
||||
isAuthenticated: state.isAuthenticated,
|
||||
account: state.account,
|
||||
subscription: state.subscription,
|
||||
}),
|
||||
}
|
||||
)
|
||||
|
||||
62
frontend/src/types/account.ts
Normal file
62
frontend/src/types/account.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
export interface Account {
|
||||
id: string
|
||||
name: string
|
||||
display_code: string
|
||||
owner_id: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface Subscription {
|
||||
id: string
|
||||
account_id: string
|
||||
plan: 'free' | 'pro' | 'team'
|
||||
status: 'active' | 'past_due' | 'canceled' | 'trialing' | 'orphaned'
|
||||
current_period_start: string | null
|
||||
current_period_end: string | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface PlanLimits {
|
||||
plan: string
|
||||
max_trees: number | null
|
||||
max_sessions_per_month: number | null
|
||||
max_users: number | null
|
||||
custom_branding: boolean
|
||||
priority_support: boolean
|
||||
export_formats: string[]
|
||||
}
|
||||
|
||||
export interface SubscriptionDetails {
|
||||
subscription: Subscription
|
||||
limits: PlanLimits
|
||||
usage: {
|
||||
tree_count: number
|
||||
session_count_this_month: number
|
||||
user_count: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface AccountInvite {
|
||||
id: string
|
||||
account_id: string
|
||||
email: string
|
||||
role: 'engineer' | 'viewer'
|
||||
code: string
|
||||
invited_by_id: string
|
||||
accepted_by_id: string | null
|
||||
expires_at: string
|
||||
used_at: string | null
|
||||
created_at: string
|
||||
}
|
||||
|
||||
export interface AccountMember {
|
||||
id: string
|
||||
email: string
|
||||
name: string
|
||||
account_role: string
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
last_login: string | null
|
||||
}
|
||||
@@ -5,7 +5,7 @@ export interface Category {
|
||||
name: string
|
||||
slug: string
|
||||
description: string | null
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
display_order: number
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
@@ -18,7 +18,7 @@ export interface CategoryListItem {
|
||||
name: string
|
||||
slug: string
|
||||
description: string | null
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
display_order: number
|
||||
is_active: boolean
|
||||
tree_count: number
|
||||
@@ -27,7 +27,7 @@ export interface CategoryListItem {
|
||||
export interface CategoryCreate {
|
||||
name: string
|
||||
description?: string | null
|
||||
team_id?: string | null
|
||||
account_id?: string | null
|
||||
}
|
||||
|
||||
export interface CategoryUpdate {
|
||||
|
||||
@@ -7,6 +7,7 @@ export * from './tag'
|
||||
export * from './category'
|
||||
export * from './folder'
|
||||
export * from './step'
|
||||
export type { Account, Subscription, PlanLimits, SubscriptionDetails, AccountInvite, AccountMember } from './account'
|
||||
|
||||
// API response wrapper types
|
||||
export interface PaginatedResponse<T> {
|
||||
|
||||
@@ -56,7 +56,7 @@ export interface StepCategory {
|
||||
name: string
|
||||
description?: string
|
||||
display_order: number
|
||||
team_id?: string
|
||||
account_id?: string
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ export interface Tag {
|
||||
id: string
|
||||
name: string
|
||||
slug: string
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
usage_count: number
|
||||
created_at: string
|
||||
}
|
||||
@@ -13,13 +13,13 @@ export interface TagListItem {
|
||||
id: string
|
||||
name: string
|
||||
slug: string
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
usage_count: number
|
||||
}
|
||||
|
||||
export interface TagCreate {
|
||||
name: string
|
||||
team_id?: string | null
|
||||
account_id?: string | null
|
||||
}
|
||||
|
||||
export interface TagAssignment {
|
||||
|
||||
@@ -67,7 +67,7 @@ export interface Tree {
|
||||
tags: string[]
|
||||
tree_structure: TreeStructure
|
||||
author_id: string | null
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
is_active: boolean
|
||||
is_public: boolean
|
||||
is_default: boolean
|
||||
@@ -86,7 +86,7 @@ export interface TreeListItem {
|
||||
category_info: CategoryInfo | null
|
||||
tags: string[]
|
||||
author_id: string | null
|
||||
team_id: string | null
|
||||
account_id: string | null
|
||||
is_active: boolean
|
||||
is_public: boolean
|
||||
is_default: boolean
|
||||
|
||||
@@ -6,8 +6,8 @@ export interface User {
|
||||
name: string
|
||||
role: UserRole
|
||||
is_super_admin: boolean
|
||||
is_team_admin: boolean
|
||||
team_id: string | null
|
||||
account_id: string
|
||||
account_role: 'owner' | 'engineer' | 'viewer'
|
||||
created_at: string
|
||||
last_login: string | null
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user