From 4ccb93ee31c6ca5ef52249e4d98c418e1a72f009 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Sat, 7 Feb 2026 02:38:47 -0500 Subject: [PATCH] feat: add account-based subscription model with migrations Transition from team-based to account-based multi-tenancy (Free/Pro/Team). Migrations 016-020 create accounts, subscriptions, plan_limits, and account_invites tables, then migrate existing users and content FKs. New models: Account, Subscription, PlanLimits, AccountInvite. Updated models add account_id alongside existing team_id (coexistence for safe two-PR deployment). Permissions and deps refactored for account_role instead of is_team_admin. Co-Authored-By: Claude Opus 4.6 --- .../versions/016_add_subscription_tables.py | 110 +++++++++++ .../versions/017_add_account_id_to_users.py | 31 +++ .../versions/018_migrate_users_to_accounts.py | 187 ++++++++++++++++++ .../019_migrate_team_fks_to_account.py | 56 ++++++ .../020_finalize_account_migration.py | 105 ++++++++++ backend/app/api/deps.py | 40 +++- backend/app/core/config.py | 10 + backend/app/core/permissions.py | 58 +++--- backend/app/core/stripe_handlers.py | 37 ++++ backend/app/core/subscriptions.py | 113 +++++++++++ backend/app/models/__init__.py | 8 + backend/app/models/account.py | 38 ++++ backend/app/models/account_invite.py | 48 +++++ backend/app/models/category.py | 12 +- backend/app/models/plan_limits.py | 16 ++ backend/app/models/step_category.py | 12 +- backend/app/models/step_library.py | 8 + backend/app/models/subscription.py | 39 ++++ backend/app/models/tag.py | 12 +- backend/app/models/tree.py | 8 + backend/app/models/user.py | 29 ++- backend/requirements.txt | 3 + 22 files changed, 933 insertions(+), 47 deletions(-) create mode 100644 backend/alembic/versions/016_add_subscription_tables.py create mode 100644 backend/alembic/versions/017_add_account_id_to_users.py create mode 100644 backend/alembic/versions/018_migrate_users_to_accounts.py create mode 100644 backend/alembic/versions/019_migrate_team_fks_to_account.py create mode 100644 backend/alembic/versions/020_finalize_account_migration.py create mode 100644 backend/app/core/stripe_handlers.py create mode 100644 backend/app/core/subscriptions.py create mode 100644 backend/app/models/account.py create mode 100644 backend/app/models/account_invite.py create mode 100644 backend/app/models/plan_limits.py create mode 100644 backend/app/models/subscription.py diff --git a/backend/alembic/versions/016_add_subscription_tables.py b/backend/alembic/versions/016_add_subscription_tables.py new file mode 100644 index 00000000..960726cc --- /dev/null +++ b/backend/alembic/versions/016_add_subscription_tables.py @@ -0,0 +1,110 @@ +"""add accounts, subscriptions, plan_limits, and account_invites tables + +Revision ID: 016 +Revises: 015 +Create Date: 2026-02-07 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + + +# revision identifiers, used by Alembic. +revision: str = '016' +down_revision: Union[str, None] = '015' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 1. accounts table + op.create_table( + 'accounts', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('display_code', sa.String(8), nullable=False), + sa.Column('owner_id', UUID(as_uuid=True), nullable=True), # nullable until user created + sa.Column('stripe_customer_id', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('display_code', name='uq_accounts_display_code'), + ) + + # 2. subscriptions table + op.create_table( + 'subscriptions', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('stripe_subscription_id', sa.String(255), nullable=True), + sa.Column('stripe_price_id', sa.String(255), nullable=True), + sa.Column('plan', sa.String(50), nullable=False, server_default='free'), + sa.Column('billing_interval', sa.String(20), nullable=True), # 'monthly' or 'annual' + sa.Column('status', sa.String(50), nullable=False, server_default='active'), + sa.Column('seat_limit', sa.Integer, nullable=True), + sa.Column('current_period_start', sa.DateTime(timezone=True), nullable=True), + sa.Column('current_period_end', sa.DateTime(timezone=True), nullable=True), + sa.Column('cancel_at_period_end', sa.Boolean, nullable=False, server_default='false'), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'), + sa.UniqueConstraint('account_id', name='uq_subscriptions_account_id'), + ) + op.create_index('ix_subscriptions_account_id', 'subscriptions', ['account_id']) + op.create_index('ix_subscriptions_plan', 'subscriptions', ['plan']) + + # 3. plan_limits table (configuration — seeded with 3 rows) + op.create_table( + 'plan_limits', + sa.Column('plan', sa.String(50), nullable=False), + sa.Column('max_trees', sa.Integer, nullable=True), # NULL = unlimited + sa.Column('max_sessions_per_month', sa.Integer, nullable=True), + sa.Column('max_users', sa.Integer, nullable=True), + sa.Column('custom_branding', sa.Boolean, nullable=False, server_default='false'), + sa.Column('priority_support', sa.Boolean, nullable=False, server_default='false'), + sa.Column('export_formats', JSONB, nullable=False, server_default='["markdown", "text"]'), + sa.PrimaryKeyConstraint('plan'), + ) + + # Seed plan_limits + op.execute(""" + INSERT INTO plan_limits (plan, max_trees, max_sessions_per_month, max_users, custom_branding, priority_support, export_formats) + VALUES + ('free', 3, 20, 1, false, false, '["markdown", "text"]'), + ('pro', 25, 200, 1, false, true, '["markdown", "text", "html", "pdf"]'), + ('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html", "pdf"]') + """) + + # 4. account_invites table + op.create_table( + 'account_invites', + sa.Column('id', UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('invited_by_id', UUID(as_uuid=True), nullable=False), + sa.Column('email', sa.String(255), nullable=False), + sa.Column('code', sa.String(32), nullable=False), + sa.Column('role', sa.String(50), nullable=False, server_default='engineer'), + sa.Column('accepted_by_id', UUID(as_uuid=True), nullable=True), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('NOW()'), nullable=False), + sa.Column('used_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['invited_by_id'], ['users.id']), + sa.ForeignKeyConstraint(['accepted_by_id'], ['users.id']), + sa.UniqueConstraint('code', name='uq_account_invites_code'), + sa.CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'), + ) + op.create_index('ix_account_invites_account_id', 'account_invites', ['account_id']) + op.create_index('ix_account_invites_email', 'account_invites', ['email']) + + +def downgrade() -> None: + op.drop_table('account_invites') + op.drop_table('plan_limits') + op.drop_table('subscriptions') + op.drop_table('accounts') diff --git a/backend/alembic/versions/017_add_account_id_to_users.py b/backend/alembic/versions/017_add_account_id_to_users.py new file mode 100644 index 00000000..2e205a33 --- /dev/null +++ b/backend/alembic/versions/017_add_account_id_to_users.py @@ -0,0 +1,31 @@ +"""add account_id and account_role columns to users + +Revision ID: 017 +Revises: 016 +Create Date: 2026-02-07 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '017' +down_revision: Union[str, None] = '016' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('users', sa.Column('account_id', UUID(as_uuid=True), nullable=True)) + op.add_column('users', sa.Column('account_role', sa.String(50), nullable=True)) + op.create_index('ix_users_account_id', 'users', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_users_account_id', table_name='users') + op.drop_column('users', 'account_role') + op.drop_column('users', 'account_id') diff --git a/backend/alembic/versions/018_migrate_users_to_accounts.py b/backend/alembic/versions/018_migrate_users_to_accounts.py new file mode 100644 index 00000000..d15401f5 --- /dev/null +++ b/backend/alembic/versions/018_migrate_users_to_accounts.py @@ -0,0 +1,187 @@ +"""migrate existing users and teams to accounts + +Revision ID: 018 +Revises: 017 +Create Date: 2026-02-07 + +This is the most critical migration. It creates a _team_account_mapping table +for deterministic cross-migration lookups, then migrates all users to accounts. + +Three paths: + A) Teams with users → Account with deterministic owner + B) Teams with zero users → Account with owner_id=NULL, subscription status='orphaned' + C) Users without a team → Personal account, user is owner +""" +import secrets +import string +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '018' +down_revision: Union[str, None] = '017' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# Characters for display codes — exclude confusing chars +DISPLAY_CODE_CHARS = string.ascii_uppercase + string.digits +DISPLAY_CODE_CHARS = DISPLAY_CODE_CHARS.replace('0', '').replace('O', '').replace('I', '').replace('1', '').replace('L', '') + + +def _generate_display_code(existing_codes: set) -> str: + """Generate a unique 8-character display code.""" + for _ in range(100): + code = ''.join(secrets.choice(DISPLAY_CODE_CHARS) for _ in range(8)) + if code not in existing_codes: + existing_codes.add(code) + return code + raise RuntimeError("Failed to generate unique display code after 100 attempts") + + +def upgrade() -> None: + conn = op.get_bind() + + # Create mapping table for deterministic cross-migration lookups + op.create_table( + '_team_account_mapping', + sa.Column('team_id', UUID(as_uuid=True), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('owner_user_id', UUID(as_uuid=True), nullable=True), + sa.PrimaryKeyConstraint('team_id'), + ) + + existing_codes: set = set() + + # --- Path A & B: Process all teams --- + teams = conn.execute(sa.text("SELECT id, name FROM teams")).fetchall() + + for team in teams: + team_id = team[0] + team_name = team[1] + display_code = _generate_display_code(existing_codes) + + # Find deterministic owner: team admin first, then earliest user + owner_row = conn.execute(sa.text(""" + SELECT id FROM users + WHERE team_id = :tid + ORDER BY is_team_admin DESC, created_at ASC, id ASC + LIMIT 1 + """), {"tid": team_id}).fetchone() + + owner_user_id = owner_row[0] if owner_row else None + + # Create account + conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at) + VALUES (gen_random_uuid(), :name, :code, :owner_id, NOW(), NOW()) + """), {"name": team_name, "code": display_code, "owner_id": owner_user_id}) + + # Get the account we just created + account_row = conn.execute(sa.text( + "SELECT id FROM accounts WHERE display_code = :code" + ), {"code": display_code}).fetchone() + account_id = account_row[0] + + # Insert mapping + conn.execute(sa.text(""" + INSERT INTO _team_account_mapping (team_id, account_id, owner_user_id) + VALUES (:tid, :aid, :uid) + """), {"tid": team_id, "aid": account_id, "uid": owner_user_id}) + + if owner_user_id is not None: + # Path A: Team with users + # Create active subscription + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW()) + """), {"aid": account_id}) + + # Update all users in this team + # Team admins become owners, others keep their role + conn.execute(sa.text(""" + UPDATE users SET + account_id = :aid, + account_role = CASE + WHEN is_team_admin = true THEN 'owner' + ELSE role + END + WHERE team_id = :tid + """), {"aid": account_id, "tid": team_id}) + else: + # Path B: Team with zero users (orphan) + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'orphaned', NOW(), NOW()) + """), {"aid": account_id}) + + # --- Path C: Users without a team --- + teamless_users = conn.execute(sa.text( + "SELECT id, name FROM users WHERE team_id IS NULL AND account_id IS NULL" + )).fetchall() + + for user in teamless_users: + user_id = user[0] + user_name = user[1] + display_code = _generate_display_code(existing_codes) + + # Create personal account (owner_id set to NULL initially) + conn.execute(sa.text(""" + INSERT INTO accounts (id, name, display_code, owner_id, created_at, updated_at) + VALUES (gen_random_uuid(), :name, :code, NULL, NOW(), NOW()) + """), {"name": f"{user_name}'s Account", "code": display_code}) + + account_row = conn.execute(sa.text( + "SELECT id FROM accounts WHERE display_code = :code" + ), {"code": display_code}).fetchone() + account_id = account_row[0] + + # Update user + conn.execute(sa.text(""" + UPDATE users SET account_id = :aid, account_role = 'owner' + WHERE id = :uid + """), {"aid": account_id, "uid": user_id}) + + # Set owner + conn.execute(sa.text(""" + UPDATE accounts SET owner_id = :uid WHERE id = :aid + """), {"uid": user_id, "aid": account_id}) + + # Create free subscription + conn.execute(sa.text(""" + INSERT INTO subscriptions (id, account_id, plan, status, created_at, updated_at) + VALUES (gen_random_uuid(), :aid, 'free', 'active', NOW(), NOW()) + """), {"aid": account_id}) + + # --- Validation --- + orphaned_users = conn.execute(sa.text( + "SELECT COUNT(*) FROM users WHERE account_id IS NULL" + )).scalar() + if orphaned_users > 0: + raise RuntimeError( + f"Migration 018 failed validation: {orphaned_users} users still have NULL account_id" + ) + + team_count = conn.execute(sa.text("SELECT COUNT(*) FROM teams")).scalar() + mapping_count = conn.execute(sa.text("SELECT COUNT(*) FROM _team_account_mapping")).scalar() + if mapping_count != team_count: + raise RuntimeError( + f"Migration 018 failed: mapping count ({mapping_count}) != team count ({team_count})" + ) + + +def downgrade() -> None: + conn = op.get_bind() + + # Clear account data from users + conn.execute(sa.text("UPDATE users SET account_id = NULL, account_role = NULL")) + + # Delete all subscriptions and accounts created by this migration + conn.execute(sa.text("DELETE FROM subscriptions")) + conn.execute(sa.text("DELETE FROM accounts")) + + # Drop mapping table + op.drop_table('_team_account_mapping') diff --git a/backend/alembic/versions/019_migrate_team_fks_to_account.py b/backend/alembic/versions/019_migrate_team_fks_to_account.py new file mode 100644 index 00000000..4f15dc80 --- /dev/null +++ b/backend/alembic/versions/019_migrate_team_fks_to_account.py @@ -0,0 +1,56 @@ +"""add account_id to content tables and backfill from _team_account_mapping + +Revision ID: 019 +Revises: 018 +Create Date: 2026-02-07 + +Uses the _team_account_mapping table from migration 018 for deterministic +FK backfill instead of non-deterministic LIMIT 1 subqueries. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '019' +down_revision: Union[str, None] = '018' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# Tables that have team_id and need account_id +CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories'] + + +def upgrade() -> None: + conn = op.get_bind() + + for table in CONTENT_TABLES: + # Add account_id column + op.add_column(table, sa.Column('account_id', UUID(as_uuid=True), nullable=True)) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + # Backfill from mapping table (deterministic) + conn.execute(sa.text(f""" + UPDATE {table} SET account_id = m.account_id + FROM _team_account_mapping m + WHERE {table}.team_id = m.team_id + """)) + + # Validate: no rows with team_id but missing account_id + orphaned = conn.execute(sa.text(f""" + SELECT COUNT(*) FROM {table} + WHERE team_id IS NOT NULL AND account_id IS NULL + """)).scalar() + if orphaned > 0: + raise RuntimeError( + f"Migration 019 failed: {table} has {orphaned} rows with team_id but no account_id" + ) + + +def downgrade() -> None: + for table in reversed(CONTENT_TABLES): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/020_finalize_account_migration.py b/backend/alembic/versions/020_finalize_account_migration.py new file mode 100644 index 00000000..510c8abb --- /dev/null +++ b/backend/alembic/versions/020_finalize_account_migration.py @@ -0,0 +1,105 @@ +"""finalize account migration — add constraints, clean up orphans + +Revision ID: 020 +Revises: 019 +Create Date: 2026-02-07 + +Adds NOT NULL constraints, foreign keys, and CHECK constraints. +Cleans up orphan accounts (zero-user teams with no content). +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '020' +down_revision: Union[str, None] = '019' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +CONTENT_TABLES = ['trees', 'step_library', 'tree_categories', 'tree_tags', 'step_categories'] + + +def upgrade() -> None: + conn = op.get_bind() + + # 1. Clean up orphan accounts (zero-user teams with no content) + conn.execute(sa.text(""" + DELETE FROM subscriptions WHERE account_id IN ( + SELECT a.id FROM accounts a + WHERE a.owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id) + ) + """)) + + # Also remove the mapping entries for these orphans + conn.execute(sa.text(""" + DELETE FROM _team_account_mapping WHERE account_id IN ( + SELECT a.id FROM accounts a + WHERE a.owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = a.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = a.id) + ) + """)) + + conn.execute(sa.text(""" + DELETE FROM accounts + WHERE owner_id IS NULL + AND NOT EXISTS (SELECT 1 FROM trees t WHERE t.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM tree_categories tc WHERE tc.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM tree_tags tt WHERE tt.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM step_categories sc WHERE sc.account_id = accounts.id) + AND NOT EXISTS (SELECT 1 FROM step_library sl WHERE sl.account_id = accounts.id) + """)) + + # 2. Users: enforce NOT NULL and add FK + CHECK + op.alter_column('users', 'account_id', nullable=False) + op.alter_column('users', 'account_role', nullable=False) + op.create_foreign_key( + 'fk_users_account_id', 'users', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE' + ) + op.create_check_constraint( + 'ck_users_account_role_enum', 'users', + "account_role IN ('owner', 'engineer', 'viewer')" + ) + + # 3. Content tables: add FK on account_id (nullable OK — NULL means global) + for table in CONTENT_TABLES: + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE' + ) + + # 4. Accounts: enforce owner_id NOT NULL + FK + op.alter_column('accounts', 'owner_id', nullable=False) + op.create_foreign_key( + 'fk_accounts_owner_id', 'accounts', 'users', + ['owner_id'], ['id'], ondelete='RESTRICT' + ) + + +def downgrade() -> None: + # Remove account owner FK and nullable constraint + op.drop_constraint('fk_accounts_owner_id', 'accounts', type_='foreignkey') + op.alter_column('accounts', 'owner_id', nullable=True) + + # Remove content table FKs + for table in reversed(CONTENT_TABLES): + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + + # Remove user constraints + op.drop_constraint('ck_users_account_role_enum', 'users', type_='check') + op.drop_constraint('fk_users_account_id', 'users', type_='foreignkey') + op.alter_column('users', 'account_role', nullable=True) + op.alter_column('users', 'account_id', nullable=True) diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index c7cfa475..6d353506 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Optional from uuid import UUID from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -8,6 +8,7 @@ from sqlalchemy import select from app.core.database import get_db from app.core.security import decode_token from app.models.user import User +from app.models.plan_limits import PlanLimits oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -90,14 +91,35 @@ async def require_admin( async def require_engineer_or_admin( current_user: Annotated[User, Depends(get_current_active_user)] ) -> User: - """Require engineer, team admin, or super admin role (blocks viewers).""" + """Require engineer, account owner, or super admin role (blocks viewers).""" if current_user.is_super_admin: return current_user - if current_user.is_team_admin and current_user.team_id is not None: + if current_user.account_role in ("owner", "engineer"): return current_user - if current_user.role not in ("engineer",): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Engineer or admin access required" - ) - return current_user + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Engineer or admin access required" + ) + + +async def require_account_owner( + current_user: Annotated[User, Depends(get_current_active_user)] +) -> User: + """Require account owner or super admin access.""" + if current_user.is_super_admin: + return current_user + if current_user.account_role == "owner": + return current_user + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Account owner access required" + ) + + +async def get_plan_limits_for_user( + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> Optional[PlanLimits]: + """Get plan limits for the current user's account.""" + from app.core.subscriptions import get_user_plan_limits + return await get_user_plan_limits(current_user.account_id, db) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 45d51cac..2041cf9c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -52,6 +52,16 @@ class Settings(BaseSettings): # Registration REQUIRE_INVITE_CODE: bool = True # Set to False to allow open registration + # Stripe + STRIPE_SECRET_KEY: Optional[str] = None + STRIPE_PUBLISHABLE_KEY: Optional[str] = None + STRIPE_WEBHOOK_SECRET: Optional[str] = None + + @property + def stripe_enabled(self) -> bool: + """Check if Stripe is configured.""" + return self.STRIPE_SECRET_KEY is not None and self.STRIPE_WEBHOOK_SECRET is not None + # CORS - set FRONTEND_URL in production (e.g., https://patherly.up.railway.app) CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:5173", "http://localhost:5174"] FRONTEND_URL: Optional[str] = None diff --git a/backend/app/core/permissions.py b/backend/app/core/permissions.py index 52fc5bdc..ba34dd62 100644 --- a/backend/app/core/permissions.py +++ b/backend/app/core/permissions.py @@ -1,12 +1,12 @@ """ Centralized permission checks for ResolutionFlow. -Role hierarchy: super_admin > team_admin > engineer > viewer +Role hierarchy: super_admin > owner > engineer > viewer - super_admin: is_super_admin=True, full system access -- team_admin: is_team_admin=True + valid team_id, manage team resources -- engineer: role='engineer' (default), CRUD own trees/steps -- viewer: role='viewer', read-only (can browse, run sessions, rate steps) +- owner: account_role='owner', manage account resources +- engineer: account_role='engineer' (default), CRUD own trees/steps +- viewer: account_role='viewer', read-only (can browse, run sessions, rate steps) """ from __future__ import annotations from typing import Optional, TYPE_CHECKING @@ -21,19 +21,19 @@ if TYPE_CHECKING: ROLE_HIERARCHY = { "super_admin": 4, - "team_admin": 3, + "owner": 3, "engineer": 2, "viewer": 1, } def get_effective_role(user: User) -> str: - """Get the effective role considering is_super_admin and is_team_admin flags.""" + """Get the effective role considering is_super_admin and account_role.""" if user.is_super_admin: return "super_admin" - if user.is_team_admin and user.team_id is not None: - return "team_admin" - return user.role # "engineer" or "viewer" + if user.account_role == "owner": + return "owner" + return user.account_role # "engineer" or "viewer" def has_minimum_role(user: User, minimum_role: str) -> bool: @@ -55,7 +55,7 @@ def can_edit_tree(user: User, tree: Tree) -> bool: return False if tree.author_id == user.id: return True - if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None: return True return False @@ -78,7 +78,7 @@ def can_manage_category(user: User, category: TreeCategory) -> bool: """Can the user edit/delete this category?""" if user.is_super_admin: return True - if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None: return True return False @@ -91,7 +91,7 @@ def can_manage_tree_tags(user: User, tree: Tree) -> bool: return False if tree.author_id == user.id: return True - if user.is_team_admin and tree.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and tree.account_id == user.account_id and user.account_id is not None: return True return False @@ -102,7 +102,7 @@ def can_access_tree(user: User, tree: Tree) -> bool: return True if tree.author_id == user.id: return True - if tree.team_id == user.team_id and user.team_id is not None: + if tree.account_id == user.account_id and user.account_id is not None: return True if user.is_super_admin: return True @@ -116,35 +116,35 @@ def can_view_step(user: User, step: StepLibrary) -> bool: if step.visibility == "private": return step.created_by == user.id if step.visibility == "team": - return (step.team_id == user.team_id and user.team_id is not None) or user.is_super_admin + return (step.account_id == user.account_id and user.account_id is not None) or user.is_super_admin return False -def can_create_tag(user: User, team_id: Optional[UUID]) -> bool: +def can_create_tag(user: User, account_id: Optional[UUID]) -> bool: """Can the user create a tag for the given scope? - - Super admins can create global tags (team_id=None) or any team's tags - - Engineers can create team tags for their own team + - Super admins can create global tags (account_id=None) or any account's tags + - Engineers can create account tags for their own account - Viewers cannot create tags """ if user.is_super_admin: return True if not can_create_content(user): return False - if team_id is not None and team_id == user.team_id: + if account_id is not None and account_id == user.account_id: return True return False -def can_create_category(user: User, team_id: Optional[UUID]) -> bool: - """Can the user create a category for the given team? +def can_create_category(user: User, account_id: Optional[UUID]) -> bool: + """Can the user create a category for the given account? - - Super admins can create global or any team's categories - - Team admins can create categories for their own team + - Super admins can create global or any account's categories + - Account owners can create categories for their own account """ if user.is_super_admin: return True - if user.is_team_admin and team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None: return True return False @@ -153,19 +153,19 @@ def can_manage_step_category(user: User, category: StepCategory) -> bool: """Can the user edit/delete this step category?""" if user.is_super_admin: return True - if user.is_team_admin and category.team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and category.account_id == user.account_id and user.account_id is not None: return True return False -def can_create_step_category(user: User, team_id: Optional[UUID]) -> bool: - """Can the user create a step category for the given team? +def can_create_step_category(user: User, account_id: Optional[UUID]) -> bool: + """Can the user create a step category for the given account? - - Super admins can create global or any team's step categories - - Team admins can create step categories for their own team + - Super admins can create global or any account's step categories + - Account owners can create step categories for their own account """ if user.is_super_admin: return True - if user.is_team_admin and team_id == user.team_id and user.team_id is not None: + if user.account_role == "owner" and account_id == user.account_id and user.account_id is not None: return True return False diff --git a/backend/app/core/stripe_handlers.py b/backend/app/core/stripe_handlers.py new file mode 100644 index 00000000..2d771e49 --- /dev/null +++ b/backend/app/core/stripe_handlers.py @@ -0,0 +1,37 @@ +"""Stripe webhook event handlers (stub implementations). + +These handlers log events but don't process them until Stripe is fully configured. +""" +import logging +from sqlalchemy.ext.asyncio import AsyncSession + +logger = logging.getLogger(__name__) + + +async def handle_checkout_completed(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: checkout.session.completed — %s", event.get("id")) + + +async def handle_invoice_paid(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: invoice.paid — %s", event.get("id")) + + +async def handle_invoice_payment_failed(event: dict, db: AsyncSession) -> None: + logger.warning("Stripe: invoice.payment_failed — %s", event.get("id")) + + +async def handle_subscription_updated(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: customer.subscription.updated — %s", event.get("id")) + + +async def handle_subscription_deleted(event: dict, db: AsyncSession) -> None: + logger.info("Stripe: customer.subscription.deleted — %s", event.get("id")) + + +WEBHOOK_HANDLERS = { + "checkout.session.completed": handle_checkout_completed, + "invoice.paid": handle_invoice_paid, + "invoice.payment_failed": handle_invoice_payment_failed, + "customer.subscription.updated": handle_subscription_updated, + "customer.subscription.deleted": handle_subscription_deleted, +} diff --git a/backend/app/core/subscriptions.py b/backend/app/core/subscriptions.py new file mode 100644 index 00000000..e4e1b735 --- /dev/null +++ b/backend/app/core/subscriptions.py @@ -0,0 +1,113 @@ +"""Subscription limit checks and plan helpers.""" +from typing import Optional +from uuid import UUID +from datetime import datetime, timezone + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.subscription import Subscription +from app.models.plan_limits import PlanLimits +from app.models.tree import Tree +from app.models.session import Session + + +async def get_account_subscription(account_id: UUID, db: AsyncSession) -> Optional[Subscription]: + result = await db.execute( + select(Subscription).where(Subscription.account_id == account_id) + ) + return result.scalar_one_or_none() + + +async def get_plan_limits(plan: str, db: AsyncSession) -> Optional[PlanLimits]: + result = await db.execute( + select(PlanLimits).where(PlanLimits.plan == plan) + ) + return result.scalar_one_or_none() + + +async def get_user_plan_limits(user_account_id: UUID, db: AsyncSession) -> Optional[PlanLimits]: + sub = await get_account_subscription(user_account_id, db) + if sub is None: + return await get_plan_limits("free", db) + return await get_plan_limits(sub.plan, db) + + +async def check_tree_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: + """Check if account can create a new tree. + + Returns: (can_create, limit, current_count) + """ + sub = await get_account_subscription(account_id, db) + if sub is None: + return False, 0, 0 + + limits = await get_plan_limits(sub.plan, db) + if limits is None or limits.max_trees is None: + return True, None, 0 # unlimited + + current_count = await db.scalar( + select(func.count(Tree.id)).where( + Tree.account_id == account_id, + Tree.deleted_at.is_(None), + ) + ) + current_count = current_count or 0 + + return current_count < limits.max_trees, limits.max_trees, current_count + + +async def check_session_limit(account_id: UUID, db: AsyncSession) -> tuple[bool, Optional[int], int]: + """Check if account can create a new session this month. + + Returns: (can_create, limit, current_count) + """ + sub = await get_account_subscription(account_id, db) + if sub is None: + return False, 0, 0 + + limits = await get_plan_limits(sub.plan, db) + if limits is None or limits.max_sessions_per_month is None: + return True, None, 0 # unlimited + + # Count sessions this calendar month for all users in this account + now = datetime.now(timezone.utc) + month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + from app.models.user import User + current_count = await db.scalar( + select(func.count(Session.id)).where( + Session.user_id.in_( + select(User.id).where(User.account_id == account_id) + ), + Session.started_at >= month_start, + ) + ) + current_count = current_count or 0 + + return current_count < limits.max_sessions_per_month, limits.max_sessions_per_month, current_count + + +async def get_account_usage(account_id: UUID, db: AsyncSession) -> dict: + """Get current usage stats for an account.""" + tree_count = await db.scalar( + select(func.count(Tree.id)).where( + Tree.account_id == account_id, + Tree.deleted_at.is_(None), + ) + ) or 0 + + now = datetime.now(timezone.utc) + month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + from app.models.user import User + session_count = await db.scalar( + select(func.count(Session.id)).where( + Session.user_id.in_( + select(User.id).where(User.account_id == account_id) + ), + Session.started_at >= month_start, + ) + ) or 0 + + return {"tree_count": tree_count, "session_count_this_month": session_count} diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 28cb75bc..981c5b15 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,5 +1,9 @@ from .user import User from .team import Team +from .account import Account +from .subscription import Subscription +from .plan_limits import PlanLimits +from .account_invite import AccountInvite from .tree import Tree from .session import Session from .attachment import Attachment @@ -15,6 +19,10 @@ from .audit_log import AuditLog __all__ = [ "User", "Team", + "Account", + "Subscription", + "PlanLimits", + "AccountInvite", "Tree", "Session", "Attachment", diff --git a/backend/app/models/account.py b/backend/app/models/account.py new file mode 100644 index 00000000..6506488f --- /dev/null +++ b/backend/app/models/account.py @@ -0,0 +1,38 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.user import User + from app.models.subscription import Subscription + from app.models.tree import Tree + from app.models.category import TreeCategory + from app.models.tag import TreeTag + from app.models.step_category import StepCategory + from app.models.step_library import StepLibrary + + +class Account(Base): + __tablename__ = "accounts" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(255), nullable=False) + display_code: Mapped[str] = mapped_column(String(8), unique=True, nullable=False) + owner_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="RESTRICT"), nullable=False) + stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + # Relationships + owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account") + users: Mapped[list["User"]] = relationship("User", foreign_keys="[User.account_id]", back_populates="account") + subscription: Mapped[Optional["Subscription"]] = relationship("Subscription", back_populates="account", uselist=False) + trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.account_id]", back_populates="account") + categories: Mapped[list["TreeCategory"]] = relationship("TreeCategory", foreign_keys="[TreeCategory.account_id]", back_populates="account") + tags: Mapped[list["TreeTag"]] = relationship("TreeTag", foreign_keys="[TreeTag.account_id]", back_populates="account") + step_categories: Mapped[list["StepCategory"]] = relationship("StepCategory", foreign_keys="[StepCategory.account_id]", back_populates="account") + step_library: Mapped[list["StepLibrary"]] = relationship("StepLibrary", foreign_keys="[StepLibrary.account_id]", back_populates="account") diff --git a/backend/app/models/account_invite.py b/backend/app/models/account_invite.py new file mode 100644 index 00000000..43b3ed56 --- /dev/null +++ b/backend/app/models/account_invite.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, CheckConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.account import Account + from app.models.user import User + + +class AccountInvite(Base): + __tablename__ = "account_invites" + __table_args__ = ( + CheckConstraint("role IN ('engineer', 'viewer')", name='ck_account_invites_role'), + ) + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False) + invited_by_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + email: Mapped[str] = mapped_column(String(255), nullable=False) + code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer") + accepted_by_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + + # Relationships + account: Mapped["Account"] = relationship("Account") + invited_by: Mapped["User"] = relationship("User", foreign_keys=[invited_by_id]) + accepted_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[accepted_by_id]) + + @property + def is_used(self) -> bool: + return self.accepted_by_id is not None + + @property + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return datetime.now(timezone.utc) > self.expires_at + + @property + def is_valid(self) -> bool: + return not self.is_used and not self.is_expired diff --git a/backend/app/models/category.py b/backend/app/models/category.py index 6d375d50..57bae701 100644 --- a/backend/app/models/category.py +++ b/backend/app/models/category.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.tree import Tree from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -38,6 +39,12 @@ class TreeCategory(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -57,10 +64,11 @@ class TreeCategory(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="categories") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="categories") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) trees: Mapped[list["Tree"]] = relationship("Tree", back_populates="category_rel") @property def is_global(self) -> bool: - """Returns True if this is a global category (not team-specific).""" - return self.team_id is None + """Returns True if this is a global category (not team or account-specific).""" + return self.team_id is None and self.account_id is None diff --git a/backend/app/models/plan_limits.py b/backend/app/models/plan_limits.py new file mode 100644 index 00000000..1a6b0511 --- /dev/null +++ b/backend/app/models/plan_limits.py @@ -0,0 +1,16 @@ +from sqlalchemy import String, Integer, Boolean +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import JSONB +from app.core.database import Base + + +class PlanLimits(Base): + __tablename__ = "plan_limits" + + plan: Mapped[str] = mapped_column(String(50), primary_key=True) + max_trees: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_sessions_per_month: Mapped[int | None] = mapped_column(Integer, nullable=True) + max_users: Mapped[int | None] = mapped_column(Integer, nullable=True) + custom_branding: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + priority_support: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + export_formats: Mapped[list] = mapped_column(JSONB, nullable=False, default=lambda: ["markdown", "text"]) diff --git a/backend/app/models/step_category.py b/backend/app/models/step_category.py index 59471ccb..da207926 100644 --- a/backend/app/models/step_category.py +++ b/backend/app/models/step_category.py @@ -8,6 +8,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -37,6 +38,12 @@ class StepCategory(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -56,9 +63,10 @@ class StepCategory(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="step_categories") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_categories") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) @property def is_global(self) -> bool: - """Returns True if this is a global category (not team-specific).""" - return self.team_id is None + """Returns True if this is a global category (not team or account-specific).""" + return self.team_id is None and self.account_id is None diff --git a/backend/app/models/step_library.py b/backend/app/models/step_library.py index 5fb78614..ae2b319b 100644 --- a/backend/app/models/step_library.py +++ b/backend/app/models/step_library.py @@ -10,6 +10,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User from app.models.team import Team + from app.models.account import Account from app.models.step_category import StepCategory from app.models.session import Session @@ -43,6 +44,12 @@ class StepLibrary(Base): ForeignKey("teams.id", ondelete="CASCADE"), nullable=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) # Organization category_id: Mapped[Optional[uuid.UUID]] = mapped_column( @@ -91,6 +98,7 @@ class StepLibrary(Base): # Relationships creator: Mapped["User"] = relationship("User", foreign_keys=[created_by]) team: Mapped[Optional["Team"]] = relationship("Team") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="step_library") category: Mapped[Optional["StepCategory"]] = relationship("StepCategory") ratings: Mapped[list["StepRating"]] = relationship("StepRating", back_populates="step", cascade="all, delete-orphan") usage_logs: Mapped[list["StepUsageLog"]] = relationship("StepUsageLog", back_populates="step", cascade="all, delete-orphan") diff --git a/backend/app/models/subscription.py b/backend/app/models/subscription.py new file mode 100644 index 00000000..54b2a440 --- /dev/null +++ b/backend/app/models/subscription.py @@ -0,0 +1,39 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.account import Account + + +class Subscription(Base): + __tablename__ = "subscriptions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), unique=True, nullable=False) + stripe_subscription_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + stripe_price_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + plan: Mapped[str] = mapped_column(String(50), nullable=False, default="free") + billing_interval: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + status: Mapped[str] = mapped_column(String(50), nullable=False, default="active") + seat_limit: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + current_period_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + current_period_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + cancel_at_period_end: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + + # Relationships + account: Mapped["Account"] = relationship("Account", back_populates="subscription") + + @property + def is_active(self) -> bool: + return self.status in ("active", "trialing") + + @property + def is_paid(self) -> bool: + return self.plan in ("pro", "team") diff --git a/backend/app/models/tag.py b/backend/app/models/tag.py index 6ebb6b38..5152c3a9 100644 --- a/backend/app/models/tag.py +++ b/backend/app/models/tag.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.tree import Tree from app.models.team import Team + from app.models.account import Account from app.models.user import User @@ -50,6 +51,12 @@ class TreeTag(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), @@ -63,6 +70,7 @@ class TreeTag(Base): # Relationships team: Mapped[Optional["Team"]] = relationship("Team", back_populates="tags") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="tags") creator: Mapped[Optional["User"]] = relationship("User", foreign_keys=[created_by]) trees: Mapped[list["Tree"]] = relationship( "Tree", @@ -72,8 +80,8 @@ class TreeTag(Base): @property def is_global(self) -> bool: - """Returns True if this is a global tag (not team-specific).""" - return self.team_id is None + """Returns True if this is a global tag (not team or account-specific).""" + return self.team_id is None and self.account_id is None @classmethod def slugify(cls, name: str) -> str: diff --git a/backend/app/models/tree.py b/backend/app/models/tree.py index 56237d96..c9b305a7 100644 --- a/backend/app/models/tree.py +++ b/backend/app/models/tree.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User from app.models.team import Team + from app.models.account import Account from app.models.session import Session from app.models.category import TreeCategory from app.models.tag import TreeTag @@ -47,6 +48,12 @@ class Tree(Base): nullable=True, index=True ) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=True, + index=True + ) is_active: Mapped[bool] = mapped_column(Boolean, default=True) is_public: Mapped[bool] = mapped_column(Boolean, default=False, index=True) is_default: Mapped[bool] = mapped_column(Boolean, default=False, index=True) @@ -75,6 +82,7 @@ class Tree(Base): # Relationships author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees") team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees") + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees") sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree") # New organization relationships diff --git a/backend/app/models/user.py b/backend/app/models/user.py index ac7648f7..ec835640 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -8,6 +8,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.team import Team + from app.models.account import Account from app.models.tree import Tree from app.models.session import Session from app.models.folder import UserFolder @@ -20,6 +21,10 @@ class User(Base): "role IN ('engineer', 'viewer')", name='ck_users_role_enum' ), + CheckConstraint( + "account_role IN ('owner', 'admin', 'engineer', 'viewer')", + name='ck_users_account_role_enum' + ), ) id: Mapped[uuid.UUID] = mapped_column( @@ -34,6 +39,17 @@ class User(Base): is_super_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_team_admin: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") + + # Account-based multi-tenancy (new) + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="RESTRICT"), + nullable=True, + index=True + ) + account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer") + + # Legacy team columns (kept for PR A coexistence) team_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id"), @@ -51,6 +67,8 @@ class User(Base): last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) # Relationships + account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="users") + owned_account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys="[Account.owner_id]", back_populates="owner", uselist=False) team: Mapped[Optional["Team"]] = relationship("Team", back_populates="users") trees: Mapped[list["Tree"]] = relationship("Tree", foreign_keys="[Tree.author_id]", back_populates="author") sessions: Mapped[list["Session"]] = relationship("Session", back_populates="user") @@ -62,6 +80,11 @@ class User(Base): return self.is_super_admin @property - def can_manage_team(self) -> bool: - """Returns True if user can manage their team (team admin or super admin).""" - return self.is_super_admin or (self.is_team_admin and self.team_id is not None) + def is_account_owner(self) -> bool: + """Returns True if user owns their account.""" + return self.account_role == "owner" + + @property + def can_manage_account(self) -> bool: + """Returns True if user can manage their account (owner, admin, or super admin).""" + return self.is_super_admin or self.account_role in ("owner", "admin") diff --git a/backend/requirements.txt b/backend/requirements.txt index 73b11ac3..5bfbf210 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -22,5 +22,8 @@ email-validator==2.1.0 # Rate Limiting slowapi==0.1.9 +# Payments +stripe==14.3.0 + # Utilities python-dotenv==1.0.1