From 4ccb93ee31c6ca5ef52249e4d98c418e1a72f009 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:38:47 -0500 Subject: [PATCH 1/5] feat: add account-based subscription model with migrations Transition from team-based to account-based multi-tenancy (Free/Pro/Team). Migrations 016-020 create accounts, subscriptions, plan_limits, and account_invites tables, then migrate existing users and content FKs. New models: Account, Subscription, PlanLimits, AccountInvite. Updated models add account_id alongside existing team_id (coexistence for safe two-PR deployment). Permissions and deps refactored for account_role instead of is_team_admin. Co-Authored-By: Claude Opus 4.6 --- .../versions/016_add_subscription_tables.py | 110 +++++++++++ .../versions/017_add_account_id_to_users.py | 31 +++ .../versions/018_migrate_users_to_accounts.py | 187 ++++++++++++++++++ .../019_migrate_team_fks_to_account.py | 56 ++++++ .../020_finalize_account_migration.py | 105 ++++++++++ backend/app/api/deps.py | 40 +++- backend/app/core/config.py | 10 + backend/app/core/permissions.py | 58 +++--- backend/app/core/stripe_handlers.py | 37 ++++ backend/app/core/subscriptions.py | 113 +++++++++++ backend/app/models/__init__.py | 8 + backend/app/models/account.py | 38 ++++ backend/app/models/account_invite.py | 48 +++++ backend/app/models/category.py | 12 +- backend/app/models/plan_limits.py | 16 ++ backend/app/models/step_category.py | 12 +- backend/app/models/step_library.py | 8 + backend/app/models/subscription.py | 39 ++++ backend/app/models/tag.py | 12 +- backend/app/models/tree.py | 8 + backend/app/models/user.py | 29 ++- backend/requirements.txt | 3 + 22 files changed, 933 insertions(+), 47 deletions(-) create mode 100644 backend/alembic/versions/016_add_subscription_tables.py create mode 100644 backend/alembic/versions/017_add_account_id_to_users.py create mode 100644 backend/alembic/versions/018_migrate_users_to_accounts.py create mode 100644 backend/alembic/versions/019_migrate_team_fks_to_account.py create mode 100644 backend/alembic/versions/020_finalize_account_migration.py create mode 100644 backend/app/core/stripe_handlers.py create mode 100644 backend/app/core/subscriptions.py create mode 100644 backend/app/models/account.py create mode 100644 backend/app/models/account_invite.py create mode 100644 backend/app/models/plan_limits.py create mode 100644 backend/app/models/subscription.py diff --git a/backend/alembic/versions/016_add_subscription_tables.py b/backend/alembic/versions/016_add_subscription_tables.py new file mode 100644 index 00000000..960726cc --- /dev/null +++ b/backend/alembic/versions/016_add_subscription_tables.py @@ -0,0 +1,110 @@ +"""add accounts, subscriptions, plan_limits, and account_invites tables + +Revision ID: 016 +Revises: 015 +Create Date: 2026-02-07 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + + +# revision identifiers, used by Alembic. +revision: str = '016' +down_revision: Union[str, None] = '015' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 1. accounts table + op.create_table( + 'accounts', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('display_code', sa.String(8), nullable=False), + sa.Column('owner_id', UUID(as_uuid=True), nullable=True), # nullable until user created + sa.Column('stripe_customer_id', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('display_code', name='uq_accounts_display_code'), + ) + + # 2. subscriptions table + op.create_table( + 'subscriptions', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('stripe_subscription_id', sa.String(255), nullable=True), + sa.Column('stripe_price_id', sa.String(255), nullable=True), + sa.Column('plan', sa.String(50), nullable=False, server_default='free'), + sa.Column('billing_interval', sa.String(20), nullable=True), # 'monthly' or 'annual' + sa.Column('status', sa.String(50), nullable=False, server_default='active'), + sa.Column('seat_limit', sa.Integer, nullable=True), + sa.Column('current_period_start', sa.DateTime(timezone=True), nullable=True), + sa.Column('current_period_end', sa.DateTime(timezone=True), nullable=True), + sa.Column('cancel_at_period_end', sa.Boolean, nullable=False, server_default='false'), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'), + sa.UniqueConstraint('account_id', name='uq_subscriptions_account_id'), + ) + op.create_index('ix_subscriptions_account_id', 'subscriptions', ['account_id']) + op.create_index('ix_subscriptions_plan', 'subscriptions', ['plan']) + + # 3. plan_limits table (configuration — seeded with 3 rows) + op.create_table( + 'plan_limits', + sa.Column('plan', sa.String(50), nullable=False), + sa.Column('max_trees', sa.Integer, nullable=True), # NULL = unlimited + sa.Column('max_sessions_per_month', sa.Integer, nullable=True), + sa.Column('max_users', sa.Integer, nullable=True), + sa.Column('custom_branding', sa.Boolean, nullable=False, server_default='false'), + sa.Column('priority_support', sa.Boolean, nullable=False, server_default='false'), + sa.Column('export_formats', JSONB, nullable=False, server_default='["markdown", "text"]'), + sa.PrimaryKeyConstraint('plan'), + ) + + # Seed plan_limits + op.execute(""" + INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) + VALUES + ('free', 3, 20, 1, false, false, '["markdown", "text"]'), + ('pro', 25, 200, 1, false, true, '["markdown", "text", "html", "pdf"]'), + ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html", "pdf"]') + """) + + # 4. account_invites table + op.create_table( + 'account_invites', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('invited_by_id', UUID(as_uuid=True), nullable=False), + sa.Column('email', sa.String(255), nullable=False), + sa.Column('code', sa.String(32), nullable=False), + sa.Column('role', sa.String(50), nullable=False, server_default='engineer'), + sa.Column('accepted_by_id', UUID(as_uuid=True), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('used_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['invited_by_id'], ['users.id']), + sa.ForeignKeyConstraint(['accepted_by_id'], ['users.id']), + sa.UniqueConstraint('code', name='uq_account_invites_code'), + sa.CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'), + ) + op.create_index('ix_account_invites_account_id', 'account_invites', ['account_id']) + op.create_index('ix_account_invites_email', 'account_invites', ['email']) + + +def downgrade() -> None: + op.drop_table('account_invites') + op.drop_table('plan_limits') + op.drop_table('subscriptions') + op.drop_table('accounts') diff --git a/backend/alembic/versions/017_add_account_id_to_users.py b/backend/alembic/versions/017_add_account_id_to_users.py new file mode 100644 index 00000000..2e205a33 --- /dev/null +++ b/backend/alembic/versions/017_add_account_id_to_users.py @@ -0,0 +1,31 @@ +"""add account_id and account_role columns to users + +Revision ID: 017 +Revises: 016 +Create Date: 2026-02-07 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '017' +down_revision: Union[str, None] = '016' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('users', sa.Column('account_id', UUID(as_uuid=True), nullable=True)) + op.add_column('users', sa.Column('account_role', sa.String(50), nullable=True)) + op.create_index('ix_users_account_id', 'users', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_users_account_id', table_name='users') + op.drop_column('users', 'account_role') + op.drop_column('users', 'account_id') diff --git a/backend/alembic/versions/018_migrate_users_to_accounts.py b/backend/alembic/versions/018_migrate_users_to_accounts.py new file mode 100644 index 00000000..d15401f5 --- /dev/null +++ b/backend/alembic/versions/018_migrate_users_to_accounts.py @@ -0,0 +1,187 @@ +"""migrate existing users and teams to accounts + +Revision ID: 018 +Revises: 017 +Create Date: 2026-02-07 + +This is the most critical migration. It creates a _team_account_mapping table +for deterministic cross-migration lookups, then migrates all users to accounts. + +Three paths: + A) Teams with users → Account with deterministic owner + B) Teams with zero users → Account with owner_id=NULL, subscription status='orphaned' + C) Users without a team → Personal account, user is owner +""" +import secrets +import string +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '018' +down_revision: Union[str, None] = '017' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# Characters for display codes — exclude confusing chars +DISPLAY_CODE_CHARS = string.ascii_uppercase + string.digits +DISPLAY_CODE_CHARS = DISPLAY_CODE_CHARS.replace('0', '').replace('O', '').replace('I', '').replace('1', '').replace('L', '') + + +def _generate_display_code(existing_codes: set) -> str: + """Generate a unique 8-character display code.""" + for _ in range(100): + code = ''.join(secrets.choice(DISPLAY_CODE_CHARS) for _ in range(8)) + if code not in existing_codes: + existing_codes.add(code) + return code + raise RuntimeError("Failed to generate unique display code after 100 attempts") + + +def upgrade() -> None: + conn = op.get_bind() + + # Create mapping table for deterministic cross-migration lookups + op.create_table( + '_team_account_mapping', + sa.Column('team_id', UUID(as_uuid=True), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('owner_user_id', UUID(as_uuid=True), nullable=True), + sa.PrimaryKeyConstraint('team_id'), + ) + + existing_codes: set = set() + + # --- Path A & B: Process all teams --- + teams = conn.execute(sa.text("SELECT id, name FROM teams")).fetchall() + + for team in teams: + team_id = team[0] + team_name = team[1] + display_code = _generate_display_code(existing_codes) + + # Find deterministic owner: team admin first, then earliest user + owner_row = conn.execute(sa.text(""" + SELECT id FROM users + WHERE team_id = :tid + ORDER BY is_team_admin DESC, created_at ASC, id ASC + LIMIT 1 + """), {"tid": team_id}).fetchone() + + owner_user_id = owner_row[0] if owner_row else None + + # Create account + conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at) + VALUES (gen_random_uuid(), :name, :code, :owner_id, NOW(), NOW()) + """), {"name": team_name, "code": display_code, "owner_id": owner_user_id}) + + # Get the account we just created + account_row = conn.execute(sa.text( + "SELECT id FROM accounts WHERE display_code = :code" + ), {"code": display_code}).fetchone() + account_id = account_row[0] + + # Insert mapping + conn.execute(sa.text(""" + INSERT INTO _team_account_mapping (team_id, account_id, owner_user_id) + VALUES (:tid, :aid, :uid) + """), {"tid": team_id, "aid": account_id, "uid": owner_user_id}) + + if owner_user_id is not None: + # Path A: Team with users + # Create active subscription + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW()) + """), {"aid": account_id}) + + # Update all users in this team + # Team admins become owners, others keep their role + conn.execute(sa.text(""" + UPDATE users SET + account_id = :aid, + account_role = CASE + WHEN is_team_admin = true THEN 'owner' + ELSE role + END + WHERE team_id = :tid + """), {"aid": account_id, "tid": team_id}) + else: + # Path B: Team with zero users (orphan) + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'orphaned', NOW(), NOW()) + """), {"aid": account_id}) + + # --- Path C: Users without a team --- + teamless_users = conn.execute(sa.text( + "SELECT id, name FROM users WHERE team_id IS NULL AND account_id IS NULL" + )).fetchall() + + for user in teamless_users: + user_id = user[0] + user_name = user[1] + display_code = _generate_display_code(existing_codes) + + # Create personal account (owner_id set to NULL initially) + conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at) + VALUES (gen_random_uuid(), :name, :code, NULL, NOW(), NOW()) + """), {"name": f"{user_name}'s Account", "code": display_code}) + + account_row = conn.execute(sa.text( + "SELECT id FROM accounts WHERE display_code = :code" + ), {"code": display_code}).fetchone() + account_id = account_row[0] + + # Update user + conn.execute(sa.text(""" + UPDATE users SET account_id = :aid, account_role = 'owner' + WHERE id = :uid + """), {"aid": account_id, "uid": user_id}) + + # Set owner + conn.execute(sa.text(""" + UPDATE accounts SET owner_id = :uid WHERE id = :aid + """), {"uid": user_id, "aid": account_id}) + + # Create free subscription + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW()) + """), {"aid": account_id}) + + # --- Validation --- + orphaned_users = conn.execute(sa.text( + "SELECT COUNT(*) FROM users WHERE account_id IS NULL" + )).scalar() + if orphaned_users > 0: + raise RuntimeError( + f"Migration 018 failed validation: {orphaned_users} users still have NULL account_id" + ) + + team_count = conn.execute(sa.text("SELECT COUNT(*) FROM teams")).scalar() + mapping_count = conn.execute(sa.text("SELECT COUNT(*) FROM _team_account_mapping")).scalar() + if mapping_count != team_count: + raise RuntimeError( + f"Migration 018 failed: mapping count ({mapping_count}) != team count ({team_count})" + ) + + +def downgrade() -> None: + conn = op.get_bind() + + # Clear account data from users + conn.execute(sa.text("UPDATE users SET account_id = NULL, account_role = NULL")) + + # Delete all subscriptions and accounts created by this migration + conn.execute(sa.text("DELETE FROM subscriptions")) + conn.execute(sa.text("DELETE FROM accounts")) + + # Drop mapping table + op.drop_table('_team_account_mapping') diff --git a/backend/alembic/versions/019_migrate_team_fks_to_account.py b/backend/alembic/versions/019_migrate_team_fks_to_account.py new file mode 100644 index 00000000..4f15dc80 --- /dev/null +++ b/backend/alembic/versions/019_migrate_team_fks_to_account.py @@ -0,0 +1,56 @@ +"""add account_id to content tables and backfill from _team_account_mapping + +Revision ID: 019 +Revises: 018 +Create Date: 2026-02-07 + +Uses the _team_account_mapping table from migration 018 for deterministic +FK backfill instead of non-deterministic LIMIT 1 subqueries. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '019' +down_revision: Union[str, None] = '018' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# Tables that have team_id and need account_id +CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories'] + + +def upgrade() -> None: + conn = op.get_bind() + + for table in CONTENT_TABLES: + # Add account_id column + op.add_column(table, sa.Column('account_id', UUID(as_uuid=True), nullable=True)) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + # Backfill from mapping table (deterministic) + conn.execute(sa.text(f""" + UPDATE {table} SET account_id = m.account_id + FROM _team_account_mapping m + WHERE {table}.team_id = m.team_id + """)) + + # Validate: no rows with team_id but missing account_id + orphaned = conn.execute(sa.text(f""" + SELECT COUNT(*) FROM {table} + WHERE team_id IS NOT NULL AND account_id IS NULL + """)).scalar() + if orphaned > 0: + raise RuntimeError( + f"Migration 019 failed: {table} has {orphaned} rows with team_id but no account_id" + ) + + +def downgrade() -> None: + for table in reversed(CONTENT_TABLES): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/020_finalize_account_migration.py b/backend/alembic/versions/020_finalize_account_migration.py new file mode 100644 index 00000000..510c8abb --- /dev/null +++ b/backend/alembic/versions/020_finalize_account_migration.py @@ -0,0 +1,105 @@ +"""finalize account migration — add constraints, clean up orphans + +Revision ID: 020 +Revises: 019 +Create Date: 2026-02-07 + +Adds NOT NULL constraints, foreign keys, and CHECK constraints. +Cleans up orphan accounts (zero-user teams with no content). +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '020' +down_revision: Union[str, None] = '019' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories'] + + +def upgrade() -> None: + conn = op.get_bind() + + # 1. Clean up orphan accounts (zero-user teams with no content) + conn.execute(sa.text(""" + DELETE FROM subscriptions WHERE account_id IN ( + SELECT a.id FROM accounts a + WHERE a.owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id) + ) + """)) + + # Also remove the mapping entries for these orphans + conn.execute(sa.text(""" + DELETE FROM _team_account_mapping WHERE account_id IN ( + SELECT a.id FROM accounts a + WHERE a.owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id) + ) + """)) + + conn.execute(sa.text(""" + DELETE FROM accounts + WHERE owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = accounts.id) + """)) + + # 2. Users: enforce NOT NULL and add FK + CHECK + op.alter_column('users', 'account_id', nullable=False) + op.alter_column('users', 'account_role', nullable=False) + op.create_foreign_key( + 'fk_users_account_id', 'users', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE' + ) + op.create_check_constraint( + 'ck_users_account_role_enum', 'users', + "account_role IN ('owner', 'engineer', 'viewer')" + ) + + # 3. Content tables: add FK on account_id (nullable OK — NULL means global) + for table in CONTENT_TABLES: + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE' + ) + + # 4. Accounts: enforce owner_id NOT NULL + FK + op.alter_column('accounts', 'owner_id', nullable=False) + op.create_foreign_key( + 'fk_accounts_owner_id', 'accounts', 'users', + ['owner_id'], ['id'], ondelete='RESTRICT' + ) + + +def downgrade() -> None: + # Remove account owner FK and nullable constraint + op.drop_constraint('fk_accounts_owner_id', 'accounts', type_='foreignkey') + op.alter_column('accounts', 'owner_id', nullable=True) + + # Remove content table FKs + for table in reversed(CONTENT_TABLES): + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + + # Remove user constraints + op.drop_constraint('ck_users_account_role_enum', 'users', type_='check') + op.drop_constraint('fk_users_account_id', 'users', type_='foreignkey') + op.alter_column('users', 'account_role', nullable=True) + op.alter_column('users', 'account_id', nullable=True) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index c7cfa475..6d353506 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Optional from uuid import UUID from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -8,6 +8,7 @@ from sqlalchemy import select from app.core.database import get_db from app.core.security import decode_token from app.models.user import User +from app.models.plan_limits import PlanLimits oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -90,14 +91,35 @@ async def require_admin( async def require_engineer_or_admin( current_user: Annotated[User, Depends(get_current_active_user)] ) -> User: - """Require engineer, team admin, or super admin role (blocks viewers).""" + """Require engineer, account owner, or super admin role (blocks viewers).""" if current_user.is_super_admin: return current_user - if current_user.is_team_admin and current_user.team_id is not None: + if current_user.account_role in ("owner", "engineer"): return current_user - if current_user.role not in ("engineer",): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Engineer or admin access required" - ) - return current_user + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Engineer or admin access required" + ) + + +async def require_account_owner( + current_user: Annotated[User, Depends(get_current_active_user)] +) -> User: + """Require account owner or super admin access.""" + if current_user.is_super_admin: + return current_user + if current_user.account_role == "owner": + return current_user + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Account owner access required" + ) + + +async def get_plan_limits_for_user( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> Optional[PlanLimits]: + """Get plan limits for the current user's account.""" + from app.core.subscriptions import get_user_plan_limits + return await get_user_plan_limits(current_user.account_id, db) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 45d51cac..2041cf9c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -52,6 +52,16 @@ class Settings(BaseSettings): # Registration REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration + # Stripe + STRIPE_SECRET_KEY: Optional[str] = None + STRIPE_PUBLISHABLE_KEY: Optional[str] = None + STRIPE_WEBHOOK_SECRET: Optional[str] = None + + @property + def stripe_enabled(self) -> bool: + """Check if Stripe is configured.""" + return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None + # CORS - set FRONTEND_URL in production (e.g., https://patherly.up.railway.app) CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173", "http://localhost:5174"] FRONTEND_URL: Optional[str] = None diff --git a/backend/app/core/permissions.py b/backend/app/core/permissions.py index 52fc5bdc..ba34dd62 100644 --- a/backend/app/core/permissions.py +++ b/backend/app/core/permissions.py @@ -1,12 +1,12 @@ """ Centralized permission checks for ResolutionFlow. -Role hierarchy: super_admin > team_admin > engineer > viewer +Role hierarchy: super_admin > owner > engineer > viewer - super_admin: is_super_admin=True, full system access -- team_admin: is_team_admin=True + valid team_id, manage team resources -- engineer: role='engineer' (default), CRUD own trees/steps -- viewer: role='viewer', read-only (can browse, run sessions, rate steps) +- owner: account_role='owner', manage account resources +- engineer: account_role='engineer' (default), CRUD own trees/steps +- viewer: account_role='viewer', read-only (can browse, run sessions, rate steps) """ from __future__ import annotations from typing import Optional, TYPE_CHECKING @@ -21,19 +21,19 @@ if TYPE_CHECKING: ROLE_HIERARCHY = { "super_admin": 4, - "team_admin": 3, + "owner": 3, "engineer": 2, "viewer": 1, } def get_effective_role(user: User) -> str: - """Get the effective role considering is_super_admin and is_team_admin flags.""" + """Get the effective role considering is_super_admin and account_role.""" if user.is_super_admin: return "super_admin" - if user.is_team_admin and user.team_id is not None: - return "team_admin" - return user.role # "engineer" or "viewer" + if user.account_role == "owner": + return "owner" + return user.account_role # "engineer" or "viewer" def has_minimum_role(user: User, minimum_role: str) -> bool: @@ -55,7 +55,7 @@ def can_edit_tree(user: User, tree: Tree) -> bool: return False if tree.author_id == user.id: return True - if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None: return True return False @@ -78,7 +78,7 @@ def can_manage_category(user: User, category: TreeCategory) -> bool: """Can the user edit/delete this category?""" if user.is_super_admin: return True - if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None: return True return False @@ -91,7 +91,7 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool: return False if tree.author_id == user.id: return True - if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None: return True return False @@ -102,7 +102,7 @@ def can_access_tree(user: User, tree: Tree) -> bool: return True if tree.author_id == user.id: return True - if tree.team_id == user.team_id and user.team_id is not None: + if tree.account_id == user.account_id and user.account_id is not None: return True if user.is_super_admin: return True @@ -116,35 +116,35 @@ def can_view_step(user: User, step: StepLibrary) -> bool: if step.visibility == "private": return step.created_by == user.id if step.visibility == "team": - return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin + return (step.account_id == user.account_id and user.account_id is not None) or user.is_super_admin return False -def can_create_tag(user: User, team_id: Optional[UUID]) -> bool: +def can_create_tag(user: User, account_id: Optional[UUID]) -> bool: """Can the user create a tag for the given scope? - - Super admins can create global tags (team_id=None) or any team's tags - - Engineers can create team tags for their own team + - Super admins can create global tags (account_id=None) or any account's tags + - Engineers can create account tags for their own account - Viewers cannot create tags """ if user.is_super_admin: return True if not can_create_content(user): return False - if team_id is not None and team_id == user.team_id: + if account_id is not None and account_id == user.account_id: return True return False -def can_create_category(user: User, team_id: Optional[UUID]) -> bool: - """Can the user create a category for the given team? +def can_create_category(user: User, account_id: Optional[UUID]) -> bool: + """Can the user create a category for the given account? - - Super admins can create global or any team's categories - - Team admins can create categories for their own team + - Super admins can create global or any account's categories + - Account owners can create categories for their own account """ if user.is_super_admin: return True - if user.is_team_admin and team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None: return True return False @@ -153,19 +153,19 @@ def can_manage_step_category(user: User, category: StepCategory) -> bool: """Can the user edit/delete this step category?""" if user.is_super_admin: return True - if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None: return True return False -def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool: - """Can the user create a step category for the given team? +def can_create_step_category(user: User, account_id: Optional[UUID]) -> bool: + """Can the user create a step category for the given account? - - Super admins can create global or any team's step categories - - Team admins can create step categories for their own team + - Super admins can create global or any account's step categories + - Account owners can create step categories for their own account """ if user.is_super_admin: return True - if user.is_team_admin and team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None: return True return False diff --git a/backend/app/core/stripe_handlers.py b/backend/app/core/stripe_handlers.py new file mode 100644 index 00000000..2d771e49 --- /dev/null +++ b/backend/app/core/stripe_handlers.py @@ -0,0 +1,37 @@ +"""Stripe webhook event handlers (stub implementations). + +These handlers log events but don't process them until Stripe is fully configured. +""" +import logging +from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + + +async def handle_checkout_completed(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: checkout.session.completed — %s", event.get("id")) + + +async def handle_invoice_paid(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: invoice.paid — %s", event.get("id")) + + +async def handle_invoice_payment_failed(event: dict, db: AsyncSession) -> None: + logger.warning("Stripe: invoice.payment_failed — %s", event.get("id")) + + +async def handle_subscription_updated(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: customer.subscription.updated — %s", event.get("id")) + + +async def handle_subscription_deleted(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: customer.subscription.deleted — %s", event.get("id")) + + +WEBHOOK_HANDLERS = { + "checkout.session.completed": handle_checkout_completed, + "invoice.paid": handle_invoice_paid, + "invoice.payment_failed": handle_invoice_payment_failed, + "customer.subscription.updated": handle_subscription_updated, + "customer.subscription.deleted": handle_subscription_deleted, +} diff --git a/backend/app/core/subscriptions.py b/backend/app/core/subscriptions.py new file mode 100644 index 00000000..e4e1b735 --- /dev/null +++ b/backend/app/core/subscriptions.py @@ -0,0 +1,113 @@ +"""Subscription limit checks and plan helpers.""" +from typing import Optional +from uuid import UUID +from datetime import datetime, timezone + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.subscription import Subscription +from app.models.plan_limits import PlanLimits +from app.models.tree import Tree +from app.models.session import Session + + +async def get_account_subscription(account_id: UUID, db: AsyncSession) -> Optional[Subscription]: + result = await db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + return result.scalar_one_or_none() + + +async def get_plan_limits(plan: str, db: AsyncSession) -> Optional[PlanLimits]: + result = await db.execute( + select(PlanLimits).where(PlanLimits.plan == plan) + ) + return result.scalar_one_or_none() + + +async def get_user_plan_limits(user_account_id: UUID, db: AsyncSession) -> Optional[PlanLimits]: + sub = await get_account_subscription(user_account_id, db) + if sub is None: + return await get_plan_limits("free", db) + return await get_plan_limits(sub.plan, db) + + +async def check_tree_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: + """Check if account can create a new tree. + + Returns: (can_create, limit, current_count) + """ + sub = await get_account_subscription(account_id, db) + if sub is None: + return False, 0, 0 + + limits = await get_plan_limits(sub.plan, db) + if limits is None or limits.max_trees is None: + return True, None, 0 # unlimited + + current_count = await db.scalar( + select(func.count(Tree.id)).where( + Tree.account_id == account_id, + Tree.deleted_at.is_(None), + ) + ) + current_count = current_count or 0 + + return current_count < limits.max_trees, limits.max_trees, current_count + + +async def check_session_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: + """Check if account can create a new session this month. + + Returns: (can_create, limit, current_count) + """ + sub = await get_account_subscription(account_id, db) + if sub is None: + return False, 0, 0 + + limits = await get_plan_limits(sub.plan, db) + if limits is None or limits.max_sessions_per_month is None: + return True, None, 0 # unlimited + + # Count sessions this calendar month for all users in this account + now = datetime.now(timezone.utc) + month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + from app.models.user import User + current_count = await db.scalar( + select(func.count(Session.id)).where( + Session.user_id.in_( + select(User.id).where(User.account_id == account_id) + ), + Session.started_at >= month_start, + ) + ) + current_count = current_count or 0 + + return current_count < limits.max_sessions_per_month, limits.max_sessions_per_month, current_count + + +async def get_account_usage(account_id: UUID, db: AsyncSession) -> dict: + """Get current usage stats for an account.""" + tree_count = await db.scalar( + select(func.count(Tree.id)).where( + Tree.account_id == account_id, + Tree.deleted_at.is_(None), + ) + ) or 0 + + now = datetime.now(timezone.utc) + month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + from app.models.user import User + session_count = await db.scalar( + select(func.count(Session.id)).where( + Session.user_id.in_( + select(User.id).where(User.account_id == account_id) + ), + Session.started_at >= month_start, + ) + ) or 0 + + return {"tree_count": tree_count, "session_count_this_month": session_count} diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 28cb75bc..981c5b15 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,5 +1,9 @@ from .user import User from .team import Team +from .account import Account +from .subscription import Subscription +from .plan_limits import PlanLimits +from .account_invite import AccountInvite from .tree import Tree from .session import Session from .attachment import Attachment @@ -15,6 +19,10 @@ from .audit_log import AuditLog __all__ = [ "User", "Team", + "Account", + "Subscription", + "PlanLimits", + "AccountInvite", "Tree", "Session", "Attachment", diff --git a/backend/app/models/account.py b/backend/app/models/account.py new file mode 100644 index 00000000..6506488f --- /dev/null +++ b/backend/app/models/account.py @@ -0,0 +1,38 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.user import User + from app.models.subscription import Subscription + from app.models.tree import Tree + from app.models.category import TreeCategory + from app.models.tag import TreeTag + from app.models.step_category import StepCategory + from app.models.step_library import StepLibrary + + +class Account(Base): + __tablename__ = "accounts" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(255), nullable=False) + display_code: Mapped[str] = mapped_column(String(8), unique=True, nullable=False) + owner_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=False) + stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + # Relationships + owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account") + users: Mapped[list["User"]] = relationship("User", foreign_keys="[User.account_id]", back_populates="account") + subscription: Mapped[Optional["Subscription"]] = relationship("Subscription", back_populates="account", uselist=False) + trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.account_id]", back_populates="account") + categories: Mapped[list["TreeCategory"]] = relationship("TreeCategory", foreign_keys="[TreeCategory.account_id]", back_populates="account") + tags: Mapped[list["TreeTag"]] = relationship("TreeTag", foreign_keys="[TreeTag.account_id]", back_populates="account") + step_categories: Mapped[list["StepCategory"]] = relationship("StepCategory", foreign_keys="[StepCategory.account_id]", back_populates="account") + step_library: Mapped[list["StepLibrary"]] = relationship("StepLibrary", foreign_keys="[StepLibrary.account_id]", back_populates="account") diff --git a/backend/app/models/account_invite.py b/backend/app/models/account_invite.py new file mode 100644 index 00000000..43b3ed56 --- /dev/null +++ b/backend/app/models/account_invite.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, CheckConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.account import Account + from app.models.user import User + + +class AccountInvite(Base): + __tablename__ = "account_invites" + __table_args__ = ( + CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'), + ) + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False) + invited_by_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + email: Mapped[str] = mapped_column(String(255), nullable=False) + code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer") + accepted_by_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # Relationships + account: Mapped["Account"] = relationship("Account") + invited_by: Mapped["User"] = relationship("User", foreign_keys=[invited_by_id]) + accepted_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[accepted_by_id]) + + @property + def is_used(self) -> bool: + return self.accepted_by_id is not None + + @property + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return datetime.now(timezone.utc) > self.expires_at + + @property + def is_valid(self) -> bool: + return not self.is_used and not self.is_expired diff --git a/backend/app/models/category.py b/backend/app/models/category.py index 6d375d50..57bae701 100644 --- a/backend/app/models/category.py +++ b/backend/app/models/category.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.tree import Tree from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -38,6 +39,12 @@ class TreeCategory(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -57,10 +64,11 @@ class TreeCategory(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="categories") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="categories") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) trees: Mapped[list["Tree"]] = relationship("Tree", back_populates="category_rel") @property def is_global(self) -> bool: - """Returns True if this is a global category (not team-specific).""" - return self.team_id is None + """Returns True if this is a global category (not team or account-specific).""" + return self.team_id is None and self.account_id is None diff --git a/backend/app/models/plan_limits.py b/backend/app/models/plan_limits.py new file mode 100644 index 00000000..1a6b0511 --- /dev/null +++ b/backend/app/models/plan_limits.py @@ -0,0 +1,16 @@ +from sqlalchemy import String, Integer, Boolean +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import JSONB +from app.core.database import Base + + +class PlanLimits(Base): + __tablename__ = "plan_limits" + + plan: Mapped[str] = mapped_column(String(50), primary_key=True) + max_trees: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_sessions_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_users: Mapped[int | None] = mapped_column(Integer, nullable=True) + custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"]) diff --git a/backend/app/models/step_category.py b/backend/app/models/step_category.py index 59471ccb..da207926 100644 --- a/backend/app/models/step_category.py +++ b/backend/app/models/step_category.py @@ -8,6 +8,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -37,6 +38,12 @@ class StepCategory(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -56,9 +63,10 @@ class StepCategory(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="step_categories") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_categories") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) @property def is_global(self) -> bool: - """Returns True if this is a global category (not team-specific).""" - return self.team_id is None + """Returns True if this is a global category (not team or account-specific).""" + return self.team_id is None and self.account_id is None diff --git a/backend/app/models/step_library.py b/backend/app/models/step_library.py index 5fb78614..ae2b319b 100644 --- a/backend/app/models/step_library.py +++ b/backend/app/models/step_library.py @@ -10,6 +10,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User from app.models.team import Team + from app.models.account import Account from app.models.step_category import StepCategory from app.models.session import Session @@ -43,6 +44,12 @@ class StepLibrary(Base): ForeignKey("teams.id", ondelete="CASCADE"), nullable=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) # Organization category_id: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -91,6 +98,7 @@ class StepLibrary(Base): # Relationships creator: Mapped["User"] = relationship("User", foreign_keys=[created_by]) team: Mapped[Optional["Team"]] = relationship("Team") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_library") category: Mapped[Optional["StepCategory"]] = relationship("StepCategory") ratings: Mapped[list["StepRating"]] = relationship("StepRating", back_populates="step", cascade="all, delete-orphan") usage_logs: Mapped[list["StepUsageLog"]] = relationship("StepUsageLog", back_populates="step", cascade="all, delete-orphan") diff --git a/backend/app/models/subscription.py b/backend/app/models/subscription.py new file mode 100644 index 00000000..54b2a440 --- /dev/null +++ b/backend/app/models/subscription.py @@ -0,0 +1,39 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.account import Account + + +class Subscription(Base): + __tablename__ = "subscriptions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), unique=True, nullable=False) + stripe_subscription_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + stripe_price_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + plan: Mapped[str] = mapped_column(String(50), nullable=False, default="free") + billing_interval: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + status: Mapped[str] = mapped_column(String(50), nullable=False, default="active") + seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + current_period_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + current_period_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + # Relationships + account: Mapped["Account"] = relationship("Account", back_populates="subscription") + + @property + def is_active(self) -> bool: + return self.status in ("active", "trialing") + + @property + def is_paid(self) -> bool: + return self.plan in ("pro", "team") diff --git a/backend/app/models/tag.py b/backend/app/models/tag.py index 6ebb6b38..5152c3a9 100644 --- a/backend/app/models/tag.py +++ b/backend/app/models/tag.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.tree import Tree from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -50,6 +51,12 @@ class TreeTag(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), @@ -63,6 +70,7 @@ class TreeTag(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="tags") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="tags") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) trees: Mapped[list["Tree"]] = relationship( "Tree", @@ -72,8 +80,8 @@ class TreeTag(Base): @property def is_global(self) -> bool: - """Returns True if this is a global tag (not team-specific).""" - return self.team_id is None + """Returns True if this is a global tag (not team or account-specific).""" + return self.team_id is None and self.account_id is None @classmethod def slugify(cls, name: str) -> str: diff --git a/backend/app/models/tree.py b/backend/app/models/tree.py index 56237d96..c9b305a7 100644 --- a/backend/app/models/tree.py +++ b/backend/app/models/tree.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User from app.models.team import Team + from app.models.account import Account from app.models.session import Session from app.models.category import TreeCategory from app.models.tag import TreeTag @@ -47,6 +48,12 @@ class Tree(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) is_active: Mapped[bool] = mapped_column(Boolean, default=True) is_public: Mapped[bool] = mapped_column(Boolean, default=False, index=True) is_default: Mapped[bool] = mapped_column(Boolean, default=False, index=True) @@ -75,6 +82,7 @@ class Tree(Base): # Relationships author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees") team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees") sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree") # New organization relationships diff --git a/backend/app/models/user.py b/backend/app/models/user.py index ac7648f7..ec835640 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -8,6 +8,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.team import Team + from app.models.account import Account from app.models.tree import Tree from app.models.session import Session from app.models.folder import UserFolder @@ -20,6 +21,10 @@ class User(Base): "role IN ('engineer', 'viewer')", name='ck_users_role_enum' ), + CheckConstraint( + "account_role IN ('owner', 'admin', 'engineer', 'viewer')", + name='ck_users_account_role_enum' + ), ) id: Mapped[uuid.UUID] = mapped_column( @@ -34,6 +39,17 @@ class User(Base): is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") + + # Account-based multi-tenancy (new) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="RESTRICT"), + nullable=True, + index=True + ) + account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer") + + # Legacy team columns (kept for PR A coexistence) team_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id"), @@ -51,6 +67,8 @@ class User(Base): last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) # Relationships + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="users") + owned_account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys="[Account.owner_id]", back_populates="owner", uselist=False) team: Mapped[Optional["Team"]] = relationship("Team", back_populates="users") trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.author_id]", back_populates="author") sessions: Mapped[list["Session"]] = relationship("Session", back_populates="user") @@ -62,6 +80,11 @@ class User(Base): return self.is_super_admin @property - def can_manage_team(self) -> bool: - """Returns True if user can manage their team (team admin or super admin).""" - return self.is_super_admin or (self.is_team_admin and self.team_id is not None) + def is_account_owner(self) -> bool: + """Returns True if user owns their account.""" + return self.account_role == "owner" + + @property + def can_manage_account(self) -> bool: + """Returns True if user can manage their account (owner, admin, or super admin).""" + return self.is_super_admin or self.account_role in ("owner", "admin") diff --git a/backend/requirements.txt b/backend/requirements.txt index 73b11ac3..5bfbf210 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -22,5 +22,8 @@ email-validator==2.1.0 # Rate Limiting slowapi==0.1.9 +# Payments +stripe==14.3.0 + # Utilities python-dotenv==1.0.1 From e0089a9c5aeb72a7e988b031ad8ec92e2d292ed8 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:39:01 -0500 Subject: [PATCH 2/5] feat: update all endpoints and schemas for account-based model Replace team_id with account_id across all API endpoints (trees, categories, tags, steps, step_categories, admin, auth). Add new accounts and webhooks endpoints. Registration now atomically creates Account + Subscription, with account_invite_code bypassing the platform invite gate. Schemas updated for account_id/account_role. 82 tests passing including 18 new tests for accounts, subscriptions, and permissions. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/endpoints/accounts.py | 236 +++++++++++++++++++ backend/app/api/endpoints/admin.py | 29 +-- backend/app/api/endpoints/auth.py | 84 ++++++- backend/app/api/endpoints/categories.py | 48 ++-- backend/app/api/endpoints/step_categories.py | 48 ++-- backend/app/api/endpoints/steps.py | 20 +- backend/app/api/endpoints/tags.py | 76 +++--- backend/app/api/endpoints/trees.py | 40 ++-- backend/app/api/endpoints/webhooks.py | 62 +++++ backend/app/api/router.py | 4 +- backend/app/schemas/account.py | 39 +++ backend/app/schemas/category.py | 6 +- backend/app/schemas/step_category.py | 6 +- backend/app/schemas/step_library.py | 4 +- backend/app/schemas/subscription.py | 40 ++++ backend/app/schemas/tag.py | 8 +- backend/app/schemas/tree.py | 4 +- backend/app/schemas/user.py | 11 +- backend/tests/conftest.py | 9 + backend/tests/test_account_management.py | 170 +++++++++++++ backend/tests/test_admin.py | 48 ++++ backend/tests/test_auth.py | 4 + backend/tests/test_permissions_account.py | 205 ++++++++++++++++ backend/tests/test_subscription_limits.py | 129 ++++++++++ 24 files changed, 1178 insertions(+), 152 deletions(-) create mode 100644 backend/app/api/endpoints/accounts.py create mode 100644 backend/app/api/endpoints/webhooks.py create mode 100644 backend/app/schemas/account.py create mode 100644 backend/app/schemas/subscription.py create mode 100644 backend/tests/test_account_management.py create mode 100644 backend/tests/test_permissions_account.py create mode 100644 backend/tests/test_subscription_limits.py diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py new file mode 100644 index 00000000..c2228584 --- /dev/null +++ b/backend/app/api/endpoints/accounts.py @@ -0,0 +1,236 @@ +from datetime import datetime, timezone, timedelta +from typing import Annotated, Optional +from uuid import UUID +import secrets +import string +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.core.database import get_db +from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage +from app.models.account import Account +from app.models.account_invite import AccountInvite +from app.models.subscription import Subscription +from app.models.user import User +from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse +from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails +from app.schemas.user import UserResponse, AccountRoleUpdate +from app.api.deps import get_current_active_user, require_account_owner + +router = APIRouter(prefix="/accounts", tags=["accounts"]) + + +@router.get("/me", response_model=AccountResponse) +async def get_my_account( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get current user's account.""" + result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = result.scalar_one_or_none() + if not account: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Account not found" + ) + return account + + +@router.get("/me/subscription", response_model=SubscriptionDetails) +async def get_my_subscription( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get current user's subscription details including limits and usage.""" + sub = await get_account_subscription(current_user.account_id, db) + if not sub: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No subscription found" + ) + + limits = await get_plan_limits(sub.plan, db) + if not limits: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Plan limits not configured" + ) + + usage = await get_account_usage(current_user.account_id, db) + + return SubscriptionDetails( + subscription=SubscriptionResponse.model_validate(sub), + limits=PlanLimitsResponse.model_validate(limits), + usage=UsageResponse(**usage), + ) + + +@router.get("/me/members", response_model=list[UserResponse]) +async def get_my_members( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get members of current user's account.""" + result = await db.execute( + select(User).where(User.account_id == current_user.account_id) + .order_by(User.created_at) + ) + return result.scalars().all() + + +@router.patch("/me", response_model=AccountResponse) +async def update_my_account( + data: AccountUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Update account settings (owner only).""" + result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = result.scalar_one_or_none() + if not account: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Account not found" + ) + + update_data = data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(account, field, value) + + await db.commit() + await db.refresh(account) + return account + + +@router.patch("/me/members/{user_id}/role", response_model=UserResponse) +async def update_member_role( + user_id: UUID, + data: AccountRoleUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Change a member's role within the account (owner only).""" + result = await db.execute( + select(User).where( + User.id == user_id, + User.account_id == current_user.account_id + ) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found in your account" + ) + + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot change your own role" + ) + + user.account_role = data.account_role + await db.commit() + await db.refresh(user) + return user + + +@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def remove_member( + user_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Remove a member from the account (owner only). + + The removed user gets a new personal account. + """ + result = await db.execute( + select(User).where( + User.id == user_id, + User.account_id == current_user.account_id + ) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found in your account" + ) + + if user.id == current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot remove yourself from your own account" + ) + + # Create a personal account for the removed user + chars = string.ascii_uppercase + string.digits + display_code = ''.join(secrets.choice(chars) for _ in range(8)) + + new_account = Account( + name=f"{user.name}'s Account", + display_code=display_code, + owner_id=user.id, + ) + db.add(new_account) + await db.flush() + + new_sub = Subscription( + account_id=new_account.id, + plan="free", + status="active", + ) + db.add(new_sub) + + user.account_id = new_account.id + user.account_role = "owner" + + await db.commit() + return None + + +@router.post("/me/invites", response_model=AccountInviteResponse, status_code=status.HTTP_201_CREATED) +async def create_invite( + data: AccountInviteCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Create an invite to join this account (owner only).""" + code = secrets.token_urlsafe(16) + + expires_at = None + if data.expires_in_days: + expires_at = datetime.now(timezone.utc) + timedelta(days=data.expires_in_days) + + invite = AccountInvite( + account_id=current_user.account_id, + invited_by_id=current_user.id, + email=data.email, + code=code, + role=data.role, + expires_at=expires_at, + ) + db.add(invite) + await db.commit() + await db.refresh(invite) + return invite + + +@router.get("/me/invites", response_model=list[AccountInviteResponse]) +async def list_invites( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """List invites for this account (owner only).""" + result = await db.execute( + select(AccountInvite) + .where(AccountInvite.account_id == current_user.account_id) + .order_by(AccountInvite.created_at.desc()) + ) + return result.scalars().all() diff --git a/backend/app/api/endpoints/admin.py b/backend/app/api/endpoints/admin.py index daa04166..e6bde866 100644 --- a/backend/app/api/endpoints/admin.py +++ b/backend/app/api/endpoints/admin.py @@ -7,7 +7,7 @@ from sqlalchemy import select, func from app.core.database import get_db from app.core.audit import log_audit from app.models.user import User -from app.schemas.user import UserResponse, RoleUpdate, TeamAdminUpdate +from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate from app.api.deps import require_admin router = APIRouter(prefix="/admin", tags=["admin"]) @@ -21,7 +21,7 @@ async def list_users( limit: int = Query(100, ge=1, le=100), is_active: Optional[bool] = Query(None, description="Filter by active status"), role: Optional[str] = Query(None, description="Filter by role"), - team_id: Optional[UUID] = Query(None, description="Filter by team") + account_id: Optional[UUID] = Query(None, description="Filter by account") ): """List all users (super admin only).""" query = select(User) @@ -30,8 +30,8 @@ async def list_users( query = query.where(User.is_active == is_active) if role: query = query.where(User.role == role) - if team_id: - query = query.where(User.team_id == team_id) + if account_id: + query = query.where(User.account_id == account_id) query = query.order_by(User.created_at.desc()).offset(skip).limit(limit) @@ -91,14 +91,14 @@ async def update_user_role( return user -@router.put("/users/{user_id}/team-admin", response_model=UserResponse) -async def toggle_team_admin( +@router.put("/users/{user_id}/account-role", response_model=UserResponse) +async def update_account_role( user_id: UUID, - data: TeamAdminUpdate, + data: AccountRoleUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(require_admin)] ): - """Toggle is_team_admin for a user (super admin only).""" + """Change a user's account role (super admin only).""" result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() @@ -108,15 +108,10 @@ async def toggle_team_admin( detail="User not found" ) - if data.is_team_admin and user.team_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User must belong to a team to be a team admin" - ) - - user.is_team_admin = data.is_team_admin - await log_audit(db, current_user.id, "user.team_admin_toggle", "user", user.id, - {"is_team_admin": data.is_team_admin}) + old_role = user.account_role + user.account_role = data.account_role + await log_audit(db, current_user.id, "user.account_role_change", "user", user.id, + {"old_account_role": old_role, "new_account_role": data.account_role}) await db.commit() await db.refresh(user) return user diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 5450860a..385c31fd 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -1,3 +1,5 @@ +import secrets +import string from datetime import datetime, timezone from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status, Request @@ -18,6 +20,9 @@ from app.core.security import ( from app.models.user import User from app.models.invite_code import InviteCode from app.models.refresh_token import RefreshToken +from app.models.account import Account +from app.models.subscription import Subscription +from app.models.account_invite import AccountInvite from app.schemas.user import UserCreate, UserResponse, UserLogin from app.schemas.token import Token from app.api.deps import get_current_active_user, get_refresh_token_payload @@ -37,6 +42,12 @@ async def _store_refresh_token(db: AsyncSession, refresh_token_str: str, user_id db.add(token_record) +def _generate_display_code() -> str: + """Generate a random 8-character alphanumeric display code.""" + chars = string.ascii_uppercase + string.digits + return ''.join(secrets.choice(chars) for _ in range(8)) + + @router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) @limiter.limit("3/minute") async def register( @@ -44,10 +55,46 @@ async def register( user_data: UserCreate, db: Annotated[AsyncSession, Depends(get_db)] ): - """Register a new user.""" - # Validate invite code if required + """Register a new user. + + Supports two flows: + - account_invite_code: Join an existing account (bypasses platform invite gate) + - invite_code: Platform invite code (when REQUIRE_INVITE_CODE is enabled) + + After user creation, if no account invite was used, a personal Account + and free Subscription are created automatically. + """ + # Check for account invite code FIRST — bypasses platform invite gate + account_invite_record = None + if user_data.account_invite_code: + result = await db.execute( + select(AccountInvite).where( + AccountInvite.code == user_data.account_invite_code + ) + ) + account_invite_record = result.scalar_one_or_none() + + if not account_invite_record: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid account invite code" + ) + + if account_invite_record.is_used: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Account invite code has already been used" + ) + + if account_invite_record.is_expired: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Account invite code has expired" + ) + + # Validate platform invite code if required (skip if account invite was provided) invite_code_record = None - if settings.REQUIRE_INVITE_CODE: + if not account_invite_record and settings.REQUIRE_INVITE_CODE: if not user_data.invite_code: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -96,8 +143,37 @@ async def register( invite_code_id=invite_code_record.id if invite_code_record else None ) db.add(new_user) + await db.flush() # Get user ID before creating account - # Mark invite code as used + if account_invite_record: + # Join existing account via account invite + new_user.account_id = account_invite_record.account_id + new_user.account_role = account_invite_record.role + + # Mark account invite as used + account_invite_record.accepted_by_id = new_user.id + account_invite_record.used_at = datetime.now(timezone.utc) + else: + # Create personal Account + free Subscription + new_account = Account( + name=f"{user_data.name}'s Account", + display_code=_generate_display_code(), + owner_id=new_user.id, + ) + db.add(new_account) + await db.flush() # Get account ID + + new_subscription = Subscription( + account_id=new_account.id, + plan="free", + status="active", + ) + db.add(new_subscription) + + new_user.account_id = new_account.id + new_user.account_role = "owner" + + # Mark platform invite code as used if invite_code_record: invite_code_record.used_by_id = new_user.id invite_code_record.used_at = datetime.now(timezone.utc) diff --git a/backend/app/api/endpoints/categories.py b/backend/app/api/endpoints/categories.py index 22341ef8..73505c05 100644 --- a/backend/app/api/endpoints/categories.py +++ b/backend/app/api/endpoints/categories.py @@ -28,11 +28,11 @@ async def list_categories( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], include_inactive: bool = Query(False, description="Include inactive categories"), - team_only: bool = Query(False, description="Only show team-specific categories") + account_only: bool = Query(False, description="Only show account-specific categories") ): """List categories visible to the user. - Returns global categories plus team-specific categories for the user's team. + Returns global categories plus account-specific categories for the user's account. """ # Build query for accessible categories query = select(TreeCategory) @@ -41,19 +41,19 @@ async def list_categories( if not include_inactive: query = query.where(TreeCategory.is_active == True) - # Filter by visibility: global OR user's team - if team_only and current_user.team_id: - query = query.where(TreeCategory.team_id == current_user.team_id) - elif current_user.team_id: + # Filter by visibility: global OR user's account + if account_only and current_user.account_id: + query = query.where(TreeCategory.account_id == current_user.account_id) + elif current_user.account_id: query = query.where( or_( - TreeCategory.team_id.is_(None), # Global - TreeCategory.team_id == current_user.team_id # User's team + TreeCategory.account_id.is_(None), # Global + TreeCategory.account_id == current_user.account_id # User's account ) ) else: - # User has no team, only show global categories - query = query.where(TreeCategory.team_id.is_(None)) + # User has no account, only show global categories + query = query.where(TreeCategory.account_id.is_(None)) query = query.order_by(TreeCategory.display_order, TreeCategory.name) @@ -76,7 +76,7 @@ async def list_categories( name=cat.name, slug=cat.slug, description=cat.description, - team_id=cat.team_id, + account_id=cat.account_id, display_order=cat.display_order, is_active=cat.is_active, tree_count=tree_count @@ -101,8 +101,8 @@ async def get_category( detail="Category not found" ) - # Check access: global categories visible to all, team categories only to team members - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global categories visible to all, account categories only to account members + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -121,7 +121,7 @@ async def get_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, @@ -138,10 +138,10 @@ async def create_category( ): """Create a new category. - - Global admins can create global categories (team_id=None) - - Team admins can create team-specific categories for their team + - Global admins can create global categories (account_id=None) + - Account admins can create account-specific categories for their account """ - if not can_create_category(current_user, category_data.team_id): + if not can_create_category(current_user, category_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this category" @@ -150,10 +150,10 @@ async def create_category( # Generate slug slug = slugify(category_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(TreeCategory).where( TreeCategory.slug == slug, - TreeCategory.team_id == category_data.team_id + TreeCategory.account_id == category_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -164,7 +164,7 @@ async def create_category( # Get next display order order_query = select(func.max(TreeCategory.display_order)).where( - TreeCategory.team_id == category_data.team_id + TreeCategory.account_id == category_data.account_id ) order_result = await db.execute(order_query) max_order = order_result.scalar() or 0 @@ -173,7 +173,7 @@ async def create_category( name=category_data.name, slug=slug, description=category_data.description, - team_id=category_data.team_id, + account_id=category_data.account_id, display_order=max_order + 1, created_by=current_user.id ) @@ -186,7 +186,7 @@ async def create_category( name=new_category.name, slug=new_category.slug, description=new_category.description, - team_id=new_category.team_id, + account_id=new_category.account_id, display_order=new_category.display_order, is_active=new_category.is_active, created_at=new_category.created_at, @@ -227,7 +227,7 @@ async def update_category( # Check for duplicate slug existing_query = select(TreeCategory).where( TreeCategory.slug == new_slug, - TreeCategory.team_id == category.team_id, + TreeCategory.account_id == category.account_id, TreeCategory.id != category_id ) existing = await db.execute(existing_query) @@ -257,7 +257,7 @@ async def update_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, diff --git a/backend/app/api/endpoints/step_categories.py b/backend/app/api/endpoints/step_categories.py index 3480929e..5d890225 100644 --- a/backend/app/api/endpoints/step_categories.py +++ b/backend/app/api/endpoints/step_categories.py @@ -25,11 +25,11 @@ async def list_step_categories( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], include_inactive: bool = Query(False, description="Include inactive categories"), - team_only: bool = Query(False, description="Only show team-specific categories") + account_only: bool = Query(False, description="Only show account-specific categories") ): """List step categories visible to the user. - Returns global categories plus team-specific categories for the user's team. + Returns global categories plus account-specific categories for the user's account. """ # Build query for accessible categories query = select(StepCategory) @@ -38,19 +38,19 @@ async def list_step_categories( if not include_inactive: query = query.where(StepCategory.is_active == True) - # Filter by visibility: global OR user's team - if team_only and current_user.team_id: - query = query.where(StepCategory.team_id == current_user.team_id) - elif current_user.team_id: + # Filter by visibility: global OR user's account + if account_only and current_user.account_id: + query = query.where(StepCategory.account_id == current_user.account_id) + elif current_user.account_id: query = query.where( or_( - StepCategory.team_id.is_(None), # Global - StepCategory.team_id == current_user.team_id # User's team + StepCategory.account_id.is_(None), # Global + StepCategory.account_id == current_user.account_id # User's account ) ) else: - # User has no team, only show global categories - query = query.where(StepCategory.team_id.is_(None)) + # User has no account, only show global categories + query = query.where(StepCategory.account_id.is_(None)) query = query.order_by(StepCategory.display_order, StepCategory.name) @@ -66,7 +66,7 @@ async def list_step_categories( name=cat.name, slug=cat.slug, description=cat.description, - team_id=cat.team_id, + account_id=cat.account_id, display_order=cat.display_order, is_active=cat.is_active, step_count=0 # Will be computed when step_library exists @@ -91,8 +91,8 @@ async def get_step_category( detail="Step category not found" ) - # Check access: global categories visible to all, team categories only to team members - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global categories visible to all, account categories only to account members + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this step category" @@ -103,7 +103,7 @@ async def get_step_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, @@ -120,10 +120,10 @@ async def create_step_category( ): """Create a new step category. - - Global admins can create global categories (team_id=None) - - Team admins can create team-specific categories for their team + - Global admins can create global categories (account_id=None) + - Account admins can create account-specific categories for their account """ - if not can_create_step_category(current_user, category_data.team_id): + if not can_create_step_category(current_user, category_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this step category" @@ -132,10 +132,10 @@ async def create_step_category( # Generate slug slug = slugify(category_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(StepCategory).where( StepCategory.slug == slug, - StepCategory.team_id == category_data.team_id + StepCategory.account_id == category_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -146,7 +146,7 @@ async def create_step_category( # Get next display order order_query = select(func.max(StepCategory.display_order)).where( - StepCategory.team_id == category_data.team_id + StepCategory.account_id == category_data.account_id ) order_result = await db.execute(order_query) max_order = order_result.scalar() or 0 @@ -155,7 +155,7 @@ async def create_step_category( name=category_data.name, slug=slug, description=category_data.description, - team_id=category_data.team_id, + account_id=category_data.account_id, display_order=max_order + 1, created_by=current_user.id ) @@ -168,7 +168,7 @@ async def create_step_category( name=new_category.name, slug=new_category.slug, description=new_category.description, - team_id=new_category.team_id, + account_id=new_category.account_id, display_order=new_category.display_order, is_active=new_category.is_active, created_at=new_category.created_at, @@ -209,7 +209,7 @@ async def update_step_category( # Check for duplicate slug existing_query = select(StepCategory).where( StepCategory.slug == new_slug, - StepCategory.team_id == category.team_id, + StepCategory.account_id == category.account_id, StepCategory.id != category_id ) existing = await db.execute(existing_query) @@ -231,7 +231,7 @@ async def update_step_category( name=category.name, slug=category.slug, description=category.description, - team_id=category.team_id, + account_id=category.account_id, display_order=category.display_order, is_active=category.is_active, created_at=category.created_at, diff --git a/backend/app/api/endpoints/steps.py b/backend/app/api/endpoints/steps.py index 49605379..d1e5a160 100644 --- a/backend/app/api/endpoints/steps.py +++ b/backend/app/api/endpoints/steps.py @@ -55,10 +55,10 @@ async def get_step_or_404( def build_visibility_filter(user: User): """Build SQLAlchemy filter for step visibility based on user.""" - if user.team_id: + if user.account_id: return or_( StepLibrary.visibility == 'public', - and_(StepLibrary.visibility == 'team', StepLibrary.team_id == user.team_id), + and_(StepLibrary.visibility == 'team', StepLibrary.account_id == user.account_id), StepLibrary.created_by == user.id # Own private steps ) else: @@ -249,7 +249,7 @@ async def get_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, @@ -296,10 +296,10 @@ async def create_step( if not cat_result.scalar_one_or_none(): raise HTTPException(status_code=400, detail="Invalid category") - # Team validation: can only set team_id to own team - team_id = step_data.team_id - if team_id and team_id != current_user.team_id and not current_user.is_super_admin: - raise HTTPException(status_code=403, detail="Cannot create step for another team") + # Account validation: can only set account_id to own account + account_id = step_data.account_id + if account_id and account_id != current_user.account_id and not current_user.is_super_admin: + raise HTTPException(status_code=403, detail="Cannot create step for another account") step = StepLibrary( title=step_data.title, @@ -309,7 +309,7 @@ async def create_step( tags=step_data.tags, visibility=step_data.visibility, created_by=current_user.id, - team_id=team_id or current_user.team_id, + account_id=account_id or current_user.account_id, ) db.add(step) @@ -326,7 +326,7 @@ async def create_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, @@ -393,7 +393,7 @@ async def update_step( "tags": step.tags, "visibility": step.visibility, "created_by": step.created_by, - "team_id": step.team_id, + "account_id": step.account_id, "usage_count": step.usage_count, "rating_average": step.rating_average, "rating_count": step.rating_count, diff --git a/backend/app/api/endpoints/tags.py b/backend/app/api/endpoints/tags.py index 4f764544..334e33f8 100644 --- a/backend/app/api/endpoints/tags.py +++ b/backend/app/api/endpoints/tags.py @@ -20,26 +20,26 @@ router = APIRouter(prefix="/tags", tags=["tags"]) async def list_tags( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], - include_team: bool = Query(True, description="Include team-specific tags") + include_account: bool = Query(True, description="Include account-specific tags") ): """List tags visible to the user. - Returns global tags plus team-specific tags for the user's team. + Returns global tags plus account-specific tags for the user's account. Tags are ordered by usage count (most used first). """ query = select(TreeTag) - # Filter by visibility: global OR user's team - if include_team and current_user.team_id: + # Filter by visibility: global OR user's account + if include_account and current_user.account_id: query = query.where( or_( - TreeTag.team_id.is_(None), # Global - TreeTag.team_id == current_user.team_id # User's team + TreeTag.account_id.is_(None), # Global + TreeTag.account_id == current_user.account_id # User's account ) ) else: # Only show global tags - query = query.where(TreeTag.team_id.is_(None)) + query = query.where(TreeTag.account_id.is_(None)) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name) @@ -55,7 +55,7 @@ async def search_tags( current_user: Annotated[User, Depends(get_current_active_user)], q: str = Query(..., min_length=1, description="Search query"), limit: int = Query(10, ge=1, le=50), - include_team: bool = Query(True, description="Include team-specific tags") + include_account: bool = Query(True, description="Include account-specific tags") ): """Search/autocomplete tags. @@ -68,15 +68,15 @@ async def search_tags( ) # Filter by visibility - if include_team and current_user.team_id: + if include_account and current_user.account_id: query = query.where( or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == current_user.team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == current_user.account_id ) ) else: - query = query.where(TreeTag.team_id.is_(None)) + query = query.where(TreeTag.account_id.is_(None)) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit) @@ -102,8 +102,8 @@ async def get_tag( detail="Tag not found" ) - # Check access: global tags visible to all, team tags only to team members - if tag.team_id and tag.team_id != current_user.team_id and not current_user.is_super_admin: + # Check access: global tags visible to all, account tags only to account members + if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this tag" @@ -120,10 +120,10 @@ async def create_tag( ): """Create a new tag. - - Global admins can create global tags (team_id=None) - - Team members can create team-specific tags for their team + - Global admins can create global tags (account_id=None) + - Account members can create account-specific tags for their account """ - if not can_create_tag(current_user, tag_data.team_id): + if not can_create_tag(current_user, tag_data.account_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to create this tag" @@ -132,10 +132,10 @@ async def create_tag( # Generate slug slug = TreeTag.slugify(tag_data.name) - # Check for duplicate slug within same scope (global or team) + # Check for duplicate slug within same scope (global or account) existing_query = select(TreeTag).where( TreeTag.slug == slug, - TreeTag.team_id == tag_data.team_id + TreeTag.account_id == tag_data.account_id ) existing = await db.execute(existing_query) if existing.scalar_one_or_none(): @@ -147,7 +147,7 @@ async def create_tag( new_tag = TreeTag( name=tag_data.name, slug=slug, - team_id=tag_data.team_id, + account_id=tag_data.account_id, created_by=current_user.id ) db.add(new_tag) @@ -200,30 +200,30 @@ async def add_tags_to_tree( continue # Try to find existing tag - # Determine scope: use tree's team, or global for admin-owned trees - tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + # Determine scope: use tree's account, or global for admin-owned trees + tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), # Global tag - TreeTag.team_id == tag_team_id # Team tag + TreeTag.account_id.is_(None), # Global tag + TreeTag.account_id == tag_account_id # Account tag ) ) tag_result = await db.execute(tag_query) tag = tag_result.scalar_one_or_none() if not tag: - # Create new tag - prefer team scope unless admin creating on public tree - new_team_id = tag_team_id - if not can_create_tag(current_user, new_team_id): - # Fall back to user's team if they can't create in tree's scope - new_team_id = current_user.team_id + # Create new tag - prefer account scope unless admin creating on public tree + new_account_id = tag_account_id + if not can_create_tag(current_user, new_account_id): + # Fall back to user's account if they can't create in tree's scope + new_account_id = current_user.account_id tag = TreeTag( name=tag_name, slug=slug, - team_id=new_team_id, + account_id=new_account_id, created_by=current_user.id ) db.add(tag) @@ -331,7 +331,7 @@ async def replace_tree_tags( tree.tags.clear() # Add new tags - tag_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tag_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) for tag_name in tag_data.tags: slug = TreeTag.slugify(tag_name) @@ -340,8 +340,8 @@ async def replace_tree_tags( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tag_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tag_account_id ) ) tag_result = await db.execute(tag_query) @@ -349,14 +349,14 @@ async def replace_tree_tags( if not tag: # Create new tag - new_team_id = tag_team_id - if not can_create_tag(current_user, new_team_id): - new_team_id = current_user.team_id + new_account_id = tag_account_id + if not can_create_tag(current_user, new_account_id): + new_account_id = current_user.account_id tag = TreeTag( name=tag_name, slug=slug, - team_id=new_team_id, + account_id=new_account_id, created_by=current_user.id ) db.add(tag) @@ -397,7 +397,7 @@ async def get_tree_tags( # Check if user can view the tree if not tree.is_public: if tree.author_id != current_user.id: - if tree.team_id != current_user.team_id: + if tree.account_id != current_user.account_id: if not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 8f238f74..0a0fb6d0 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -15,6 +15,7 @@ from app.models.folder import UserFolder, user_folder_trees from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin from app.core.permissions import can_edit_tree, can_access_tree +from app.core.subscriptions import check_tree_limit from app.core.audit import log_audit router = APIRouter(prefix="/trees", tags=["trees"]) @@ -37,8 +38,8 @@ def build_tree_access_filter(current_user: User): Tree.is_public == True, Tree.author_id == current_user.id, ] - if current_user.team_id: - conditions.append(Tree.team_id == current_user.team_id) + if current_user.account_id: + conditions.append(Tree.account_id == current_user.account_id) return or_(*conditions) @@ -61,7 +62,7 @@ def build_tree_response(tree: Tree) -> TreeListResponse: category_info=category_info, tags=tree.tag_names, author_id=tree.author_id, - team_id=tree.team_id, + account_id=tree.account_id, is_active=tree.is_active, is_public=tree.is_public, is_default=tree.is_default, @@ -92,7 +93,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse: tags=tree.tag_names, tree_structure=tree.tree_structure, author_id=tree.author_id, - team_id=tree.team_id, + account_id=tree.account_id, is_active=tree.is_active, is_public=tree.is_public, is_default=tree.is_default, @@ -289,7 +290,7 @@ async def create_tree( detail="Category not found" ) # Check category access - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -302,16 +303,25 @@ async def create_tree( category_id=tree_data.category_id, tree_structure=tree_data.tree_structure, author_id=None if is_default else current_user.id, # Default trees have no author - team_id=None if is_default else current_user.team_id, + account_id=None if is_default else current_user.account_id, is_public=True if is_default else tree_data.is_public, # Default trees are always public is_default=is_default ) + # Check subscription tree limit + if not is_default and current_user.account_id: + can_create, limit, count = await check_tree_limit(current_user.account_id, db) + if not can_create: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees." + ) + db.add(new_tree) await db.flush() # Get the ID # Handle tags if tree_data.tags: - tree_team_id = new_tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tree_account_id = new_tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) # Collect tags to add tags_to_add = [] @@ -322,8 +332,8 @@ async def create_tree( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tree_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tree_account_id ) ) tag_result = await db.execute(tag_query) @@ -334,7 +344,7 @@ async def create_tree( tag = TreeTag( name=tag_name, slug=slug, - team_id=tree_team_id, + account_id=tree_account_id, created_by=current_user.id ) db.add(tag) @@ -420,7 +430,7 @@ async def update_tree( status_code=status.HTTP_404_NOT_FOUND, detail="Category not found" ) - if category.team_id and category.team_id != current_user.team_id and not current_user.is_super_admin: + if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have access to this category" @@ -450,7 +460,7 @@ async def update_tree( ) # Add new tags - tree_team_id = tree.team_id or (current_user.team_id if not current_user.is_super_admin else None) + tree_account_id = tree.account_id or (current_user.account_id if not current_user.is_super_admin else None) added_tag_ids = set() for tag_name in tags_data: @@ -459,8 +469,8 @@ async def update_tree( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.team_id.is_(None), - TreeTag.team_id == tree_team_id + TreeTag.account_id.is_(None), + TreeTag.account_id == tree_account_id ) ) tag_result = await db.execute(tag_query) @@ -470,7 +480,7 @@ async def update_tree( tag = TreeTag( name=tag_name, slug=slug, - team_id=tree_team_id, + account_id=tree_account_id, created_by=current_user.id ) db.add(tag) diff --git a/backend/app/api/endpoints/webhooks.py b/backend/app/api/endpoints/webhooks.py new file mode 100644 index 00000000..1773ec22 --- /dev/null +++ b/backend/app/api/endpoints/webhooks.py @@ -0,0 +1,62 @@ +import logging +from fastapi import APIRouter, Request, HTTPException, status, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.config import settings +from app.core.stripe_handlers import WEBHOOK_HANDLERS + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/webhooks", tags=["webhooks"]) + + +@router.post("/stripe") +async def stripe_webhook( + request: Request, + db: AsyncSession = Depends(get_db), +): + """Handle Stripe webhook events. + + Returns 200 for all events to prevent Stripe retries. + Actual processing happens only when Stripe is configured. + """ + if not settings.stripe_enabled: + return {"status": "ok", "message": "Stripe not configured, event ignored"} + + payload = await request.body() + sig_header = request.headers.get("stripe-signature") + + if not sig_header: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing stripe-signature header" + ) + + # Verify webhook signature + try: + import stripe + stripe.api_key = settings.STRIPE_SECRET_KEY + event = stripe.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except ImportError: + logger.warning("stripe package not installed, cannot verify webhook") + return {"status": "ok", "message": "stripe package not installed"} + except Exception as e: + logger.error("Stripe webhook signature verification failed: %s", e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid signature" + ) + + event_type = event.get("type", "") + handler = WEBHOOK_HANDLERS.get(event_type) + + if handler: + try: + await handler(event, db) + except Exception: + logger.exception("Error handling Stripe event %s", event_type) + + return {"status": "ok"} diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 5940ea3f..05a773bc 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin +from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks api_router = APIRouter() @@ -13,3 +13,5 @@ api_router.include_router(folders.router) api_router.include_router(step_categories.router) api_router.include_router(steps.router) api_router.include_router(admin.router) +api_router.include_router(accounts.router) +api_router.include_router(webhooks.router) diff --git a/backend/app/schemas/account.py b/backend/app/schemas/account.py new file mode 100644 index 00000000..8a9a101e --- /dev/null +++ b/backend/app/schemas/account.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID +from pydantic import BaseModel, Field + + +class AccountResponse(BaseModel): + id: UUID + name: str + display_code: str + owner_id: UUID + stripe_customer_id: Optional[str] = None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class AccountUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1, max_length=255) + + +class AccountInviteCreate(BaseModel): + email: str = Field(..., max_length=255) + role: str = Field("engineer", pattern="^(engineer|viewer)$") + expires_in_days: Optional[int] = Field(None, ge=1, le=30) + + +class AccountInviteResponse(BaseModel): + id: UUID + account_id: UUID + email: str + code: str + role: str + expires_at: Optional[datetime] = None + created_at: datetime + used_at: Optional[datetime] = None + + model_config = {"from_attributes": True} diff --git a/backend/app/schemas/category.py b/backend/app/schemas/category.py index 2cca0694..13e9955c 100644 --- a/backend/app/schemas/category.py +++ b/backend/app/schemas/category.py @@ -20,7 +20,7 @@ class CategoryBase(BaseModel): class CategoryCreate(CategoryBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.") class CategoryUpdate(BaseModel): @@ -33,7 +33,7 @@ class CategoryUpdate(BaseModel): class CategoryResponse(CategoryBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool created_at: datetime @@ -49,7 +49,7 @@ class CategoryListResponse(BaseModel): name: str slug: str description: Optional[str] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool tree_count: int = 0 diff --git a/backend/app/schemas/step_category.py b/backend/app/schemas/step_category.py index 9d03667e..106c32c9 100644 --- a/backend/app/schemas/step_category.py +++ b/backend/app/schemas/step_category.py @@ -20,7 +20,7 @@ class StepCategoryBase(BaseModel): class StepCategoryCreate(StepCategoryBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific category. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific category. NULL for global.") class StepCategoryUpdate(BaseModel): @@ -33,7 +33,7 @@ class StepCategoryUpdate(BaseModel): class StepCategoryResponse(StepCategoryBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool created_at: datetime @@ -49,7 +49,7 @@ class StepCategoryListResponse(BaseModel): name: str slug: str description: Optional[str] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None display_order: int is_active: bool step_count: int = 0 diff --git a/backend/app/schemas/step_library.py b/backend/app/schemas/step_library.py index dbd7357a..93390c60 100644 --- a/backend/app/schemas/step_library.py +++ b/backend/app/schemas/step_library.py @@ -30,7 +30,7 @@ class StepLibraryBase(BaseModel): class StepLibraryCreate(StepLibraryBase): - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None class StepLibraryUpdate(BaseModel): @@ -45,7 +45,7 @@ class StepLibraryUpdate(BaseModel): class StepLibraryResponse(StepLibraryBase): id: UUID created_by: UUID - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int rating_average: Decimal rating_count: int diff --git a/backend/app/schemas/subscription.py b/backend/app/schemas/subscription.py new file mode 100644 index 00000000..9b832926 --- /dev/null +++ b/backend/app/schemas/subscription.py @@ -0,0 +1,40 @@ +from typing import Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +class SubscriptionResponse(BaseModel): + id: UUID + plan: str + status: str + billing_interval: Optional[str] = None + current_period_start: Optional[datetime] = None + current_period_end: Optional[datetime] = None + cancel_at_period_end: bool = False + stripe_subscription_id: Optional[str] = None + + model_config = {"from_attributes": True} + + +class PlanLimitsResponse(BaseModel): + plan: str + max_trees: Optional[int] = None + max_sessions_per_month: Optional[int] = None + max_users: Optional[int] = None + custom_branding: bool = False + priority_support: bool = False + export_formats: list[str] = ["markdown", "text"] + + model_config = {"from_attributes": True} + + +class UsageResponse(BaseModel): + tree_count: int + session_count_this_month: int + + +class SubscriptionDetails(BaseModel): + subscription: SubscriptionResponse + limits: PlanLimitsResponse + usage: UsageResponse diff --git a/backend/app/schemas/tag.py b/backend/app/schemas/tag.py index 2de4bfcf..47912057 100644 --- a/backend/app/schemas/tag.py +++ b/backend/app/schemas/tag.py @@ -19,13 +19,13 @@ class TagBase(BaseModel): class TagCreate(TagBase): - team_id: Optional[UUID] = Field(None, description="Team ID for team-specific tag. NULL for global.") + account_id: Optional[UUID] = Field(None, description="Account ID for account-specific tag. NULL for global.") class TagResponse(TagBase): id: UUID slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int created_at: datetime @@ -37,7 +37,7 @@ class TagListResponse(BaseModel): id: UUID name: str slug: str - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None usage_count: int class Config: @@ -53,4 +53,4 @@ class TagSearchParams(BaseModel): """Query parameters for tag search/autocomplete.""" q: str = Field(..., min_length=1, description="Search query") limit: int = Field(10, ge=1, le=50) - include_team: bool = Field(True, description="Include team-specific tags") + include_account: bool = Field(True, description="Include account-specific tags") diff --git a/backend/app/schemas/tree.py b/backend/app/schemas/tree.py index 8810f0d5..86aa7567 100644 --- a/backend/app/schemas/tree.py +++ b/backend/app/schemas/tree.py @@ -44,7 +44,7 @@ class TreeResponse(TreeBase): id: UUID tree_structure: dict[str, Any] author_id: Optional[UUID] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None category_id: Optional[UUID] = None category_info: Optional[CategoryInfo] = None tags: list[str] = [] # List of tag names @@ -69,7 +69,7 @@ class TreeListResponse(BaseModel): category_info: Optional[CategoryInfo] = None tags: list[str] = [] # List of tag names author_id: Optional[UUID] = None - team_id: Optional[UUID] = None + account_id: Optional[UUID] = None is_active: bool is_public: bool is_default: bool diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index b2f5d81e..6777f668 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -13,6 +13,7 @@ class UserBase(BaseModel): class UserCreate(UserBase): password: str = Field(..., min_length=10, description="Password must be at least 10 characters") invite_code: Optional[str] = Field(None, description="Invite code for registration (required when invite system is enabled)") + account_invite_code: Optional[str] = Field(None, description="Account invite code to join an existing account") @field_validator('password') @classmethod @@ -38,11 +39,11 @@ class UserLogin(BaseModel): class UserResponse(UserBase): id: UUID - role: str + role: str = "engineer" + account_id: UUID + account_role: str is_super_admin: bool = False - is_team_admin: bool = False is_active: bool = True - team_id: Optional[UUID] = None created_at: datetime last_login: Optional[datetime] = None @@ -54,5 +55,5 @@ class RoleUpdate(BaseModel): role: Literal["engineer", "viewer"] -class TeamAdminUpdate(BaseModel): - is_team_admin: bool +class AccountRoleUpdate(BaseModel): + account_role: str = Field(..., pattern="^(engineer|viewer)$") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7845ba82..e7b7667a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -55,6 +55,15 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]: await conn.execute(sa.text("CREATE SCHEMA public")) await conn.run_sync(Base.metadata.create_all) + # Seed plan_limits for subscription checks + await conn.execute(sa.text(""" + INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) + VALUES + ('free', 3, 20, 1, false, false, '["markdown", "text"]'), + ('pro', 25, 200, 5, true, false, '["markdown", "text", "html"]'), + ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]') + """)) + # Create async session maker async_session_maker = async_sessionmaker( engine, diff --git a/backend/tests/test_account_management.py b/backend/tests/test_account_management.py new file mode 100644 index 00000000..a8fed198 --- /dev/null +++ b/backend/tests/test_account_management.py @@ -0,0 +1,170 @@ +"""Integration tests for account management endpoints.""" + +import pytest +from httpx import AsyncClient + + +class TestAccountEndpoints: + """Test suite for account management endpoints.""" + + @pytest.mark.asyncio + async def test_get_my_account(self, client: AsyncClient, auth_headers: dict): + """Test getting current user's account.""" + response = await client.get("/api/v1/accounts/me", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "name" in data + assert "display_code" in data + assert "owner_id" in data + assert len(data["display_code"]) == 8 + + @pytest.mark.asyncio + async def test_get_my_subscription(self, client: AsyncClient, auth_headers: dict): + """Test getting current user's subscription details.""" + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert "subscription" in data + assert "limits" in data + assert "usage" in data + assert data["subscription"]["plan"] == "free" + assert data["subscription"]["status"] == "active" + assert data["limits"]["max_trees"] == 3 + assert data["limits"]["max_sessions_per_month"] == 20 + + @pytest.mark.asyncio + async def test_get_my_members(self, client: AsyncClient, auth_headers: dict): + """Test getting members of current user's account.""" + response = await client.get("/api/v1/accounts/me/members", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) >= 1 + # Current user should be in members list + assert any(m["account_role"] == "owner" for m in data) + + @pytest.mark.asyncio + async def test_update_my_account(self, client: AsyncClient, auth_headers: dict): + """Test updating account name.""" + response = await client.patch( + "/api/v1/accounts/me", + json={"name": "Updated Account Name"}, + headers=auth_headers + ) + assert response.status_code == 200 + assert response.json()["name"] == "Updated Account Name" + + @pytest.mark.asyncio + async def test_update_account_requires_owner(self, client: AsyncClient): + """Test that non-owners cannot update account settings.""" + # Register two users + owner_data = { + "email": "owner@example.com", + "password": "OwnerPass123!", + "name": "Owner" + } + await client.post("/api/v1/auth/register", json=owner_data) + + # Login as owner and create an invite + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "owner@example.com", + "password": "OwnerPass123!" + }) + owner_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "member@example.com", "role": "engineer"}, + headers=owner_headers + ) + assert invite_resp.status_code == 201 + invite_code = invite_resp.json()["code"] + + # Register member with account invite code + member_data = { + "email": "member@example.com", + "password": "MemberPass123!", + "name": "Member", + "account_invite_code": invite_code + } + reg_resp = await client.post("/api/v1/auth/register", json=member_data) + assert reg_resp.status_code == 201 + assert reg_resp.json()["account_role"] == "engineer" + + # Login as member + member_login = await client.post("/api/v1/auth/login/json", json={ + "email": "member@example.com", + "password": "MemberPass123!" + }) + member_headers = {"Authorization": f"Bearer {member_login.json()['access_token']}"} + + # Member should not be able to update account + response = await client.patch( + "/api/v1/accounts/me", + json={"name": "Hacked Name"}, + headers=member_headers + ) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_create_and_list_invites(self, client: AsyncClient, auth_headers: dict): + """Test creating and listing account invites.""" + # Create invite + response = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "invitee@example.com", "role": "engineer"}, + headers=auth_headers + ) + assert response.status_code == 201 + data = response.json() + assert data["email"] == "invitee@example.com" + assert data["role"] == "engineer" + assert "code" in data + + # List invites + list_response = await client.get("/api/v1/accounts/me/invites", headers=auth_headers) + assert list_response.status_code == 200 + invites = list_response.json() + assert len(invites) >= 1 + assert any(i["email"] == "invitee@example.com" for i in invites) + + @pytest.mark.asyncio + async def test_register_with_account_invite(self, client: AsyncClient, auth_headers: dict): + """Test that account invite code joins user to existing account.""" + # Get current account + account_resp = await client.get("/api/v1/accounts/me", headers=auth_headers) + account_id = account_resp.json()["id"] + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "joiner@example.com", "role": "viewer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + # Register with account invite code + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "joiner@example.com", + "password": "JoinerPass123!", + "name": "Joiner", + "account_invite_code": invite_code + }) + assert reg_resp.status_code == 201 + data = reg_resp.json() + assert data["account_id"] == account_id + assert data["account_role"] == "viewer" + + @pytest.mark.asyncio + async def test_register_with_invalid_invite_code(self, client: AsyncClient): + """Test that invalid account invite code is rejected.""" + response = await client.post("/api/v1/auth/register", json={ + "email": "bad@example.com", + "password": "BadPassword123!", + "name": "Bad User", + "account_invite_code": "INVALID_CODE" + }) + assert response.status_code == 400 + assert "invalid" in response.json()["detail"].lower() diff --git a/backend/tests/test_admin.py b/backend/tests/test_admin.py index aebc9ebc..1c96c03b 100644 --- a/backend/tests/test_admin.py +++ b/backend/tests/test_admin.py @@ -169,3 +169,51 @@ class TestAdminEndpoints: log = result.scalar_one_or_none() assert log is not None assert str(log.resource_id) == user_id + + @pytest.mark.asyncio + async def test_change_account_role( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict + ): + """Test changing a user's account role.""" + user_id = test_user["user_data"]["id"] + response = await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "viewer"}, + headers=admin_auth_headers + ) + assert response.status_code == 200 + assert response.json()["account_role"] == "viewer" + + @pytest.mark.asyncio + async def test_change_account_role_invalid( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict + ): + """Test that invalid account_role values are rejected.""" + user_id = test_user["user_data"]["id"] + response = await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "owner"}, + headers=admin_auth_headers + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_audit_log_created_on_account_role_change( + self, client: AsyncClient, admin_auth_headers: dict, test_user: dict, test_db: AsyncSession + ): + """Test that changing account role creates an audit log entry.""" + user_id = test_user["user_data"]["id"] + await client.put( + f"/api/v1/admin/users/{user_id}/account-role", + json={"account_role": "viewer"}, + headers=admin_auth_headers + ) + + result = await test_db.execute( + select(AuditLog).where(AuditLog.action == "user.account_role_change") + ) + log = result.scalar_one_or_none() + assert log is not None + assert str(log.resource_id) == user_id + assert log.details["old_account_role"] == "owner" + assert log.details["new_account_role"] == "viewer" diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 5e578900..f1463949 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -23,6 +23,8 @@ class TestAuthentication: assert data["email"] == user_data["email"] assert data["name"] == user_data["name"] assert data["role"] == "engineer" + assert "account_id" in data + assert data["account_role"] == "owner" assert "id" in data assert "password" not in data # Password should not be returned @@ -107,6 +109,7 @@ class TestAuthentication: assert response.status_code == 201 data = response.json() assert data["role"] == "engineer" + assert data["account_role"] == "owner" @pytest.mark.asyncio async def test_register_default_role_is_engineer(self, client: AsyncClient): @@ -121,6 +124,7 @@ class TestAuthentication: assert response.status_code == 201 assert response.json()["role"] == "engineer" + assert response.json()["account_role"] == "owner" @pytest.mark.asyncio async def test_register_rejects_no_uppercase(self, client: AsyncClient): diff --git a/backend/tests/test_permissions_account.py b/backend/tests/test_permissions_account.py new file mode 100644 index 00000000..211a7340 --- /dev/null +++ b/backend/tests/test_permissions_account.py @@ -0,0 +1,205 @@ +"""Integration tests for account-based permissions.""" + +import pytest +from httpx import AsyncClient + + +class TestAccountPermissions: + """Test suite for account-based permission checks.""" + + @pytest.mark.asyncio + async def test_viewer_cannot_create_tree(self, client: AsyncClient, test_db): + """Test that viewers cannot create trees.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Register a user + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "viewer@example.com", + "password": "ViewerPass123!", + "name": "Viewer User" + }) + assert reg_resp.status_code == 201 + user_id = PyUUID(reg_resp.json()["id"]) + + # Demote to viewer via ORM + result = await test_db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one() + user.account_role = "viewer" + await test_db.commit() + + # Login as viewer + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "viewer@example.com", + "password": "ViewerPass123!" + }) + viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Try to create tree + response = await client.post("/api/v1/trees", json={ + "name": "Viewer Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "Test", "description": "Test"} + }, headers=viewer_headers) + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_viewer_can_list_trees(self, client: AsyncClient, auth_headers: dict, test_db): + """Test that viewers can browse/list trees.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Create a public tree as the regular user first + await client.post("/api/v1/trees", json={ + "name": "Public Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": True + }, headers=auth_headers) + + # Register viewer + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "viewer2@example.com", + "password": "ViewerPass123!", + "name": "Viewer 2" + }) + user_id = PyUUID(reg_resp.json()["id"]) + + result = await test_db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one() + user.account_role = "viewer" + await test_db.commit() + + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "viewer2@example.com", + "password": "ViewerPass123!" + }) + viewer_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Viewer can list trees + response = await client.get("/api/v1/trees", headers=viewer_headers) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_owner_can_edit_account_members_tree(self, client: AsyncClient, auth_headers: dict, test_db): + """Test that account owner can edit trees created by account members.""" + from sqlalchemy import select + from app.models.user import User + from uuid import UUID as PyUUID + + # Get owner's account + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id = me_resp.json()["account_id"] + + # Create invite + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "engineer@example.com", "role": "engineer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + # Register engineer in same account + reg_resp = await client.post("/api/v1/auth/register", json={ + "email": "engineer@example.com", + "password": "EngineerPass123!", + "name": "Engineer", + "account_invite_code": invite_code + }) + assert reg_resp.status_code == 201 + assert reg_resp.json()["account_id"] == account_id + + # Login as engineer + eng_login = await client.post("/api/v1/auth/login/json", json={ + "email": "engineer@example.com", + "password": "EngineerPass123!" + }) + eng_headers = {"Authorization": f"Bearer {eng_login.json()['access_token']}"} + + # Engineer creates a tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Engineer's Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"} + }, headers=eng_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Owner can edit engineer's tree + update_resp = await client.put( + f"/api/v1/trees/{tree_id}", + json={"name": "Owner Updated Name"}, + headers=auth_headers + ) + assert update_resp.status_code == 200 + assert update_resp.json()["name"] == "Owner Updated Name" + + @pytest.mark.asyncio + async def test_account_scoped_visibility(self, client: AsyncClient, auth_headers: dict): + """Test that account members can see each other's non-public trees.""" + # Get owner's account + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id = me_resp.json()["account_id"] + + # Owner creates a private tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Private Account Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": False + }, headers=auth_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Create invite and add member + invite_resp = await client.post( + "/api/v1/accounts/me/invites", + json={"email": "teammate@example.com", "role": "engineer"}, + headers=auth_headers + ) + invite_code = invite_resp.json()["code"] + + await client.post("/api/v1/auth/register", json={ + "email": "teammate@example.com", + "password": "TeammatePass123!", + "name": "Teammate", + "account_invite_code": invite_code + }) + + mate_login = await client.post("/api/v1/auth/login/json", json={ + "email": "teammate@example.com", + "password": "TeammatePass123!" + }) + mate_headers = {"Authorization": f"Bearer {mate_login.json()['access_token']}"} + + # Teammate should see the private tree (same account) + response = await client.get(f"/api/v1/trees/{tree_id}", headers=mate_headers) + assert response.status_code == 200 + assert response.json()["name"] == "Private Account Tree" + + @pytest.mark.asyncio + async def test_different_account_cannot_see_private_tree(self, client: AsyncClient, auth_headers: dict): + """Test that users from different accounts cannot see private trees.""" + # Owner creates a private tree + tree_resp = await client.post("/api/v1/trees", json={ + "name": "Secret Tree", + "tree_structure": {"id": "root", "type": "solution", "title": "T", "description": "T"}, + "is_public": False + }, headers=auth_headers) + assert tree_resp.status_code == 201 + tree_id = tree_resp.json()["id"] + + # Register a completely separate user (different account) + await client.post("/api/v1/auth/register", json={ + "email": "outsider@example.com", + "password": "OutsiderPass123!", + "name": "Outsider" + }) + + outsider_login = await client.post("/api/v1/auth/login/json", json={ + "email": "outsider@example.com", + "password": "OutsiderPass123!" + }) + outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"} + + # Outsider should NOT see the private tree + response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers) + assert response.status_code == 403 diff --git a/backend/tests/test_subscription_limits.py b/backend/tests/test_subscription_limits.py new file mode 100644 index 00000000..540e42af --- /dev/null +++ b/backend/tests/test_subscription_limits.py @@ -0,0 +1,129 @@ +"""Integration tests for subscription limits.""" + +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + + +class TestSubscriptionLimits: + """Test suite for subscription plan limits.""" + + @pytest.mark.asyncio + async def test_free_plan_tree_limit(self, client: AsyncClient, auth_headers: dict): + """Test that free plan has tree creation limit of 3.""" + tree_template = { + "name": "Limit Test Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test tree" + } + } + + # Create trees up to the limit + for i in range(3): + tree_data = {**tree_template, "name": f"Tree {i+1}"} + response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert response.status_code == 201, f"Failed creating tree {i+1}: {response.json()}" + + # 4th tree should fail with 402 + tree_data = {**tree_template, "name": "Tree 4 Over Limit"} + response = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert response.status_code == 402 + assert "limit" in response.json()["detail"].lower() + + @pytest.mark.asyncio + async def test_subscription_details_show_usage(self, client: AsyncClient, auth_headers: dict): + """Test that subscription details reflect actual usage.""" + # Check initial usage + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + initial_usage = response.json()["usage"] + assert initial_usage["tree_count"] == 0 + + # Create a tree + tree_data = { + "name": "Usage Test Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test" + } + } + create_resp = await client.post("/api/v1/trees", json=tree_data, headers=auth_headers) + assert create_resp.status_code == 201 + + # Check usage increased + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + updated_usage = response.json()["usage"] + assert updated_usage["tree_count"] == 1 + + @pytest.mark.asyncio + async def test_super_admin_bypasses_limits( + self, client: AsyncClient, admin_auth_headers: dict + ): + """Test that super admin can create trees without limit checks.""" + tree_template = { + "name": "Admin Tree", + "tree_structure": { + "id": "root", + "type": "solution", + "title": "Test", + "description": "Test tree" + }, + "is_default": True # Default trees skip limit check + } + + # Super admin creating default trees should always work + for i in range(5): + tree_data = {**tree_template, "name": f"Admin Tree {i+1}"} + response = await client.post( + "/api/v1/trees", json=tree_data, headers=admin_auth_headers + ) + assert response.status_code == 201 + + @pytest.mark.asyncio + async def test_free_plan_limits_correct(self, client: AsyncClient, auth_headers: dict): + """Test that free plan limits are correct.""" + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + limits = response.json()["limits"] + assert limits["plan"] == "free" + assert limits["max_trees"] == 3 + assert limits["max_sessions_per_month"] == 20 + assert limits["max_users"] == 1 + assert limits["custom_branding"] is False + assert limits["priority_support"] is False + + @pytest.mark.asyncio + async def test_upgraded_plan_has_higher_limits( + self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession + ): + """Test that upgrading plan increases limits.""" + from app.models.subscription import Subscription + from app.models.user import User + + # Get the user's subscription and upgrade it + me_resp = await client.get("/api/v1/auth/me", headers=auth_headers) + account_id_str = me_resp.json()["account_id"] + + from uuid import UUID + account_id = UUID(account_id_str) + result = await test_db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + sub = result.scalar_one() + sub.plan = "pro" + await test_db.commit() + + # Check limits are now pro + response = await client.get("/api/v1/accounts/me/subscription", headers=auth_headers) + assert response.status_code == 200 + limits = response.json()["limits"] + assert limits["plan"] == "pro" + assert limits["max_trees"] == 25 + assert limits["max_sessions_per_month"] == 200 From 7a6f839ef402c77578e0633487a48374f0729c6a Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:39:15 -0500 Subject: [PATCH 3/5] feat: update frontend for account-based subscriptions Replace all team_id/team_admin references with account_id/owner across types, store, hooks, API clients, components, and pages. Add new AccountSettingsPage, UpgradePrompt, CheckoutButton, useSubscription hook, and accounts API client. AuthStore now parallel-fetches account and subscription data alongside user profile. Also fix folder sidebar not refreshing after tree deletion by dispatching the folder-changed event in handleDeleteTree. Co-Authored-By: Claude Opus 4.6 --- frontend/package-lock.json | 10 + frontend/package.json | 1 + frontend/src/api/accounts.ts | 48 ++ frontend/src/api/categories.ts | 4 +- frontend/src/api/index.ts | 1 + frontend/src/api/tags.ts | 8 +- .../src/components/common/UpgradePrompt.tsx | 32 ++ frontend/src/components/layout/AppLayout.tsx | 9 +- .../src/components/layout/ProtectedRoute.tsx | 2 +- .../subscription/CheckoutButton.tsx | 24 + .../tree-editor/TreeMetadataForm.tsx | 2 +- frontend/src/hooks/usePermissions.ts | 18 +- frontend/src/hooks/useSubscription.ts | 45 ++ frontend/src/pages/AccountSettingsPage.tsx | 494 ++++++++++++++++++ frontend/src/pages/TreeEditorPage.tsx | 2 +- frontend/src/pages/TreeLibraryPage.tsx | 3 +- frontend/src/pages/index.ts | 1 + frontend/src/router.tsx | 5 + frontend/src/store/authStore.ts | 30 +- frontend/src/types/account.ts | 62 +++ frontend/src/types/category.ts | 6 +- frontend/src/types/index.ts | 1 + frontend/src/types/step.ts | 2 +- frontend/src/types/tag.ts | 6 +- frontend/src/types/tree.ts | 4 +- frontend/src/types/user.ts | 4 +- 26 files changed, 786 insertions(+), 38 deletions(-) create mode 100644 frontend/src/api/accounts.ts create mode 100644 frontend/src/components/common/UpgradePrompt.tsx create mode 100644 frontend/src/components/subscription/CheckoutButton.tsx create mode 100644 frontend/src/hooks/useSubscription.ts create mode 100644 frontend/src/pages/AccountSettingsPage.tsx create mode 100644 frontend/src/types/account.ts diff --git a/frontend/package-lock.json b/frontend/package-lock.json index d655c72c..24a295ec 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,6 +8,7 @@ "name": "frontend", "version": "0.0.0", "dependencies": { + "@stripe/stripe-js": "^8.7.0", "axios": "^1.13.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", @@ -1430,6 +1431,15 @@ "win32" ] }, + "node_modules/@stripe/stripe-js": { + "version": "8.7.0", + "resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-8.7.0.tgz", + "integrity": "sha512-tNUerSstwNC1KuHgX4CASGO0Md3CB26IJzSXmVlSuFvhsBP4ZaEPpY4jxWOn9tfdDscuVT4Kqb8cZ2o9nLCgRQ==", + "license": "MIT", + "engines": { + "node": ">=12.16" + } + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", diff --git a/frontend/package.json b/frontend/package.json index c038dd9c..672da33f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,6 +10,7 @@ "preview": "vite preview" }, "dependencies": { + "@stripe/stripe-js": "^8.7.0", "axios": "^1.13.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/frontend/src/api/accounts.ts b/frontend/src/api/accounts.ts new file mode 100644 index 00000000..07bffc9a --- /dev/null +++ b/frontend/src/api/accounts.ts @@ -0,0 +1,48 @@ +import apiClient from './client' +import type { Account, SubscriptionDetails, AccountMember, AccountInvite } from '@/types' + +export const accountsApi = { + async getMyAccount(): Promise { + const response = await apiClient.get('/accounts/me') + return response.data + }, + + async getMySubscription(): Promise { + const response = await apiClient.get('/accounts/me/subscription') + return response.data + }, + + async updateMyAccount(data: { name?: string }): Promise { + const response = await apiClient.patch('/accounts/me', data) + return response.data + }, + + async getMembers(): Promise { + const response = await apiClient.get('/accounts/me/members') + return response.data + }, + + async updateMemberRole(userId: string, role: string): Promise { + const response = await apiClient.patch( + `/accounts/me/members/${userId}/role`, + { role } + ) + return response.data + }, + + async removeMember(userId: string): Promise { + await apiClient.delete(`/accounts/me/members/${userId}`) + }, + + async createInvite(data: { email: string; role: string }): Promise { + const response = await apiClient.post('/accounts/me/invites', data) + return response.data + }, + + async getInvites(): Promise { + const response = await apiClient.get('/accounts/me/invites') + return response.data + }, +} + +export default accountsApi diff --git a/frontend/src/api/categories.ts b/frontend/src/api/categories.ts index f3b38085..e227e9c6 100644 --- a/frontend/src/api/categories.ts +++ b/frontend/src/api/categories.ts @@ -2,9 +2,9 @@ import apiClient from './client' import type { Category, CategoryListItem, CategoryCreate, CategoryUpdate } from '@/types' export const categoriesApi = { - async list(includeInactive = false, teamOnly = false): Promise { + async list(includeInactive = false, accountOnly = false): Promise { const response = await apiClient.get('/categories', { - params: { include_inactive: includeInactive, team_only: teamOnly }, + params: { include_inactive: includeInactive, account_only: accountOnly }, }) return response.data }, diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 267fafa1..a03da465 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -8,3 +8,4 @@ export { default as categoriesApi } from './categories' export { default as foldersApi } from './folders' export { default as stepsApi } from './steps' export { default as stepCategoriesApi } from './stepCategories' +export { default as accountsApi } from './accounts' diff --git a/frontend/src/api/tags.ts b/frontend/src/api/tags.ts index 2907209c..5dfe46fe 100644 --- a/frontend/src/api/tags.ts +++ b/frontend/src/api/tags.ts @@ -2,16 +2,16 @@ import apiClient from './client' import type { Tag, TagListItem, TagCreate, TagAssignment } from '@/types' export const tagsApi = { - async list(includeTeam = true): Promise { + async list(includeAccount = true): Promise { const response = await apiClient.get('/tags', { - params: { include_team: includeTeam }, + params: { include_account: includeAccount }, }) return response.data }, - async search(query: string, limit = 10, includeTeam = true): Promise { + async search(query: string, limit = 10, includeAccount = true): Promise { const response = await apiClient.get('/tags/search', { - params: { q: query, limit, include_team: includeTeam }, + params: { q: query, limit, include_account: includeAccount }, }) return response.data }, diff --git a/frontend/src/components/common/UpgradePrompt.tsx b/frontend/src/components/common/UpgradePrompt.tsx new file mode 100644 index 00000000..e67c17fa --- /dev/null +++ b/frontend/src/components/common/UpgradePrompt.tsx @@ -0,0 +1,32 @@ +import { cn } from '@/lib/utils' +import { useSubscription } from '@/hooks/useSubscription' + +interface UpgradePromptProps { + feature: string // e.g., "create more trees", "start more sessions" + className?: string +} + +export function UpgradePrompt({ feature, className }: UpgradePromptProps) { + const { plan } = useSubscription() + + return ( +
+

Plan Limit Reached

+

+ Your {plan} plan doesn't allow you to {feature}. Upgrade your plan to continue. +

+ +
+ ) +} diff --git a/frontend/src/components/layout/AppLayout.tsx b/frontend/src/components/layout/AppLayout.tsx index b543f913..fbec1a01 100644 --- a/frontend/src/components/layout/AppLayout.tsx +++ b/frontend/src/components/layout/AppLayout.tsx @@ -49,6 +49,7 @@ export function AppLayout() { const navItems = [ { path: '/trees', label: 'Trees' }, { path: '/sessions', label: 'Sessions' }, + { path: '/account', label: 'Account' }, { path: '/settings', label: 'Settings' }, ] @@ -98,12 +99,12 @@ export function AppLayout() { className={cn( 'hidden rounded-full px-2 py-0.5 text-xs font-medium sm:inline-block', effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400', - effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400', + effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400', effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400' )} > {effectiveRole === 'super_admin' ? 'Super Admin' : - effectiveRole === 'team_admin' ? 'Team Admin' : + effectiveRole === 'owner' ? 'Owner' : 'Viewer'} )} @@ -158,12 +159,12 @@ export function AppLayout() { className={cn( 'mt-1 inline-block rounded-full px-2 py-0.5 text-xs font-medium', effectiveRole === 'super_admin' && 'bg-red-500/10 text-red-600 dark:text-red-400', - effectiveRole === 'team_admin' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400', + effectiveRole === 'owner' && 'bg-blue-500/10 text-blue-600 dark:text-blue-400', effectiveRole === 'viewer' && 'bg-gray-500/10 text-gray-600 dark:text-gray-400' )} > {effectiveRole === 'super_admin' ? 'Super Admin' : - effectiveRole === 'team_admin' ? 'Team Admin' : + effectiveRole === 'owner' ? 'Owner' : 'Viewer'} )} diff --git a/frontend/src/components/layout/ProtectedRoute.tsx b/frontend/src/components/layout/ProtectedRoute.tsx index 96163c44..c432c3b9 100644 --- a/frontend/src/components/layout/ProtectedRoute.tsx +++ b/frontend/src/components/layout/ProtectedRoute.tsx @@ -27,7 +27,7 @@ export function ProtectedRoute({ requiredRole, children }: ProtectedRouteProps) if (requiredRole) { const ROLE_HIERARCHY: Record = { super_admin: 4, - team_admin: 3, + owner: 3, engineer: 2, viewer: 1, } diff --git a/frontend/src/components/subscription/CheckoutButton.tsx b/frontend/src/components/subscription/CheckoutButton.tsx new file mode 100644 index 00000000..c3e2d332 --- /dev/null +++ b/frontend/src/components/subscription/CheckoutButton.tsx @@ -0,0 +1,24 @@ +import { cn } from '@/lib/utils' + +interface CheckoutButtonProps { + plan: 'pro' | 'team' + className?: string +} + +export function CheckoutButton({ plan, className }: CheckoutButtonProps) { + const planLabels = { pro: 'Pro', team: 'Team' } + + return ( + + ) +} diff --git a/frontend/src/components/tree-editor/TreeMetadataForm.tsx b/frontend/src/components/tree-editor/TreeMetadataForm.tsx index eb99831d..2adc9fff 100644 --- a/frontend/src/components/tree-editor/TreeMetadataForm.tsx +++ b/frontend/src/components/tree-editor/TreeMetadataForm.tsx @@ -119,7 +119,7 @@ export function TreeMetadataForm() { {categories.map((cat) => ( ))} diff --git a/frontend/src/hooks/usePermissions.ts b/frontend/src/hooks/usePermissions.ts index 5dc39fb2..55706d4a 100644 --- a/frontend/src/hooks/usePermissions.ts +++ b/frontend/src/hooks/usePermissions.ts @@ -1,18 +1,18 @@ /** * Centralized permissions hook for ResolutionFlow. * - * Role hierarchy: super_admin > team_admin > engineer > viewer + * Role hierarchy: super_admin > owner > engineer > viewer * * Mirrors backend logic in backend/app/core/permissions.py */ import { useAuthStore } from '@/store/authStore' import type { User } from '@/types' -export type EffectiveRole = 'super_admin' | 'team_admin' | 'engineer' | 'viewer' +export type EffectiveRole = 'super_admin' | 'owner' | 'engineer' | 'viewer' const ROLE_HIERARCHY: Record = { super_admin: 4, - team_admin: 3, + owner: 3, engineer: 2, viewer: 1, } @@ -20,7 +20,7 @@ const ROLE_HIERARCHY: Record = { function getEffectiveRole(user: User | null): EffectiveRole { if (!user) return 'viewer' if (user.is_super_admin) return 'super_admin' - if (user.is_team_admin && user.team_id) return 'team_admin' + if (user.account_role === 'owner') return 'owner' return user.role as EffectiveRole } @@ -37,7 +37,7 @@ export function usePermissions() { return { effectiveRole, isSuperAdmin: effectiveRole === 'super_admin', - isTeamAdmin: effectiveRole === 'team_admin' || effectiveRole === 'super_admin', + isAccountOwner: effectiveRole === 'owner' || effectiveRole === 'super_admin', isEngineer: hasMinimumRole(user, 'engineer'), isViewer: effectiveRole === 'viewer', @@ -46,12 +46,12 @@ export function usePermissions() { canCreateSteps: hasMinimumRole(user, 'engineer'), // Resource-specific checks - canEditTree: (tree: { author_id: string | null; team_id?: string | null }) => { + canEditTree: (tree: { author_id: string | null; account_id?: string | null }) => { if (!user) return false if (user.is_super_admin) return true if (!hasMinimumRole(user, 'engineer')) return false if (tree.author_id && tree.author_id === user.id) return true - if (user.is_team_admin && tree.team_id === user.team_id && user.team_id) return true + if (user.account_role === 'owner' && tree.account_id === user.account_id && user.account_id) return true return false }, @@ -68,8 +68,8 @@ export function usePermissions() { }, // Management permissions - canManageCategories: hasMinimumRole(user, 'team_admin'), + canManageCategories: hasMinimumRole(user, 'owner'), canManageGlobalCategories: effectiveRole === 'super_admin', - canManageTeam: effectiveRole === 'super_admin' || (effectiveRole === 'team_admin'), + canManageAccount: effectiveRole === 'super_admin' || effectiveRole === 'owner', } } diff --git a/frontend/src/hooks/useSubscription.ts b/frontend/src/hooks/useSubscription.ts new file mode 100644 index 00000000..715a86bb --- /dev/null +++ b/frontend/src/hooks/useSubscription.ts @@ -0,0 +1,45 @@ +import { useAuthStore } from '@/store/authStore' + +export function useSubscription() { + const subscription = useAuthStore((s) => s.subscription) + + const plan = subscription?.subscription.plan ?? 'free' + const limits = subscription?.limits ?? null + const usage = subscription?.usage ?? null + const isActive = subscription?.subscription.status === 'active' || subscription?.subscription.status === 'trialing' + + const isPaidPlan = plan === 'pro' || plan === 'team' + + const canUseFeature = (feature: 'custom_branding' | 'priority_support'): boolean => { + if (!limits) return false + return limits[feature] + } + + const isAtTreeLimit = (): boolean => { + if (!limits || !usage) return false + if (limits.max_trees === null) return false // unlimited + return usage.tree_count >= limits.max_trees + } + + const isAtSessionLimit = (): boolean => { + if (!limits || !usage) return false + if (limits.max_sessions_per_month === null) return false + return usage.session_count_this_month >= limits.max_sessions_per_month + } + + const formatLimit = (value: number | null): string => { + return value === null ? 'Unlimited' : String(value) + } + + return { + plan, + limits, + usage, + isActive, + isPaidPlan, + canUseFeature, + isAtTreeLimit, + isAtSessionLimit, + formatLimit, + } +} diff --git a/frontend/src/pages/AccountSettingsPage.tsx b/frontend/src/pages/AccountSettingsPage.tsx new file mode 100644 index 00000000..d58e9058 --- /dev/null +++ b/frontend/src/pages/AccountSettingsPage.tsx @@ -0,0 +1,494 @@ +import { useEffect, useState } from 'react' +import { Building2, Users, Mail, Crown, Loader2, AlertCircle, Check, X } from 'lucide-react' +import { accountsApi } from '@/api' +import type { Account, AccountMember, AccountInvite } from '@/types' +import { cn } from '@/lib/utils' +import { usePermissions } from '@/hooks/usePermissions' +import { useSubscription } from '@/hooks/useSubscription' +import { useAuthStore } from '@/store/authStore' +import { CheckoutButton } from '@/components/subscription/CheckoutButton' + +export function AccountSettingsPage() { + const { isAccountOwner } = usePermissions() + const { plan, limits, usage } = useSubscription() + const subscription = useAuthStore((s) => s.subscription) + + const [account, setAccount] = useState(null) + const [members, setMembers] = useState([]) + const [invites, setInvites] = useState([]) + const [isLoading, setIsLoading] = useState(true) + const [error, setError] = useState(null) + + // Account name editing + const [isEditingName, setIsEditingName] = useState(false) + const [editedName, setEditedName] = useState('') + const [isSavingName, setIsSavingName] = useState(false) + + // Invite form + const [inviteEmail, setInviteEmail] = useState('') + const [inviteRole, setInviteRole] = useState('engineer') + const [isInviting, setIsInviting] = useState(false) + const [inviteError, setInviteError] = useState(null) + const [inviteSuccess, setInviteSuccess] = useState(null) + + useEffect(() => { + loadData() + }, []) + + const loadData = async () => { + setIsLoading(true) + setError(null) + try { + const accountData = await accountsApi.getMyAccount() + setAccount(accountData) + setEditedName(accountData.name) + + if (isAccountOwner) { + const [membersData, invitesData] = await Promise.all([ + accountsApi.getMembers(), + accountsApi.getInvites(), + ]) + setMembers(membersData) + setInvites(invitesData) + } + } catch (err) { + setError('Failed to load account information') + console.error(err) + } finally { + setIsLoading(false) + } + } + + const handleSaveName = async () => { + if (!editedName.trim() || editedName === account?.name) { + setIsEditingName(false) + return + } + setIsSavingName(true) + try { + const updated = await accountsApi.updateMyAccount({ name: editedName.trim() }) + setAccount(updated) + setIsEditingName(false) + } catch (err) { + console.error('Failed to update account name:', err) + } finally { + setIsSavingName(false) + } + } + + const handleInvite = async (e: React.FormEvent) => { + e.preventDefault() + if (!inviteEmail.trim()) return + + setIsInviting(true) + setInviteError(null) + setInviteSuccess(null) + try { + await accountsApi.createInvite({ email: inviteEmail.trim(), role: inviteRole }) + setInviteSuccess(`Invitation sent to ${inviteEmail}`) + setInviteEmail('') + // Refresh invites list + const invitesData = await accountsApi.getInvites() + setInvites(invitesData) + } catch (err) { + setInviteError('Failed to send invitation') + console.error(err) + } finally { + setIsInviting(false) + } + } + + const handleRemoveMember = async (userId: string) => { + try { + await accountsApi.removeMember(userId) + setMembers(members.filter((m) => m.id !== userId)) + } catch (err) { + console.error('Failed to remove member:', err) + } + } + + if (isLoading) { + return ( +
+
+
+ ) + } + + if (error) { + return ( +
+
+
+ + {error} +
+
+
+ ) + } + + const sub = subscription?.subscription + + return ( +
+
+
+ +

Account Settings

+
+

+ Manage your account, subscription, and team +

+
+ +
+ {/* Account Info Section */} +
+

Account Information

+ +
+ {/* Account Name */} +
+ + {isEditingName ? ( +
+ setEditedName(e.target.value)} + className={cn( + 'flex-1 rounded-md border border-input bg-background px-3 py-2', + 'text-foreground placeholder:text-muted-foreground', + 'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary' + )} + autoFocus + onKeyDown={(e) => { + if (e.key === 'Enter') handleSaveName() + if (e.key === 'Escape') { + setEditedName(account?.name ?? '') + setIsEditingName(false) + } + }} + /> + + +
+ ) : ( +
+ {account?.name} + {isAccountOwner && ( + + )} +
+ )} +
+ + {/* Display Code */} +
+ +

+ {account?.display_code} +

+
+
+
+ + {/* Subscription Section */} +
+

Subscription

+ +
+ {/* Plan & Status */} +
+ + + {plan.charAt(0).toUpperCase() + plan.slice(1)} Plan + + {sub && ( + + {sub.status.charAt(0).toUpperCase() + sub.status.slice(1).replace('_', ' ')} + + )} +
+ + {sub?.current_period_end && ( +

+ Current period ends: {new Date(sub.current_period_end).toLocaleDateString()} +

+ )} + + {/* Usage Stats */} + {limits && usage && ( +
+ + + +
+ )} + + {/* Upgrade buttons */} + {plan === 'free' && ( +
+ + +
+ )} + {plan === 'pro' && ( +
+ +
+ )} +
+
+ + {/* Team Members Section (owners only) */} + {isAccountOwner && ( +
+
+ +

Team Members

+
+ + {members.length === 0 ? ( +

No team members yet.

+ ) : ( +
+ {members.map((member) => ( +
+
+

{member.name}

+

{member.email}

+
+
+ + {member.account_role} + + {!member.is_active && ( + + Inactive + + )} + {member.account_role !== 'owner' && ( + + )} +
+
+ ))} +
+ )} +
+ )} + + {/* Invite Member Section (owners only) */} + {isAccountOwner && ( +
+
+ +

Invite Member

+
+ +
+
+ setInviteEmail(e.target.value)} + required + className={cn( + 'flex-1 rounded-md border border-input bg-background px-3 py-2', + 'text-foreground placeholder:text-muted-foreground', + 'focus:border-primary focus:outline-none focus:ring-1 focus:ring-primary' + )} + /> + + +
+ + {inviteError && ( +

{inviteError}

+ )} + {inviteSuccess && ( +

{inviteSuccess}

+ )} +
+ + {/* Pending Invites */} + {invites.length > 0 && ( +
+

Pending Invites

+
+ {invites + .filter((inv) => !inv.used_at) + .map((invite) => ( +
+
+

{invite.email}

+

+ Expires {new Date(invite.expires_at).toLocaleDateString()} +

+
+ + {invite.role} + +
+ ))} +
+
+ )} +
+ )} +
+
+ ) +} + +/** Small helper component for usage stat display */ +function UsageStat({ + label, + current, + max, +}: { + label: string + current: number + max: number | null +}) { + const isUnlimited = max === null + const percentage = isUnlimited ? 0 : Math.min((current / max) * 100, 100) + const isNearLimit = !isUnlimited && percentage >= 80 + const isAtLimit = !isUnlimited && current >= max + + return ( +
+

{label}

+

+ {current} + + {' '}/ {isUnlimited ? 'Unlimited' : max} + +

+ {!isUnlimited && ( +
+
+
+ )} +
+ ) +} + +export default AccountSettingsPage diff --git a/frontend/src/pages/TreeEditorPage.tsx b/frontend/src/pages/TreeEditorPage.tsx index 374da3fb..781fa26f 100644 --- a/frontend/src/pages/TreeEditorPage.tsx +++ b/frontend/src/pages/TreeEditorPage.tsx @@ -102,7 +102,7 @@ export function TreeEditorPage() { setLoading(true) try { const tree = await treesApi.get(id) - if (!canEditTree({ author_id: tree.author_id, team_id: tree.team_id })) { + if (!canEditTree({ author_id: tree.author_id, account_id: tree.account_id })) { navigate('/trees') return } diff --git a/frontend/src/pages/TreeLibraryPage.tsx b/frontend/src/pages/TreeLibraryPage.tsx index 30b89c8c..84b0271c 100644 --- a/frontend/src/pages/TreeLibraryPage.tsx +++ b/frontend/src/pages/TreeLibraryPage.tsx @@ -137,6 +137,7 @@ export function TreeLibraryPage() { try { await treesApi.delete(treeToDelete.id) setTrees(trees.filter((t) => t.id !== treeToDelete.id)) + window.dispatchEvent(new Event('folder-changed')) } catch (err) { console.error('Failed to delete tree:', err) setError('Failed to delete tree') @@ -352,7 +353,7 @@ export function TreeLibraryPage() {
- {canEditTree({ author_id: tree.author_id, team_id: tree.team_id }) && ( + {canEditTree({ author_id: tree.author_id, account_id: tree.account_id }) && ( , }, + { + path: 'account', + element: , + }, ], }, ]) diff --git a/frontend/src/store/authStore.ts b/frontend/src/store/authStore.ts index 457697c3..cb8976ee 100644 --- a/frontend/src/store/authStore.ts +++ b/frontend/src/store/authStore.ts @@ -1,11 +1,14 @@ import { create } from 'zustand' import { persist } from 'zustand/middleware' -import type { User, Token, UserCreate, UserLogin } from '@/types' +import type { User, Token, UserCreate, UserLogin, Account, SubscriptionDetails } from '@/types' import { authApi } from '@/api' +import { apiClient } from '@/api' interface AuthState { user: User | null token: Token | null + account: Account | null + subscription: SubscriptionDetails | null isAuthenticated: boolean isLoading: boolean error: string | null @@ -25,6 +28,8 @@ export const useAuthStore = create()( (set, get) => ({ user: null, token: null, + account: null, + subscription: null, isAuthenticated: false, isLoading: false, error: null, @@ -70,15 +75,30 @@ export const useAuthStore = create()( } finally { localStorage.removeItem('access_token') localStorage.removeItem('refresh_token') - set({ user: null, token: null, isAuthenticated: false, error: null }) + set({ user: null, token: null, account: null, subscription: null, isAuthenticated: false, error: null }) } }, fetchUser: async () => { set({ isLoading: true }) try { - const user = await authApi.me() - set({ user, isLoading: false }) + const [userResult, accountResult, subscriptionResult] = await Promise.allSettled([ + authApi.me(), + apiClient.get('/accounts/me').then(r => r.data), + apiClient.get('/accounts/me/subscription').then(r => r.data), + ]) + + const user = userResult.status === 'fulfilled' ? userResult.value : null + const account = accountResult.status === 'fulfilled' ? accountResult.value : null + const subscription = subscriptionResult.status === 'fulfilled' ? subscriptionResult.value : null + + if (!user) { + // User fetch failed — propagate the error + const reason = userResult.status === 'rejected' ? userResult.reason : new Error('Failed to fetch user') + throw reason + } + + set({ user, account, subscription, isLoading: false }) } catch (error: unknown) { const message = error instanceof Error ? error.message : 'Failed to fetch user' set({ error: message, isLoading: false }) @@ -95,6 +115,8 @@ export const useAuthStore = create()( partialize: (state) => ({ token: state.token, isAuthenticated: state.isAuthenticated, + account: state.account, + subscription: state.subscription, }), } ) diff --git a/frontend/src/types/account.ts b/frontend/src/types/account.ts new file mode 100644 index 00000000..4841530e --- /dev/null +++ b/frontend/src/types/account.ts @@ -0,0 +1,62 @@ +export interface Account { + id: string + name: string + display_code: string + owner_id: string + created_at: string + updated_at: string +} + +export interface Subscription { + id: string + account_id: string + plan: 'free' | 'pro' | 'team' + status: 'active' | 'past_due' | 'canceled' | 'trialing' | 'orphaned' + current_period_start: string | null + current_period_end: string | null + created_at: string + updated_at: string +} + +export interface PlanLimits { + plan: string + max_trees: number | null + max_sessions_per_month: number | null + max_users: number | null + custom_branding: boolean + priority_support: boolean + export_formats: string[] +} + +export interface SubscriptionDetails { + subscription: Subscription + limits: PlanLimits + usage: { + tree_count: number + session_count_this_month: number + user_count: number + } +} + +export interface AccountInvite { + id: string + account_id: string + email: string + role: 'engineer' | 'viewer' + code: string + invited_by_id: string + accepted_by_id: string | null + expires_at: string + used_at: string | null + created_at: string +} + +export interface AccountMember { + id: string + email: string + name: string + account_role: string + is_active: boolean + created_at: string + last_login: string | null +} diff --git a/frontend/src/types/category.ts b/frontend/src/types/category.ts index 6a318768..906a9b14 100644 --- a/frontend/src/types/category.ts +++ b/frontend/src/types/category.ts @@ -5,7 +5,7 @@ export interface Category { name: string slug: string description: string | null - team_id: string | null + account_id: string | null display_order: number is_active: boolean created_at: string @@ -18,7 +18,7 @@ export interface CategoryListItem { name: string slug: string description: string | null - team_id: string | null + account_id: string | null display_order: number is_active: boolean tree_count: number @@ -27,7 +27,7 @@ export interface CategoryListItem { export interface CategoryCreate { name: string description?: string | null - team_id?: string | null + account_id?: string | null } export interface CategoryUpdate { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 9c1f892b..0172e43d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -7,6 +7,7 @@ export * from './tag' export * from './category' export * from './folder' export * from './step' +export type { Account, Subscription, PlanLimits, SubscriptionDetails, AccountInvite, AccountMember } from './account' // API response wrapper types export interface PaginatedResponse { diff --git a/frontend/src/types/step.ts b/frontend/src/types/step.ts index 19ead785..d9f12f00 100644 --- a/frontend/src/types/step.ts +++ b/frontend/src/types/step.ts @@ -56,7 +56,7 @@ export interface StepCategory { name: string description?: string display_order: number - team_id?: string + account_id?: string is_active: boolean } diff --git a/frontend/src/types/tag.ts b/frontend/src/types/tag.ts index e33e2a9a..4b8813c7 100644 --- a/frontend/src/types/tag.ts +++ b/frontend/src/types/tag.ts @@ -4,7 +4,7 @@ export interface Tag { id: string name: string slug: string - team_id: string | null + account_id: string | null usage_count: number created_at: string } @@ -13,13 +13,13 @@ export interface TagListItem { id: string name: string slug: string - team_id: string | null + account_id: string | null usage_count: number } export interface TagCreate { name: string - team_id?: string | null + account_id?: string | null } export interface TagAssignment { diff --git a/frontend/src/types/tree.ts b/frontend/src/types/tree.ts index 1774068e..f62789c6 100644 --- a/frontend/src/types/tree.ts +++ b/frontend/src/types/tree.ts @@ -67,7 +67,7 @@ export interface Tree { tags: string[] tree_structure: TreeStructure author_id: string | null - team_id: string | null + account_id: string | null is_active: boolean is_public: boolean is_default: boolean @@ -86,7 +86,7 @@ export interface TreeListItem { category_info: CategoryInfo | null tags: string[] author_id: string | null - team_id: string | null + account_id: string | null is_active: boolean is_public: boolean is_default: boolean diff --git a/frontend/src/types/user.ts b/frontend/src/types/user.ts index b5fe4a59..89136508 100644 --- a/frontend/src/types/user.ts +++ b/frontend/src/types/user.ts @@ -6,8 +6,8 @@ export interface User { name: string role: UserRole is_super_admin: boolean - is_team_admin: boolean - team_id: string | null + account_id: string + account_role: 'owner' | 'engineer' | 'viewer' created_at: string last_login: string | null } From 974e86a5023f0e4804df511d9d598301c514c67e Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:55:53 -0500 Subject: [PATCH 4/5] fix: resolve circular FK between users and accounts on registration Account.owner_id and User.account_id are both NOT NULL, creating a circular dependency that prevents inserting either row first. Fix by: 1. Making owner_id nullable (set immediately after user creation) 2. Creating Account before User, then setting owner_id after flush 3. Removing NOT NULL enforcement on owner_id in migration 020 Co-Authored-By: Claude Opus 4.6 --- .../020_finalize_account_migration.py | 3 +- backend/app/api/endpoints/auth.py | 45 +++++++++++-------- backend/app/models/account.py | 2 +- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/backend/alembic/versions/020_finalize_account_migration.py b/backend/alembic/versions/020_finalize_account_migration.py index 510c8abb..2cd29772 100644 --- a/backend/alembic/versions/020_finalize_account_migration.py +++ b/backend/alembic/versions/020_finalize_account_migration.py @@ -81,8 +81,7 @@ def upgrade() -> None: ['account_id'], ['id'], ondelete='CASCADE' ) - # 4. Accounts: enforce owner_id NOT NULL + FK - op.alter_column('accounts', 'owner_id', nullable=False) + # 4. Accounts: add owner FK (owner_id stays nullable due to circular FK with users) op.create_foreign_key( 'fk_accounts_owner_id', 'accounts', 'users', ['owner_id'], ['id'], ondelete='RESTRICT' diff --git a/backend/app/api/endpoints/auth.py b/backend/app/api/endpoints/auth.py index 385c31fd..44384c73 100644 --- a/backend/app/api/endpoints/auth.py +++ b/backend/app/api/endpoints/auth.py @@ -134,35 +134,47 @@ async def register( detail="Email already registered" ) - # Create new user - new_user = User( - email=user_data.email, - password_hash=get_password_hash(user_data.password), - name=user_data.name, - role="engineer", - invite_code_id=invite_code_record.id if invite_code_record else None - ) - db.add(new_user) - await db.flush() # Get user ID before creating account - if account_invite_record: # Join existing account via account invite - new_user.account_id = account_invite_record.account_id - new_user.account_role = account_invite_record.role + new_user = User( + email=user_data.email, + password_hash=get_password_hash(user_data.password), + name=user_data.name, + role="engineer", + invite_code_id=invite_code_record.id if invite_code_record else None, + account_id=account_invite_record.account_id, + account_role=account_invite_record.role, + ) + db.add(new_user) + await db.flush() # Mark account invite as used account_invite_record.accepted_by_id = new_user.id account_invite_record.used_at = datetime.now(timezone.utc) else: - # Create personal Account + free Subscription + # Create personal Account first (user needs account_id for NOT NULL constraint) new_account = Account( name=f"{user_data.name}'s Account", display_code=_generate_display_code(), - owner_id=new_user.id, ) db.add(new_account) await db.flush() # Get account ID + new_user = User( + email=user_data.email, + password_hash=get_password_hash(user_data.password), + name=user_data.name, + role="engineer", + invite_code_id=invite_code_record.id if invite_code_record else None, + account_id=new_account.id, + account_role="owner", + ) + db.add(new_user) + await db.flush() # Get user ID + + # Now set account owner and create subscription + new_account.owner_id = new_user.id + new_subscription = Subscription( account_id=new_account.id, plan="free", @@ -170,9 +182,6 @@ async def register( ) db.add(new_subscription) - new_user.account_id = new_account.id - new_user.account_role = "owner" - # Mark platform invite code as used if invite_code_record: invite_code_record.used_by_id = new_user.id diff --git a/backend/app/models/account.py b/backend/app/models/account.py index 6506488f..e9e8be18 100644 --- a/backend/app/models/account.py +++ b/backend/app/models/account.py @@ -22,7 +22,7 @@ class Account(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name: Mapped[str] = mapped_column(String(255), nullable=False) display_code: Mapped[str] = mapped_column(String(8), unique=True, nullable=False) - owner_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=False) + owner_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=True) stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) From 8dbb87e4d2aab6b49a575fad6c07409bfd29f41b Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 03:06:35 -0500 Subject: [PATCH 5/5] fix: add migration 021 to make accounts.owner_id nullable on existing DBs Railway already ran the old migration 020 which enforced NOT NULL on owner_id. Since alembic won't re-run a corrected 020, this new migration explicitly reverts the constraint for databases that already applied it. Co-Authored-By: Claude Opus 4.6 --- .../versions/021_fix_owner_id_nullable.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 backend/alembic/versions/021_fix_owner_id_nullable.py diff --git a/backend/alembic/versions/021_fix_owner_id_nullable.py b/backend/alembic/versions/021_fix_owner_id_nullable.py new file mode 100644 index 00000000..28322bd2 --- /dev/null +++ b/backend/alembic/versions/021_fix_owner_id_nullable.py @@ -0,0 +1,30 @@ +"""fix accounts.owner_id to be nullable (circular FK with users) + +Revision ID: 021 +Revises: 020 +Create Date: 2026-02-07 + +The original migration 020 enforced NOT NULL on owner_id, but this creates +a circular FK problem: Account needs owner_id (User) and User needs +account_id (Account). Making owner_id nullable resolves this — we create +the Account first, then the User, then set owner_id. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '021' +down_revision: Union[str, None] = '020' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.alter_column('accounts', 'owner_id', nullable=True) + + +def downgrade() -> None: + op.alter_column('accounts', 'owner_id', nullable=False)