diff --git a/backend/alembic/versions/0b470d9e6cf1_create_db_roles.py b/backend/alembic/versions/0b470d9e6cf1_create_db_roles.py new file mode 100644 index 00000000..d8d81327 --- /dev/null +++ b/backend/alembic/versions/0b470d9e6cf1_create_db_roles.py @@ -0,0 +1,102 @@ +"""create_db_roles + +Revision ID: 0b470d9e6cf1 +Revises: a9f3b2c1d4e5 +Create Date: 2026-04-10 03:58:10.207919 + +""" +import os +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import text + + +# revision identifiers, used by Alembic. +revision: str = '0b470d9e6cf1' +down_revision: Union[str, None] = 'a9f3b2c1d4e5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Passwords from env vars. For local dev, defaults are sufficient. + # For production (Railway), set DB_APP_ROLE_PASSWORD and + # DB_ADMIN_ROLE_PASSWORD as environment variables before running migrations. + # Passwords must not contain single quotes. + app_pw = os.environ.get("DB_APP_ROLE_PASSWORD", "app_secret_change_me") + admin_pw = os.environ.get("DB_ADMIN_ROLE_PASSWORD", "admin_secret_change_me") + + # Fetch the current database name dynamically — avoids hardcoding + # (the DB is named 'resolutionflow' in dev, potentially different elsewhere). + conn = op.get_bind() + db_name = conn.execute(text("SELECT current_database()")).scalar() + + # ── Application role ──────────────────────────────────────────────────── + # Subject to RLS. Used by FastAPI at runtime via DATABASE_URL. + op.execute(f""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_app') THEN + CREATE ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}'; + ELSE + ALTER ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}'; + END IF; + END $$ + """) + op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_app") + op.execute("GRANT USAGE ON SCHEMA public TO resolutionflow_app") + op.execute( + "GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public " + "TO resolutionflow_app" + ) + op.execute( + "GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO resolutionflow_app" + ) + # Ensure future tables automatically get the same permissions + op.execute( + "ALTER DEFAULT PRIVILEGES IN SCHEMA public " + "GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO resolutionflow_app" + ) + op.execute( + "ALTER DEFAULT PRIVILEGES IN SCHEMA public " + "GRANT USAGE, SELECT ON SEQUENCES TO resolutionflow_app" + ) + + # ── Admin role ────────────────────────────────────────────────────────── + # BYPASSRLS. Used by Alembic (DATABASE_URL_SYNC) and /admin/* endpoints + # (ADMIN_DATABASE_URL) after Task 11. + op.execute(f""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_admin') THEN + CREATE ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}'; + ELSE + ALTER ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}'; + END IF; + END $$ + """) + op.execute("GRANT resolutionflow_app TO resolutionflow_admin") + op.execute("ALTER ROLE resolutionflow_admin BYPASSRLS") + op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_admin") + + +def downgrade() -> None: + conn = op.get_bind() + db_name = conn.execute(text("SELECT current_database()")).scalar() + + op.execute( + "REVOKE ALL ON ALL TABLES IN SCHEMA public FROM resolutionflow_app" + ) + op.execute( + "REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM resolutionflow_app" + ) + op.execute( + f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_app" + ) + op.execute( + f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_admin" + ) + op.execute("DROP ROLE IF EXISTS resolutionflow_admin") + op.execute("DROP ROLE IF EXISTS resolutionflow_app") diff --git a/backend/alembic/versions/174f442795b7_set_not_null_account_id_phase1.py b/backend/alembic/versions/174f442795b7_set_not_null_account_id_phase1.py new file mode 100644 index 00000000..ce576dae --- /dev/null +++ b/backend/alembic/versions/174f442795b7_set_not_null_account_id_phase1.py @@ -0,0 +1,86 @@ +"""set NOT NULL on all previously-nullable account_id columns + +Revision ID: 174f442795b7 +Revises: 3a40fe11b427 +Create Date: 2026-04-09 00:00:00.000000 + +All tables in this migration had account_id set to nullable previously. +Task 9 (create_global_content_tables) cleared all NULL rows. +This migration enforces the NOT NULL constraint. +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '174f442795b7' +down_revision: Union[str, None] = '3a40fe11b427' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # tree_embeddings: backfill from trees (must happen before SET NOT NULL) + op.execute(""" + UPDATE tree_embeddings te + SET account_id = t.account_id + FROM trees t + WHERE te.tree_id = t.id + AND te.account_id IS NULL + """) + + # feedback: backfill from users + op.execute(""" + UPDATE feedback f + SET account_id = u.account_id + FROM users u + WHERE f.user_id = u.id + AND f.account_id IS NULL + """) + + # Verify ALL tables before touching any SET NOT NULL + tables_with_account_id = [ + 'users', 'trees', 'tree_categories', 'tree_tags', + 'step_categories', 'step_library', 'tree_embeddings', 'feedback', + ] + for table in tables_with_account_id: + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows in {table}. " + "Run Task 9 (create_global_content_tables) first, or " + "manually backfill/delete orphaned rows." + ) + + # SET NOT NULL on all + for table in tables_with_account_id: + op.alter_column(table, 'account_id', nullable=False) + + # Create indexes where they don't already exist + new_indexes = [ + ('tree_embeddings', 'ix_tree_embeddings_account_id'), + ('feedback', 'ix_feedback_account_id'), + ] + for table, index_name in new_indexes: + result = op.get_bind().execute(sa.text( + f"SELECT 1 FROM pg_indexes WHERE tablename='{table}' AND indexname='{index_name}'" + )) + if not result.fetchone(): + op.create_index(index_name, table, ['account_id']) + + +def downgrade() -> None: + # Revert to nullable + for table in ('users', 'trees', 'tree_categories', 'tree_tags', + 'step_categories', 'step_library', 'tree_embeddings', 'feedback'): + op.alter_column(table, 'account_id', nullable=True) + for table, index_name in ( + ('tree_embeddings', 'ix_tree_embeddings_account_id'), + ('feedback', 'ix_feedback_account_id'), + ): + try: + op.drop_index(index_name, table_name=table) + except Exception: + pass diff --git a/backend/alembic/versions/2c6aabd89bc6_add_account_id_target_lists.py b/backend/alembic/versions/2c6aabd89bc6_add_account_id_target_lists.py new file mode 100644 index 00000000..1107e373 --- /dev/null +++ b/backend/alembic/versions/2c6aabd89bc6_add_account_id_target_lists.py @@ -0,0 +1,62 @@ +"""add account_id to target_lists (keep team_id) + +Revision ID: 2c6aabd89bc6 +Revises: 78fc200abac1 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '2c6aabd89bc6' +down_revision: Union[str, None] = '78fc200abac1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('target_lists', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_target_lists_account_id', 'target_lists', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Primary: team_id → team admin user → account_id + op.execute(""" + UPDATE target_lists tl + SET account_id = u.account_id + FROM users u + WHERE u.team_id = tl.team_id + AND u.is_team_admin = TRUE + AND u.account_id IS NOT NULL + AND tl.account_id IS NULL + """) + + # Fallback: created_by → users.account_id + op.execute(""" + UPDATE target_lists tl + SET account_id = u.account_id + FROM users u + WHERE tl.created_by = u.id + AND u.account_id IS NOT NULL + AND tl.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM target_lists WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} target_lists rows have NULL account_id. " + "No team admin found for these teams. Resolve before re-running." + ) + + op.alter_column('target_lists', 'account_id', nullable=False) + op.create_index('ix_target_lists_account_id', 'target_lists', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_target_lists_account_id', table_name='target_lists') + op.drop_constraint('fk_target_lists_account_id', 'target_lists', type_='foreignkey') + op.drop_column('target_lists', 'account_id') diff --git a/backend/alembic/versions/3a40fe11b427_create_global_content_tables.py b/backend/alembic/versions/3a40fe11b427_create_global_content_tables.py new file mode 100644 index 00000000..01dbac47 --- /dev/null +++ b/backend/alembic/versions/3a40fe11b427_create_global_content_tables.py @@ -0,0 +1,155 @@ +"""create template_trees and platform_steps global content tables + +Revision ID: 3a40fe11b427 +Revises: 2c6aabd89bc6 +Create Date: 2026-04-09 00:00:00.000000 + +These tables hold platform-owned content that is readable by all +authenticated users. No account_id. No RLS. Ever. +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + +revision: str = '3a40fe11b427' +down_revision: Union[str, None] = '2c6aabd89bc6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Create template_trees ───────────────────────────────────────────────── + op.create_table( + 'template_trees', + sa.Column('id', UUID(), primary_key=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('category', sa.String(100), nullable=True), + sa.Column('tree_type', sa.String(20), nullable=False), + sa.Column('tree_structure', JSONB(), nullable=False), + sa.Column('tags', JSONB(), nullable=False, server_default='[]'), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('source_tree_id', UUID(), sa.ForeignKey('trees.id', ondelete='SET NULL'), nullable=True), + ) + op.create_index('ix_template_trees_tree_type', 'template_trees', ['tree_type']) + + # ── Create platform_steps ──────────────────────────────────────────────── + op.create_table( + 'platform_steps', + sa.Column('id', UUID(), primary_key=True), + sa.Column('title', sa.String(255), nullable=False), + sa.Column('step_type', sa.String(50), nullable=False), + sa.Column('content', JSONB(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('source_step_id', UUID(), sa.ForeignKey('step_library.id', ondelete='SET NULL'), nullable=True), + ) + op.create_index('ix_platform_steps_step_type', 'platform_steps', ['step_type']) + + # ── Copy is_default=TRUE trees → template_trees ───────────────────────── + # Note: trees.tags is a relationship via tree_tags join table — no direct column. + # Aggregate tag names via a correlated subquery. + op.execute(""" + INSERT INTO template_trees + (id, name, description, category, tree_type, tree_structure, + tags, is_active, created_at, updated_at, source_tree_id) + SELECT + gen_random_uuid(), t.name, t.description, t.category, t.tree_type, + t.tree_structure, + COALESCE( + (SELECT jsonb_agg(tt.name ORDER BY tt.name) + FROM tree_tag_assignments ta + JOIN tree_tags tt ON tt.id = ta.tag_id + WHERE ta.tree_id = t.id), + '[]'::jsonb + ), + t.is_active, + COALESCE(t.created_at, NOW()), COALESCE(t.updated_at, NOW()), t.id + FROM trees t + WHERE t.is_default = TRUE + """) + + # ── Copy visibility='public' steps → platform_steps ───────────────────── + op.execute(""" + INSERT INTO platform_steps + (id, title, step_type, content, is_active, created_at, updated_at, source_step_id) + SELECT + gen_random_uuid(), title, step_type, content, is_active, + COALESCE(created_at, NOW()), COALESCE(updated_at, NOW()), id + FROM step_library + WHERE visibility = 'public' + """) + + # ── Create platform sentinel account ───────────────────────────────────── + op.execute(""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES ( + '00000000-0000-0000-0000-000000000001', + 'ResolutionFlow Platform', + 'PLATFORM', + NOW(), + NOW() + ) + ON CONFLICT (id) DO NOTHING + """) + + # ── Assign is_default trees to platform account ────────────────────────── + op.execute(""" + UPDATE trees + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE is_default = TRUE + AND account_id IS NULL + """) + + # ── Assign global categories/tags/steps to platform account ───────────── + op.execute(""" + UPDATE tree_categories + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + op.execute(""" + UPDATE tree_tags + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + op.execute(""" + UPDATE step_categories + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + op.execute(""" + UPDATE step_library + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + # ── Verify zero NULLs in all 5 tables ─────────────────────────────────── + for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows remain in {table} " + "after platform account assignment. Investigate before re-running." + ) + + +def downgrade() -> None: + platform_id = '00000000-0000-0000-0000-000000000001' + for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'): + op.execute(f"UPDATE {table} SET account_id = NULL WHERE account_id = '{platform_id}'") + + op.execute(f"DELETE FROM accounts WHERE id = '{platform_id}'") + op.drop_index('ix_platform_steps_step_type', table_name='platform_steps') + op.drop_index('ix_template_trees_tree_type', table_name='template_trees') + op.drop_table('platform_steps') + op.drop_table('template_trees') diff --git a/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py b/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py new file mode 100644 index 00000000..92d0e6e5 --- /dev/null +++ b/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py @@ -0,0 +1,77 @@ +"""add account_id to AI branching tables + +Revision ID: 478c159e5654 +Revises: cc214c63aa30 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '478c159e5654' +down_revision: Union[str, None] = 'cc214c63aa30' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + ai_tables = ('session_branches', 'session_handoffs', 'fork_points', 'ai_session_steps') + + # Step 1: ADD COLUMN (nullable) + for table in ai_tables: + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + op.add_column('ai_suggestions', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_ai_suggestions_account_id', 'ai_suggestions', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Step 2: BACKFILL + for table in ai_tables: + op.execute(f""" + UPDATE {table} t + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE t.session_id = ai.id + AND t.account_id IS NULL + """) + + op.execute(""" + UPDATE ai_suggestions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # Step 3: VERIFY zero NULLs + for table in ai_tables + ('ai_suggestions',): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows in {table}." + ) + + # Step 4: SET NOT NULL + for table in ai_tables + ('ai_suggestions',): + op.alter_column(table, 'account_id', nullable=False) + + # Step 5: CREATE INDEX + for table in ai_tables + ('ai_suggestions',): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('session_branches', 'session_handoffs', 'fork_points', + 'ai_session_steps', 'ai_suggestions'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/7167e9374b0c_add_account_id_step_ratings.py b/backend/alembic/versions/7167e9374b0c_add_account_id_step_ratings.py new file mode 100644 index 00000000..e34ac86e --- /dev/null +++ b/backend/alembic/versions/7167e9374b0c_add_account_id_step_ratings.py @@ -0,0 +1,46 @@ +"""add account_id to step_ratings and step_usage_log + +Revision ID: 7167e9374b0c +Revises: 478c159e5654 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '7167e9374b0c' +down_revision: Union[str, None] = '478c159e5654' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('step_ratings', 'step_usage_log'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + # Backfill: from the RATER/LOGGER user's account (not the step's account) + op.execute(f""" + UPDATE {table} t + SET account_id = u.account_id + FROM users u + WHERE t.user_id = u.id + AND t.account_id IS NULL + """) + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('step_ratings', 'step_usage_log'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/78fc200abac1_add_account_id_script_tables.py b/backend/alembic/versions/78fc200abac1_add_account_id_script_tables.py new file mode 100644 index 00000000..74116db6 --- /dev/null +++ b/backend/alembic/versions/78fc200abac1_add_account_id_script_tables.py @@ -0,0 +1,103 @@ +"""add account_id to script_builder_sessions, script_templates, script_generations + +Revision ID: 78fc200abac1 +Revises: 7f136778f5a8 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '78fc200abac1' +down_revision: Union[str, None] = '7f136778f5a8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +PLATFORM_ACCOUNT_ID = '00000000-0000-0000-0000-000000000001' + + +def upgrade() -> None: + # Ensure the platform sentinel account exists before any fallback assignments. + # Migration 3a40fe11b427 also inserts this with ON CONFLICT DO NOTHING — safe. + op.execute(f""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES ( + '{PLATFORM_ACCOUNT_ID}', + 'ResolutionFlow Platform', + 'PLATFORM', + NOW(), + NOW() + ) + ON CONFLICT (id) DO NOTHING + """) + + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # script_builder_sessions: user_id → users.account_id + op.execute(""" + UPDATE script_builder_sessions sbs + SET account_id = u.account_id + FROM users u + WHERE sbs.user_id = u.id + AND sbs.account_id IS NULL + """) + + # script_templates: created_by → users.account_id (nullable created_by) + op.execute(""" + UPDATE script_templates st + SET account_id = u.account_id + FROM users u + WHERE st.created_by = u.id + AND st.account_id IS NULL + """) + # Fallback: team_id → team admin user + op.execute(""" + UPDATE script_templates st + SET account_id = u.account_id + FROM users u + WHERE u.team_id = st.team_id + AND u.is_team_admin = TRUE + AND u.account_id IS NOT NULL + AND st.account_id IS NULL + """) + # Final fallback: platform-seeded templates with NULL team_id AND NULL created_by + # (e.g. the 6 AD templates inserted by migration 057) → platform sentinel account + op.execute(f""" + UPDATE script_templates + SET account_id = '{PLATFORM_ACCOUNT_ID}' + WHERE account_id IS NULL + """) + + # script_generations: user_id → users.account_id + op.execute(""" + UPDATE script_generations sg + SET account_id = u.account_id + FROM users u + WHERE sg.user_id = u.id + AND sg.account_id IS NULL + """) + + # VERIFY + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/7f136778f5a8_add_account_id_maintenance.py b/backend/alembic/versions/7f136778f5a8_add_account_id_maintenance.py new file mode 100644 index 00000000..fbbc5cbd --- /dev/null +++ b/backend/alembic/versions/7f136778f5a8_add_account_id_maintenance.py @@ -0,0 +1,62 @@ +"""add account_id to maintenance_schedules + +Revision ID: 7f136778f5a8 +Revises: 8aac5b372402 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '7f136778f5a8' +down_revision: Union[str, None] = '8aac5b372402' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('maintenance_schedules', + sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_maintenance_schedules_account_id', 'maintenance_schedules', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Primary: tree_id → trees.account_id (only where tree.account_id is NOT NULL) + op.execute(""" + UPDATE maintenance_schedules ms + SET account_id = t.account_id + FROM trees t + WHERE ms.tree_id = t.id + AND t.account_id IS NOT NULL + AND ms.account_id IS NULL + """) + + # Fallback: created_by → users.account_id (for is_default trees with NULL account_id) + op.execute(""" + UPDATE maintenance_schedules ms + SET account_id = u.account_id + FROM users u + WHERE ms.created_by = u.id + AND u.account_id IS NOT NULL + AND ms.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM maintenance_schedules WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} maintenance_schedules rows have NULL account_id. " + "Check if created_by is NULL — those rows need manual resolution." + ) + + op.alter_column('maintenance_schedules', 'account_id', nullable=False) + op.create_index('ix_maintenance_schedules_account_id', 'maintenance_schedules', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_maintenance_schedules_account_id', table_name='maintenance_schedules') + op.drop_constraint('fk_maintenance_schedules_account_id', 'maintenance_schedules', type_='foreignkey') + op.drop_column('maintenance_schedules', 'account_id') diff --git a/backend/alembic/versions/8aac5b372402_add_account_id_psa_notifications.py b/backend/alembic/versions/8aac5b372402_add_account_id_psa_notifications.py new file mode 100644 index 00000000..1637e0b1 --- /dev/null +++ b/backend/alembic/versions/8aac5b372402_add_account_id_psa_notifications.py @@ -0,0 +1,81 @@ +"""add account_id to PSA and notification tables + +Revision ID: 8aac5b372402 +Revises: a1d2a84b9abb +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '8aac5b372402' +down_revision: Union[str, None] = 'a1d2a84b9abb' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: ADD COLUMN + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Step 2: BACKFILL + # psa_post_log: prefer psa_connection → fallback to posted_by user + # Note: cannot reference the updated table (ppl) inside the FROM clause JOIN, + # so use a correlated subquery for psa_connections lookup instead. + op.execute(""" + UPDATE psa_post_log ppl + SET account_id = COALESCE( + (SELECT account_id FROM psa_connections WHERE id = ppl.psa_connection_id), + u.account_id + ) + FROM users u + WHERE ppl.posted_by = u.id + AND ppl.account_id IS NULL + """) + + # psa_member_mappings: via psa_connection + op.execute(""" + UPDATE psa_member_mappings pmm + SET account_id = pc.account_id + FROM psa_connections pc + WHERE pmm.psa_connection_id = pc.id + AND pmm.account_id IS NULL + """) + + # notification_logs: via notification_config + op.execute(""" + UPDATE notification_logs nl + SET account_id = nc.account_id + FROM notification_configs nc + WHERE nl.notification_config_id = nc.id + AND nl.account_id IS NULL + """) + + # Step 3: VERIFY + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + + # Step 4: SET NOT NULL + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.alter_column(table, 'account_id', nullable=False) + + # Step 5: CREATE INDEX + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py b/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py new file mode 100644 index 00000000..ca32f0d2 --- /dev/null +++ b/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py @@ -0,0 +1,45 @@ +"""add account_id to user personalization tables + +Revision ID: a1d2a84b9abb +Revises: 7167e9374b0c +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'a1d2a84b9abb' +down_revision: Union[str, None] = '7167e9374b0c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + op.execute(f""" + UPDATE {table} t + SET account_id = u.account_id + FROM users u + WHERE t.user_id = u.id + AND t.account_id IS NULL + """) + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/alembic/versions/a9f3b2c1d4e5_merge_phase1_with_main.py b/backend/alembic/versions/a9f3b2c1d4e5_merge_phase1_with_main.py new file mode 100644 index 00000000..2f14f7b1 --- /dev/null +++ b/backend/alembic/versions/a9f3b2c1d4e5_merge_phase1_with_main.py @@ -0,0 +1,24 @@ +"""merge Phase 1 tenant isolation chain with main head + +Revision ID: a9f3b2c1d4e5 +Revises: 070, 174f442795b7 +Create Date: 2026-04-09 00:00:00.000000 + +Merge migration: consolidates the Phase 1 account_id chain (cc214c63aa30 → … → 174f442795b7) +with the main sequential chain (… → 070) into a single head so that +`alembic upgrade head` works without ambiguity. +""" +from typing import Sequence, Union + +revision: str = 'a9f3b2c1d4e5' +down_revision: Union[str, tuple] = ('070', '174f442795b7') +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/backend/alembic/versions/c5f48b9890f9_enable_rls_phase1.py b/backend/alembic/versions/c5f48b9890f9_enable_rls_phase1.py new file mode 100644 index 00000000..333c5ca2 --- /dev/null +++ b/backend/alembic/versions/c5f48b9890f9_enable_rls_phase1.py @@ -0,0 +1,108 @@ +"""enable_rls_phase1 + +Revision ID: c5f48b9890f9 +Revises: 0b470d9e6cf1 +Create Date: 2026-04-10 04:01:13.043321 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c5f48b9890f9' +down_revision: Union[str, None] = '0b470d9e6cf1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_NULL_UUID = "00000000-0000-0000-0000-000000000000" +_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001" +_CURRENT_ACCOUNT = ( + f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), " + f"'{_NULL_UUID}')::uuid" +) + + +def upgrade() -> None: + # ── trees ─────────────────────────────────────────────────────────────── + # Extended policy mirrors can_access_tree() in app/core/permissions.py. + # Tenant sees: own rows, platform rows, any default tree, any public tree, + # any gallery-featured tree. + # is_gallery_featured = TRUE is included because /public/templates is a + # no-auth endpoint — no tenant context is set, so gallery trees must pass + # RLS on their own flag rather than relying on account_id or is_public. + # Private/team trees from other accounts are hidden. + op.execute("ALTER TABLE trees ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE trees FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON trees + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + OR is_default = TRUE + OR is_public = TRUE + OR is_gallery_featured = TRUE + ) + """) + + # ── tree_tags ──────────────────────────────────────────────────────────── + # Own account + platform tags (global tags visible to all tenants). + op.execute("ALTER TABLE tree_tags ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE tree_tags FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON tree_tags + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + ) + """) + + # ── tree_categories ────────────────────────────────────────────────────── + op.execute("ALTER TABLE tree_categories ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE tree_categories FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON tree_categories + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + ) + """) + + # ── step_categories ────────────────────────────────────────────────────── + op.execute("ALTER TABLE step_categories ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE step_categories FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON step_categories + USING ( + account_id = {_CURRENT_ACCOUNT} + OR account_id = '{_PLATFORM_UUID}'::uuid + ) + """) + + # ── psa_connections ────────────────────────────────────────────────────── + # Tenant-only — PSA credentials must never cross tenant boundaries. + op.execute("ALTER TABLE psa_connections ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE psa_connections FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON psa_connections + USING (account_id = {_CURRENT_ACCOUNT}) + """) + + # ── flow_proposals ──────────────────────────────────────────────────────── + # Tenant-only. + op.execute("ALTER TABLE flow_proposals ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE flow_proposals FORCE ROW LEVEL SECURITY") + op.execute(f""" + CREATE POLICY tenant_isolation ON flow_proposals + USING (account_id = {_CURRENT_ACCOUNT}) + """) + + +def downgrade() -> None: + for table in ["trees", "tree_tags", "tree_categories", "step_categories", + "psa_connections", "flow_proposals"]: + op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}") + op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") + op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY") diff --git a/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py b/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py new file mode 100644 index 00000000..bd929f52 --- /dev/null +++ b/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py @@ -0,0 +1,95 @@ +"""add account_id to core session tables + +Revision ID: cc214c63aa30 +Revises: b8d2f4a6c091 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'cc214c63aa30' +down_revision: Union[str, None] = '064' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = ('067',) + + +def upgrade() -> None: + # ── Step 1: ADD COLUMN (nullable) ──────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', + table, 'accounts', + ['account_id'], ['id'], + ondelete='CASCADE', + ) + + # ── Step 2: BACKFILL ───────────────────────────────────────────────────── + # sessions: direct join to users + op.execute(""" + UPDATE sessions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # attachments: chain through sessions (now backfilled above) + op.execute(""" + UPDATE attachments a + SET account_id = s.account_id + FROM sessions s + WHERE a.session_id = s.id + AND a.account_id IS NULL + """) + + # session_supporting_data: same chain + op.execute(""" + UPDATE session_supporting_data sd + SET account_id = s.account_id + FROM sessions s + WHERE sd.session_id = s.id + AND sd.account_id IS NULL + """) + + # session_resolution_outputs: FK is to ai_sessions, not sessions + op.execute(""" + UPDATE session_resolution_outputs sro + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE sro.session_id = ai.id + AND sro.account_id IS NULL + """) + + # ── Step 3: VERIFY zero NULLs — raises if any remain ──────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows remain in {table}. " + f"Fix the backfill before re-running." + ) + + # ── Step 4: SET NOT NULL ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.alter_column(table, 'account_id', nullable=False) + + # ── Step 5: CREATE INDEX ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 4bd3fd3c..bae3f935 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -10,6 +10,8 @@ from app.core.database import get_db from app.core.security import decode_token from app.models.user import User from app.models.plan_limits import PlanLimits +from app.core.tenant_context import set_current_account_id, clear_current_account_id +from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints # Routes that are allowed even when must_change_password is True _PASSWORD_CHANGE_ALLOWLIST = { @@ -192,18 +194,42 @@ async def get_plan_limits_for_user( return await get_user_plan_limits(current_user.account_id, db) -async def get_tenant_context( +async def require_tenant_context( current_user: Annotated[User, Depends(get_current_active_user)], -) -> UUID: - """Return the current user's account_id. +): + """Set per-request tenant context for RLS. - Use this dependency instead of reading current_user.account_id directly. - Raises 403 if the user has no account association (should not happen in - normal flows — users are always associated with an account on registration). + Raises 403 if the authenticated user has no account_id — never falls back + to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a + malformed account). + + Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to + issue set_config('app.current_account_id', …, true) on every transaction. + + Applied to every user-facing router. NOT applied to /admin/* routers or + public endpoints (auth, shared, webhooks). """ if current_user.account_id is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="User not associated with any account", + detail="User account required", ) - return current_user.account_id + token = set_current_account_id(current_user.account_id) + try: + yield + finally: + clear_current_account_id(token) + + +async def require_admin_db( + db: Annotated[AsyncSession, Depends(get_admin_db)], + current_user: Annotated[User, Depends(require_admin)], +) -> AsyncSession: + """Return a BYPASSRLS admin DB session after verifying super_admin role. + + Use on /admin/* endpoints that query RLS-protected tables. Replaces + Depends(get_db) on the db parameter of those endpoints. + The current_user dep is still declared separately on the endpoint if + the user object is needed in the handler. + """ + return db diff --git a/backend/app/api/endpoints/admin.py b/backend/app/api/endpoints/admin.py index 8450e0bd..76786c11 100644 --- a/backend/app/api/endpoints/admin.py +++ b/backend/app/api/endpoints/admin.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from sqlalchemy.orm import selectinload -from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.core.audit import log_audit from app.core.config import settings from app.core.security import get_password_hash, generate_temp_password, create_password_reset_token, decode_token, hash_token @@ -37,7 +37,7 @@ router = APIRouter(prefix="/admin", tags=["admin"]) @router.get("/users", response_model=list[UserResponse]) async def list_users( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=100), @@ -74,7 +74,7 @@ def _generate_display_code() -> str: @router.post("/users", response_model=AdminUserCreateResponse, status_code=status.HTTP_201_CREATED) async def create_user( data: AdminUserCreate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Create a new user with a temporary password (super admin only). @@ -199,7 +199,7 @@ async def create_user( @router.get("/users/{user_id}", response_model=UserDetailResponse) async def get_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Get enriched user details (super admin only).""" @@ -317,7 +317,7 @@ async def get_user( async def update_user_role( user_id: UUID, role_data: RoleUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Change user role (super admin only).""" @@ -349,7 +349,7 @@ async def update_user_role( async def update_account_role( user_id: UUID, data: AccountRoleUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Change a user's account role (super admin only).""" @@ -375,7 +375,7 @@ async def update_account_role( async def update_super_admin_status( user_id: UUID, data: dict, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Promote or demote a user to/from super admin (super admin only).""" @@ -414,7 +414,7 @@ async def update_super_admin_status( @router.put("/users/{user_id}/deactivate", response_model=UserResponse) async def deactivate_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Deactivate a user account (super admin only).""" @@ -443,7 +443,7 @@ async def deactivate_user( @router.put("/users/{user_id}/activate", response_model=UserResponse) async def activate_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)] ): """Reactivate a user account (super admin only).""" @@ -467,7 +467,7 @@ async def activate_user( async def move_user_account( user_id: UUID, data: MoveUserAccount, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Move a user to a different account (super admin only).""" @@ -520,7 +520,7 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User, async def update_user_plan( user_id: UUID, data: SubscriptionPlanUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Change a user's subscription plan (super admin only).""" @@ -539,7 +539,7 @@ async def update_user_plan( async def extend_user_trial( user_id: UUID, data: ExtendTrialRequest, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Extend or start a trial for a user's subscription (super admin only).""" @@ -569,7 +569,7 @@ async def extend_user_trial( async def admin_reset_password( user_id: UUID, data: AdminPasswordReset, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Admin-triggered password reset (super admin only). @@ -640,7 +640,7 @@ async def admin_reset_password( @router.put("/users/{user_id}/archive", response_model=UserResponse) async def archive_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Archive (soft delete) a user (super admin only).""" @@ -675,7 +675,7 @@ async def archive_user( @router.put("/users/{user_id}/restore", response_model=UserResponse) async def restore_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Restore an archived user (super admin only).""" @@ -700,7 +700,7 @@ async def restore_user( @router.get("/users/{user_id}/hard-delete-check", response_model=HardDeleteCheckResponse) async def hard_delete_check( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Check if a user can be hard-deleted (super admin only). Returns blockers.""" @@ -773,7 +773,7 @@ async def hard_delete_check( @router.delete("/users/{user_id}/hard-delete", status_code=status.HTTP_204_NO_CONTENT) async def hard_delete_user( user_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Permanently delete a user (super admin only). User must be archived first.""" @@ -833,7 +833,7 @@ async def hard_delete_user( @router.post("/invites", status_code=status.HTTP_201_CREATED) async def admin_create_invite( data: dict, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Quick-invite a user to an account (super admin only). diff --git a/backend/app/api/endpoints/admin_categories.py b/backend/app/api/endpoints/admin_categories.py index 39218bcb..bfecc31e 100644 --- a/backend/app/api/endpoints/admin_categories.py +++ b/backend/app/api/endpoints/admin_categories.py @@ -4,25 +4,26 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func -from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.core.audit import log_audit from app.models.user import User from app.models.category import TreeCategory from app.models.tree import Tree from app.schemas.admin import GlobalCategoryCreate, GlobalCategoryUpdate, GlobalCategoryResponse from app.api.deps import require_admin +from app.core.service_account import PLATFORM_ACCOUNT_ID router = APIRouter(prefix="/admin/categories", tags=["admin-categories"]) @router.get("/global", response_model=list[GlobalCategoryResponse]) async def list_global_categories( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """List all global categories (account_id IS NULL).""" result = await db.execute( - select(TreeCategory).where(TreeCategory.account_id.is_(None)).order_by(TreeCategory.name) + select(TreeCategory).where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID).order_by(TreeCategory.name) ) categories = result.scalars().all() @@ -45,36 +46,36 @@ async def list_global_categories( @router.post("/global", response_model=GlobalCategoryResponse, status_code=status.HTTP_201_CREATED) async def create_global_category( data: GlobalCategoryCreate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Create a global category.""" # Check slug uniqueness for global categories existing = await db.execute( - select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None)) + select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID) ) if existing.scalar_one_or_none(): raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Global category with this slug already exists") - category = TreeCategory(name=data.name, slug=data.slug, account_id=None) + category = TreeCategory(name=data.name, slug=data.slug, account_id=PLATFORM_ACCOUNT_ID) db.add(category) await log_audit(db, current_user.id, "global_category.create", "category", details={"name": data.name}) await db.commit() await db.refresh(category) - return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=None, tree_count=0) + return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=PLATFORM_ACCOUNT_ID, tree_count=0) @router.put("/global/{category_id}", response_model=GlobalCategoryResponse) async def update_global_category( category_id: UUID, data: GlobalCategoryUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Update a global category.""" result = await db.execute( - select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None)) + select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID) ) category = result.scalar_one_or_none() if not category: @@ -86,7 +87,7 @@ async def update_global_category( # Check slug uniqueness existing = await db.execute( select(TreeCategory).where( - TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None), TreeCategory.id != category_id + TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID, TreeCategory.id != category_id ) ) if existing.scalar_one_or_none(): @@ -103,19 +104,19 @@ async def update_global_category( return GlobalCategoryResponse( id=category.id, name=category.name, slug=category.slug, - account_id=None, tree_count=tree_count, + account_id=PLATFORM_ACCOUNT_ID, tree_count=tree_count, ) @router.delete("/global/{category_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_global_category( category_id: UUID, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Delete (archive) a global category.""" result = await db.execute( - select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None)) + select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID) ) category = result.scalar_one_or_none() if not category: diff --git a/backend/app/api/endpoints/admin_dashboard.py b/backend/app/api/endpoints/admin_dashboard.py index 33d8f564..90859b18 100644 --- a/backend/app/api/endpoints/admin_dashboard.py +++ b/backend/app/api/endpoints/admin_dashboard.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func -from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.models.user import User from app.models.subscription import Subscription from app.models.tree import Tree @@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin/dashboard", tags=["admin-dashboard"]) @router.get("/metrics", response_model=DashboardMetrics) async def get_dashboard_metrics( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Get platform overview metrics.""" @@ -45,7 +45,7 @@ async def get_dashboard_metrics( @router.get("/activity", response_model=list[ActivityEntry]) async def get_dashboard_activity( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Get recent audit log entries for activity feed.""" diff --git a/backend/app/api/endpoints/admin_gallery.py b/backend/app/api/endpoints/admin_gallery.py index 8292bfb4..d3cc61d6 100644 --- a/backend/app/api/endpoints/admin_gallery.py +++ b/backend/app/api/endpoints/admin_gallery.py @@ -12,7 +12,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import require_admin -from app.core.database import get_db +from app.core.admin_database import get_admin_db from app.models.script_template import ScriptTemplate from app.models.tree import Tree from app.models.user import User @@ -66,7 +66,7 @@ def _script_summary(script: ScriptTemplate) -> dict: @router.get("/featured") async def list_featured( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """List all featured flows and scripts (super admin only).""" @@ -92,7 +92,7 @@ async def list_featured( @router.get("/items") async def list_all_items( - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """List ALL flows and scripts with their gallery status (super admin only).""" @@ -119,7 +119,7 @@ async def list_all_items( async def toggle_flow_featured( flow_id: UUID, body: FeatureToggle, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Toggle is_gallery_featured on a flow (super admin only).""" @@ -138,7 +138,7 @@ async def toggle_flow_featured( async def update_flow_sort_order( flow_id: UUID, body: SortOrderUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Update gallery_sort_order on a flow (super admin only).""" @@ -157,7 +157,7 @@ async def update_flow_sort_order( async def toggle_script_featured( script_id: UUID, body: FeatureToggle, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Toggle is_gallery_featured on a script (super admin only).""" @@ -176,7 +176,7 @@ async def toggle_script_featured( async def update_script_sort_order( script_id: UUID, body: SortOrderUpdate, - db: Annotated[AsyncSession, Depends(get_db)], + db: Annotated[AsyncSession, Depends(get_admin_db)], current_user: Annotated[User, Depends(require_admin)], ): """Update gallery_sort_order on a script (super admin only).""" diff --git a/backend/app/api/endpoints/categories.py b/backend/app/api/endpoints/categories.py index f0d7d010..0d6517c3 100644 --- a/backend/app/api/endpoints/categories.py +++ b/backend/app/api/endpoints/categories.py @@ -12,6 +12,7 @@ from app.models.user import User from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse from app.api.deps import get_current_active_user from app.core.permissions import can_manage_category, can_create_category +from app.core.service_account import PLATFORM_ACCOUNT_ID from app.core.filters import tenant_filter router = APIRouter(prefix="/categories", tags=["categories"]) @@ -48,13 +49,13 @@ async def list_categories( elif current_user.account_id: query = query.where( or_( - TreeCategory.account_id.is_(None), # Global + TreeCategory.account_id == PLATFORM_ACCOUNT_ID, # Global TreeCategory.account_id == current_user.account_id # User's account ) ) else: # User has no account, only show global categories - query = query.where(TreeCategory.account_id.is_(None)) + query = query.where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID) query = query.order_by(TreeCategory.display_order, TreeCategory.name) @@ -176,7 +177,7 @@ async def create_category( name=category_data.name, slug=slug, description=category_data.description, - account_id=category_data.account_id, + account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID, display_order=max_order + 1, created_by=current_user.id ) diff --git a/backend/app/api/endpoints/scripts.py b/backend/app/api/endpoints/scripts.py index 180c0d43..3db7175d 100644 --- a/backend/app/api/endpoints/scripts.py +++ b/backend/app/api/endpoints/scripts.py @@ -197,6 +197,7 @@ async def create_template( template = ScriptTemplate( category_id=data.category_id, team_id=current_user.team_id, + account_id=current_user.account_id, created_by=current_user.id, name=data.name, slug=slug, @@ -364,6 +365,7 @@ async def generate_script( generation = ScriptGeneration( template_id=template.id, user_id=current_user.id, + account_id=current_user.account_id, team_id=current_user.team_id, session_id=data.session_id, ai_session_id=data.ai_session_id, diff --git a/backend/app/api/endpoints/step_categories.py b/backend/app/api/endpoints/step_categories.py index 53770bee..c194ea6c 100644 --- a/backend/app/api/endpoints/step_categories.py +++ b/backend/app/api/endpoints/step_categories.py @@ -16,6 +16,7 @@ from app.schemas.step_category import ( ) from app.api.deps import get_current_active_user from app.core.permissions import can_manage_step_category, can_create_step_category +from app.core.service_account import PLATFORM_ACCOUNT_ID router = APIRouter(prefix="/step-categories", tags=["step-categories"]) @@ -44,13 +45,13 @@ async def list_step_categories( elif current_user.account_id: query = query.where( or_( - StepCategory.account_id.is_(None), # Global + StepCategory.account_id == PLATFORM_ACCOUNT_ID, # Global StepCategory.account_id == current_user.account_id # User's account ) ) else: # User has no account, only show global categories - query = query.where(StepCategory.account_id.is_(None)) + query = query.where(StepCategory.account_id == PLATFORM_ACCOUNT_ID) query = query.order_by(StepCategory.display_order, StepCategory.name) @@ -155,7 +156,7 @@ async def create_step_category( name=category_data.name, slug=slug, description=category_data.description, - account_id=category_data.account_id, + account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID, display_order=max_order + 1, created_by=current_user.id ) diff --git a/backend/app/api/endpoints/tags.py b/backend/app/api/endpoints/tags.py index b8438e8b..ac5cea4d 100644 --- a/backend/app/api/endpoints/tags.py +++ b/backend/app/api/endpoints/tags.py @@ -12,6 +12,7 @@ from app.models.user import User from app.schemas.tag import TagCreate, TagResponse, TagListResponse, TagAssignment from app.api.deps import get_current_active_user from app.core.permissions import can_manage_tree_tags, can_create_tag +from app.core.service_account import PLATFORM_ACCOUNT_ID router = APIRouter(prefix="/tags", tags=["tags"]) @@ -33,13 +34,13 @@ async def list_tags( if include_account and current_user.account_id: query = query.where( or_( - TreeTag.account_id.is_(None), # Global + TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global TreeTag.account_id == current_user.account_id # User's account ) ) else: # Only show global tags - query = query.where(TreeTag.account_id.is_(None)) + query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name) @@ -71,12 +72,12 @@ async def search_tags( if include_account and current_user.account_id: query = query.where( or_( - TreeTag.account_id.is_(None), + TreeTag.account_id == PLATFORM_ACCOUNT_ID, TreeTag.account_id == current_user.account_id ) ) else: - query = query.where(TreeTag.account_id.is_(None)) + query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID) query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit) @@ -147,7 +148,7 @@ async def create_tag( new_tag = TreeTag( name=tag_data.name, slug=slug, - account_id=tag_data.account_id, + account_id=tag_data.account_id if tag_data.account_id is not None else PLATFORM_ACCOUNT_ID, created_by=current_user.id ) db.add(new_tag) @@ -206,7 +207,7 @@ async def add_tags_to_tree( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.account_id.is_(None), # Global tag + TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global tag TreeTag.account_id == tag_account_id # Account tag ) ) @@ -340,7 +341,7 @@ async def replace_tree_tags( tag_query = select(TreeTag).where( TreeTag.slug == slug, or_( - TreeTag.account_id.is_(None), + TreeTag.account_id == PLATFORM_ACCOUNT_ID, TreeTag.account_id == tag_account_id ) ) diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 75cdaf70..6a16297e 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -29,6 +29,7 @@ from app.core.subscriptions import check_tree_limit, get_account_subscription, g from app.core.audit import log_audit from app.core.config import settings from app.core.tree_validation import can_publish_tree +from app.core.service_account import PLATFORM_ACCOUNT_ID from app.core.step_sync import sync_steps_from_tree, deactivate_synced_steps_for_tree from app.services.rag_service import index_tree as rag_index_tree @@ -391,6 +392,7 @@ async def get_tree( ) if not tree.is_active or not can_access_tree(current_user, tree): + # Always 404, never 403. A 403 confirms the resource exists. raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Tree not found" @@ -470,7 +472,7 @@ async def create_tree( tree_structure=tree_data.tree_structure, intake_form=intake_form_data, author_id=service_account_id if is_default else current_user.id, - account_id=None if is_default else current_user.account_id, + account_id=PLATFORM_ACCOUNT_ID if is_default else current_user.account_id, is_public=True if is_default else tree_data.is_public, # Default trees are always public is_default=is_default, status=tree_data.status diff --git a/backend/app/api/router.py b/backend/app/api/router.py index d588afc9..ed32ba58 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,51 +1,89 @@ -from fastapi import APIRouter -from app.api.endpoints import auth, trees, sessions, sidebar, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares, shared, tree_markdown -from app.api.endpoints import admin_dashboard, admin_audit, admin_plan_limits, admin_feature_flags, admin_settings, admin_categories -from app.api.endpoints import ratings, analytics -from app.api.endpoints import target_lists -from app.api.endpoints import maintenance_schedules -from app.api.endpoints import feedback -from app.api.endpoints import ai_builder -from app.api.endpoints import ai_fix -from app.api.endpoints import ai_chat -from app.api.endpoints import copilot -from app.api.endpoints import assistant_chat -from app.api.endpoints import survey -from app.api.endpoints import admin_survey -from app.api.endpoints import tree_transfer -from app.api.endpoints import ai_suggestions -from app.api.endpoints import kb_accelerator -from app.api.endpoints import beta_signup -from app.api.endpoints import scripts -from app.api.endpoints import integrations -from app.api.endpoints import onboarding -from app.api.endpoints import branding -from app.api.endpoints import supporting_data -from app.api.endpoints import ai_sessions -from app.api.endpoints import flow_proposals -from app.api.endpoints import flowpilot_analytics -from app.api.endpoints import notifications -from app.api.endpoints import public_templates -from app.api.endpoints import admin_gallery -from app.api.endpoints import uploads -from app.api.endpoints import script_builder -from app.api.endpoints import beta_feedback -from app.api.endpoints import session_branches -from app.api.endpoints import session_handoffs -from app.api.endpoints import session_resolutions +from fastapi import APIRouter, Depends + +from app.api.deps import require_tenant_context +from app.api.endpoints import ( + admin, + admin_audit, + admin_categories, + admin_dashboard, + admin_feature_flags, + admin_gallery, + admin_plan_limits, + admin_settings, + admin_survey, + ai_builder, + ai_chat, + ai_fix, + ai_sessions, + ai_suggestions, + analytics, + assistant_chat, + auth, + beta_feedback, + beta_signup, + branding, + categories, + copilot, + feedback, + flow_proposals, + flowpilot_analytics, + folders, + integrations, + invite, + kb_accelerator, + maintenance_schedules, + notifications, + onboarding, + public_templates, + ratings, + scripts, + script_builder, + session_branches, + session_handoffs, + session_resolutions, + sessions, + shared, + shares, + sidebar, + step_categories, + steps, + supporting_data, + survey, + tags, + target_lists, + tree_markdown, + tree_transfer, + trees, + uploads, + webhooks, + accounts, +) api_router = APIRouter() +# --------------------------------------------------------------------------- +# Public / unauthenticated endpoints — no tenant context +# +# Note: auth.router contains both public endpoints (register, login, +# forgot-password, reset-password, email/verify) and authenticated endpoints +# (GET/PATCH /me, logout, change-password, email/send-verification). +# The authenticated auth endpoints only query the `users` table, which is +# excluded from Phase 1 RLS. They work correctly without tenant context +# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS. +# --------------------------------------------------------------------------- api_router.include_router(auth.router) -api_router.include_router(trees.router) -api_router.include_router(sidebar.router) -api_router.include_router(sessions.router) -api_router.include_router(invite.router) -api_router.include_router(categories.router) -api_router.include_router(tags.router) -api_router.include_router(folders.router) -api_router.include_router(step_categories.router) -api_router.include_router(steps.router) +api_router.include_router(shared.router) # Public share links (no auth) +api_router.include_router(beta_signup.router) +api_router.include_router(webhooks.router) # Stripe webhook receiver +api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited) + +# --------------------------------------------------------------------------- +# Admin endpoints — super_admin only +# admin_categories, admin_gallery, admin_dashboard, admin query Phase 1 RLS +# tables and MUST use get_admin_db (migrated in Task 8). The remaining admin +# endpoints (admin_audit, admin_plan_limits, admin_feature_flags, +# admin_settings, admin_survey) are safe until Phase 2 extends RLS. +# --------------------------------------------------------------------------- api_router.include_router(admin.router) api_router.include_router(admin_dashboard.router) api_router.include_router(admin_audit.router) @@ -53,42 +91,54 @@ api_router.include_router(admin_plan_limits.router) api_router.include_router(admin_feature_flags.router) api_router.include_router(admin_settings.router) api_router.include_router(admin_categories.router) -api_router.include_router(accounts.router) -api_router.include_router(webhooks.router) -api_router.include_router(shares.router) -api_router.include_router(shared.router) # Public endpoints (no auth) -api_router.include_router(tree_markdown.router) -api_router.include_router(ratings.router) -api_router.include_router(analytics.router) -api_router.include_router(target_lists.router) -api_router.include_router(maintenance_schedules.router) -api_router.include_router(feedback.router) -api_router.include_router(ai_builder.router) -api_router.include_router(ai_fix.router) -api_router.include_router(ai_chat.router) -api_router.include_router(copilot.router) -api_router.include_router(assistant_chat.router) -api_router.include_router(survey.router) api_router.include_router(admin_survey.router) -api_router.include_router(tree_transfer.router) -api_router.include_router(ai_suggestions.router) -api_router.include_router(kb_accelerator.router) -api_router.include_router(beta_signup.router) -api_router.include_router(scripts.router) -api_router.include_router(integrations.router) -api_router.include_router(onboarding.router) -api_router.include_router(branding.router) -api_router.include_router(supporting_data.router) -api_router.include_router(session_handoffs.queue_router) # Must be before ai_sessions to avoid /{session_id} conflict -api_router.include_router(session_resolutions.router) # Must be before ai_sessions to avoid /{session_id} conflict -api_router.include_router(ai_sessions.router) -api_router.include_router(flow_proposals.router) -api_router.include_router(flowpilot_analytics.router) -api_router.include_router(notifications.router) -api_router.include_router(public_templates.router) api_router.include_router(admin_gallery.router) -api_router.include_router(uploads.router) -api_router.include_router(script_builder.router) -api_router.include_router(beta_feedback.router) -api_router.include_router(session_branches.router) -api_router.include_router(session_handoffs.router) + +# --------------------------------------------------------------------------- +# User-facing endpoints — tenant context required +# --------------------------------------------------------------------------- +_tenant_deps = [Depends(require_tenant_context)] + +api_router.include_router(trees.router, dependencies=_tenant_deps) +api_router.include_router(sidebar.router, dependencies=_tenant_deps) +api_router.include_router(sessions.router, dependencies=_tenant_deps) +api_router.include_router(invite.router, dependencies=_tenant_deps) +api_router.include_router(categories.router, dependencies=_tenant_deps) +api_router.include_router(tags.router, dependencies=_tenant_deps) +api_router.include_router(folders.router, dependencies=_tenant_deps) +api_router.include_router(step_categories.router, dependencies=_tenant_deps) +api_router.include_router(steps.router, dependencies=_tenant_deps) +api_router.include_router(accounts.router, dependencies=_tenant_deps) +api_router.include_router(shares.router, dependencies=_tenant_deps) +api_router.include_router(tree_markdown.router, dependencies=_tenant_deps) +api_router.include_router(ratings.router, dependencies=_tenant_deps) +api_router.include_router(analytics.router, dependencies=_tenant_deps) +api_router.include_router(target_lists.router, dependencies=_tenant_deps) +api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps) +api_router.include_router(feedback.router, dependencies=_tenant_deps) +api_router.include_router(ai_builder.router, dependencies=_tenant_deps) +api_router.include_router(ai_fix.router, dependencies=_tenant_deps) +api_router.include_router(ai_chat.router, dependencies=_tenant_deps) +api_router.include_router(copilot.router, dependencies=_tenant_deps) +api_router.include_router(assistant_chat.router, dependencies=_tenant_deps) +api_router.include_router(survey.router, dependencies=_tenant_deps) +api_router.include_router(tree_transfer.router, dependencies=_tenant_deps) +api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps) +api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps) +api_router.include_router(scripts.router, dependencies=_tenant_deps) +api_router.include_router(integrations.router, dependencies=_tenant_deps) +api_router.include_router(onboarding.router, dependencies=_tenant_deps) +api_router.include_router(branding.router, dependencies=_tenant_deps) +api_router.include_router(supporting_data.router, dependencies=_tenant_deps) +# session_handoffs queue router must come before ai_sessions to avoid conflict +api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps) +api_router.include_router(session_resolutions.router, dependencies=_tenant_deps) +api_router.include_router(ai_sessions.router, dependencies=_tenant_deps) +api_router.include_router(flow_proposals.router, dependencies=_tenant_deps) +api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps) +api_router.include_router(notifications.router, dependencies=_tenant_deps) +api_router.include_router(uploads.router, dependencies=_tenant_deps) +api_router.include_router(script_builder.router, dependencies=_tenant_deps) +api_router.include_router(beta_feedback.router, dependencies=_tenant_deps) +api_router.include_router(session_branches.router, dependencies=_tenant_deps) +api_router.include_router(session_handoffs.router, dependencies=_tenant_deps) diff --git a/backend/app/core/admin_database.py b/backend/app/core/admin_database.py new file mode 100644 index 00000000..1e84a132 --- /dev/null +++ b/backend/app/core/admin_database.py @@ -0,0 +1,36 @@ +# backend/app/core/admin_database.py +""" +Admin database engine — connects as resolutionflow_admin (BYPASSRLS). + +Use ONLY for /admin/* endpoints and internal tooling. +Never use this engine from user-facing endpoints. +""" +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import settings + +admin_engine = create_async_engine( + settings.ADMIN_DATABASE_URL, + echo=settings.DEBUG, + future=True, +) + +_admin_session_factory = async_sessionmaker( + admin_engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +async def get_admin_db() -> AsyncGenerator[AsyncSession, None]: + """Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints.""" + async with _admin_session_factory() as session: + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 7fdf7fb6..5d31b789 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -23,10 +23,33 @@ class Settings(BaseSettings): return v.replace("postgresql://", "postgresql+asyncpg://", 1) return v - @property - def DATABASE_URL_SYNC(self) -> str: - """Get sync URL by removing asyncpg prefix from DATABASE_URL.""" - return self.DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://", 1) + # Sync URL for Alembic migrations. Defaults to DATABASE_URL (sync-converted). + # Set explicitly in .env to use a different role for migrations (e.g. superuser) + # when DATABASE_URL has been switched to the app role. + DATABASE_URL_SYNC: str = "" + + @field_validator("DATABASE_URL_SYNC", mode="before") + @classmethod + def default_database_url_sync(cls, v: str, info) -> str: + """Fall back to sync-converted DATABASE_URL if not explicitly set.""" + if not v: + base = info.data.get("DATABASE_URL", "") + return base.replace("postgresql+asyncpg://", "postgresql://", 1) + return v + + # Admin database — resolutionflow_admin role, BYPASSRLS. + # Used by /admin/* endpoints. Defaults to DATABASE_URL for local dev. + ADMIN_DATABASE_URL: str = "" + + @field_validator("ADMIN_DATABASE_URL", mode="before") + @classmethod + def default_admin_database_url(cls, v: str, info) -> str: + """Fall back to DATABASE_URL if ADMIN_DATABASE_URL is not set.""" + if not v: + return info.data.get("DATABASE_URL", "") + if v.startswith("postgresql://"): + return v.replace("postgresql://", "postgresql+asyncpg://", 1) + return v # JWT Settings SECRET_KEY: str = _DEFAULT_SECRET_KEY diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 45dc8288..c8132156 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -1,6 +1,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase from .config import settings +from app.core.tenant_context import register_tenant_listener # Create async engine engine = create_async_engine( @@ -16,6 +17,11 @@ async_session_maker = async_sessionmaker( expire_on_commit=False ) +# Register the RLS tenant context listener on the app engine. +# Fires at the start of every transaction; issues set_config automatically. +# Must NOT be called on admin_engine — admin connections bypass RLS. +register_tenant_listener(engine) + class Base(DeclarativeBase): """Base class for all database models.""" diff --git a/backend/app/core/service_account.py b/backend/app/core/service_account.py index a2175981..9d00a1d9 100644 --- a/backend/app/core/service_account.py +++ b/backend/app/core/service_account.py @@ -18,6 +18,10 @@ logger = logging.getLogger(__name__) SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com" SERVICE_ACCOUNT_NAME = "ResolutionFlow" + +# Well-known UUID for the platform account — owns all default/global content. +# Created by migration 3a40fe11b427_create_global_content_tables. +PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001") SYSTEM_ACCOUNT_NAME = "ResolutionFlow System" SYSTEM_ACCOUNT_DISPLAY_CODE = "RF-SYS-1" diff --git a/backend/app/core/tenant_context.py b/backend/app/core/tenant_context.py new file mode 100644 index 00000000..9cdb80c2 --- /dev/null +++ b/backend/app/core/tenant_context.py @@ -0,0 +1,92 @@ +# backend/app/core/tenant_context.py +""" +Per-request tenant context for row-level security. + +Flow: + 1. require_tenant_context (FastAPI dep) calls set_current_account_id(). + 2. The SQLAlchemy transaction-begin listener fires on every new transaction + and calls set_config('app.current_account_id', , true) automatically. + 3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE) + to filter rows. + +The ContextVar is asyncio-task-scoped: each concurrent request has its own value. +set_config with is_local=true is transaction-scoped: it resets on COMMIT or +ROLLBACK, so the listener re-applies it at the start of every transaction. +""" +import contextvars +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlalchemy import event, or_, text +from sqlalchemy.ext.asyncio import AsyncEngine + +if TYPE_CHECKING: + from app.models.user import User + +# One slot per async task — each concurrent request gets its own value. +_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "current_account_id", default=None +) + +# Platform account — global content visible to all tenants. +PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001") + + +def set_current_account_id(account_id: UUID) -> contextvars.Token: + """Set tenant context for the current request coroutine. + + Returns a token so the caller can reset it after the request. + """ + return _current_account_id.set(str(account_id)) + + +def clear_current_account_id(token: contextvars.Token) -> None: + """Reset the ContextVar to its previous value (call in finally block).""" + _current_account_id.reset(token) + + +def get_current_account_id() -> str | None: + """Return the account_id string for the current request, or None.""" + return _current_account_id.get() + + +def register_tenant_listener(engine: AsyncEngine) -> None: + """Register the transaction-begin listener on the given engine. + + Must be called once at application startup, AFTER the engine is created. + The listener issues set_config() at the start of every transaction so that + the setting is re-applied automatically even when a request commits + mid-flight and starts a new transaction. + + Do NOT call this on admin_engine — admin connections must never set tenant + context automatically. + """ + + @event.listens_for(engine.sync_engine, "begin") + def _on_transaction_begin(conn) -> None: # noqa: ANN001 + account_id = _current_account_id.get() + if account_id: + # set_config(name, value, is_local=true) ≡ SET LOCAL. + # Unlike SET LOCAL, set_config IS parameterisable. + conn.execute( + text("SELECT set_config('app.current_account_id', :id, true)"), + {"id": account_id}, + ) + # If no account_id is set, do nothing. The RLS policy falls back to a + # null-matching UUID and returns zero rows — fail-closed behaviour. + + +def tenant_filter(Model, current_user: "User"): # noqa: ANN001 + """SQLAlchemy filter clause for tables that contain platform-owned rows. + + Use for: tree_tags, tree_categories, step_categories, step_library, + template_trees, platform_steps. + + For tenant-only tables (trees, sessions, psa_connections, etc.) use: + Model.account_id == current_user.account_id + directly. + """ + return or_( + Model.account_id == current_user.account_id, + Model.account_id == PLATFORM_ACCOUNT_ID, + ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index fd3a754a..0441624f 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -54,6 +54,8 @@ from .session_branch import SessionBranch from .fork_point import ForkPoint from .session_handoff import SessionHandoff from .session_resolution_output import SessionResolutionOutput +from .template_tree import TemplateTree +from .platform_step import PlatformStep __all__ = [ "User", @@ -122,4 +124,6 @@ __all__ = [ "ForkPoint", "SessionHandoff", "SessionResolutionOutput", + "TemplateTree", + "PlatformStep", ] diff --git a/backend/app/models/ai_session_step.py b/backend/app/models/ai_session_step.py index 1642632b..09ffc4c1 100644 --- a/backend/app/models/ai_session_step.py +++ b/backend/app/models/ai_session_step.py @@ -50,6 +50,13 @@ class AISessionStep(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Denormalized from ai_sessions.account_id for direct tenant filtering.", + ) step_order: Mapped[int] = mapped_column( Integer, nullable=False, comment="Sequential position in the session (0-indexed)", diff --git a/backend/app/models/ai_suggestion.py b/backend/app/models/ai_suggestion.py index 8ee65dd5..12321c9a 100644 --- a/backend/app/models/ai_suggestion.py +++ b/backend/app/models/ai_suggestion.py @@ -28,6 +28,12 @@ class AISuggestion(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) session_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("ai_chat_sessions.id", ondelete="SET NULL"), diff --git a/backend/app/models/attachment.py b/backend/app/models/attachment.py index dc5266b6..910f697c 100644 --- a/backend/app/models/attachment.py +++ b/backend/app/models/attachment.py @@ -20,6 +20,12 @@ class Attachment(Base): ForeignKey("sessions.id"), nullable=False ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) file_name: Mapped[str] = mapped_column(String(255), nullable=False) file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) diff --git a/backend/app/models/category.py b/backend/app/models/category.py index eb3a56a6..abc7f8d7 100644 --- a/backend/app/models/category.py +++ b/backend/app/models/category.py @@ -39,10 +39,10 @@ class TreeCategory(Base): nullable=True, index=True ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, index=True ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) diff --git a/backend/app/models/feedback.py b/backend/app/models/feedback.py index bfd50302..59501204 100644 --- a/backend/app/models/feedback.py +++ b/backend/app/models/feedback.py @@ -1,6 +1,5 @@ import uuid from datetime import datetime, timezone -from typing import Optional from sqlalchemy import String, Text, DateTime, ForeignKey from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import UUID @@ -11,7 +10,7 @@ class Feedback(Base): __tablename__ = "feedback" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="SET NULL"), nullable=True) + account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True) user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False) email: Mapped[str] = mapped_column(String(255), nullable=False) feedback_type: Mapped[str] = mapped_column(String(50), nullable=False) diff --git a/backend/app/models/folder.py b/backend/app/models/folder.py index 7edaeaef..50923c86 100644 --- a/backend/app/models/folder.py +++ b/backend/app/models/folder.py @@ -46,6 +46,12 @@ class UserFolder(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) name: Mapped[str] = mapped_column(String(100), nullable=False) color: Mapped[str] = mapped_column(String(7), nullable=False, default="#6366f1") icon: Mapped[str] = mapped_column(String(50), nullable=False, default="folder") diff --git a/backend/app/models/fork_point.py b/backend/app/models/fork_point.py index a5700774..8c89d49d 100644 --- a/backend/app/models/fork_point.py +++ b/backend/app/models/fork_point.py @@ -23,6 +23,12 @@ class ForkPoint(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) parent_branch_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=False) trigger_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True) fork_reason: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/backend/app/models/maintenance_schedule.py b/backend/app/models/maintenance_schedule.py index 91280eb4..f8e38246 100644 --- a/backend/app/models/maintenance_schedule.py +++ b/backend/app/models/maintenance_schedule.py @@ -23,6 +23,12 @@ class MaintenanceSchedule(Base): created_by: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) cron_expression: Mapped[str] = mapped_column(String(100), nullable=False) timezone: Mapped[str] = mapped_column(String(100), nullable=False, default="UTC") target_list_id: Mapped[Optional[uuid.UUID]] = mapped_column( diff --git a/backend/app/models/notification_log.py b/backend/app/models/notification_log.py index 5ee4e932..99f8a7cb 100644 --- a/backend/app/models/notification_log.py +++ b/backend/app/models/notification_log.py @@ -31,6 +31,12 @@ class NotificationLog(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) event: Mapped[str] = mapped_column(String(50), nullable=False) payload: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) status: Mapped[str] = mapped_column(String(20), default="sent") diff --git a/backend/app/models/platform_step.py b/backend/app/models/platform_step.py new file mode 100644 index 00000000..39e79733 --- /dev/null +++ b/backend/app/models/platform_step.py @@ -0,0 +1,37 @@ +"""Platform step model — platform-owned steps, readable by all users. + +No account_id. No RLS. Readable by any authenticated user. +Populated by promoting visibility='public' steps from step_library. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class PlatformStep(Base): + __tablename__ = "platform_steps" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + title: Mapped[str] = mapped_column(String(255), nullable=False) + step_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + content: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + source_step_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("step_library.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/models/psa_member_mapping.py b/backend/app/models/psa_member_mapping.py index e85925d8..6ca18109 100644 --- a/backend/app/models/psa_member_mapping.py +++ b/backend/app/models/psa_member_mapping.py @@ -25,6 +25,12 @@ class PsaMemberMapping(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) user_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), diff --git a/backend/app/models/psa_post_log.py b/backend/app/models/psa_post_log.py index 14697507..9e4018f8 100644 --- a/backend/app/models/psa_post_log.py +++ b/backend/app/models/psa_post_log.py @@ -35,6 +35,12 @@ class PsaPostLog(Base): ForeignKey("psa_connections.id", ondelete="SET NULL"), nullable=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) ticket_id: Mapped[str] = mapped_column(String(100), nullable=False) note_type: Mapped[str] = mapped_column(String(50), nullable=False) content_posted: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/backend/app/models/script_builder_session.py b/backend/app/models/script_builder_session.py index f7075494..723a4cfb 100644 --- a/backend/app/models/script_builder_session.py +++ b/backend/app/models/script_builder_session.py @@ -29,6 +29,12 @@ class ScriptBuilderSession(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) team_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL"), diff --git a/backend/app/models/script_template.py b/backend/app/models/script_template.py index 838d2f3c..3624f031 100644 --- a/backend/app/models/script_template.py +++ b/backend/app/models/script_template.py @@ -44,6 +44,12 @@ class ScriptTemplate(Base): team_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=True, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) @@ -97,6 +103,12 @@ class ScriptGeneration(Base): user_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) team_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL"), nullable=True, index=True ) diff --git a/backend/app/models/session.py b/backend/app/models/session.py index c191572b..5bcd6241 100644 --- a/backend/app/models/session.py +++ b/backend/app/models/session.py @@ -31,6 +31,12 @@ class Session(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list) decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list) diff --git a/backend/app/models/session_branch.py b/backend/app/models/session_branch.py index ab6cc50e..e3716806 100644 --- a/backend/app/models/session_branch.py +++ b/backend/app/models/session_branch.py @@ -35,6 +35,12 @@ class SessionBranch(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) parent_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=True) fork_point_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True) branch_order: Mapped[int] = mapped_column(Integer, nullable=False, default=1) diff --git a/backend/app/models/session_handoff.py b/backend/app/models/session_handoff.py index 0fd53128..1b44df56 100644 --- a/backend/app/models/session_handoff.py +++ b/backend/app/models/session_handoff.py @@ -27,6 +27,12 @@ class SessionHandoff(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) handed_off_by: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) intent: Mapped[str] = mapped_column(String(20), nullable=False) source_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="SET NULL"), nullable=True) diff --git a/backend/app/models/session_resolution_output.py b/backend/app/models/session_resolution_output.py index cb56fa42..3ae32549 100644 --- a/backend/app/models/session_resolution_output.py +++ b/backend/app/models/session_resolution_output.py @@ -23,6 +23,12 @@ class SessionResolutionOutput(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) output_type: Mapped[str] = mapped_column(String(30), nullable=False) generated_content: Mapped[str] = mapped_column(Text, nullable=False) structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}") diff --git a/backend/app/models/step_category.py b/backend/app/models/step_category.py index da207926..73b2e17b 100644 --- a/backend/app/models/step_category.py +++ b/backend/app/models/step_category.py @@ -38,10 +38,10 @@ class StepCategory(Base): nullable=True, index=True ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, index=True ) display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) diff --git a/backend/app/models/step_library.py b/backend/app/models/step_library.py index e93c1f75..3c0b35ae 100644 --- a/backend/app/models/step_library.py +++ b/backend/app/models/step_library.py @@ -46,10 +46,10 @@ class StepLibrary(Base): ForeignKey("teams.id", ondelete="CASCADE"), nullable=True ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, index=True ) @@ -143,6 +143,13 @@ class StepRating(Base): ForeignKey("users.id", ondelete="CASCADE"), nullable=False ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Account of the RATER (not the step owner).", + ) rating: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) was_helpful: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) review_text: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) @@ -187,6 +194,13 @@ class StepUsageLog(Base): ForeignKey("users.id", ondelete="CASCADE"), nullable=False ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Account of the user who logged this usage.", + ) session_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), diff --git a/backend/app/models/supporting_data.py b/backend/app/models/supporting_data.py index ea04cd91..d69f66e2 100644 --- a/backend/app/models/supporting_data.py +++ b/backend/app/models/supporting_data.py @@ -14,6 +14,12 @@ class SessionSupportingData(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) label: Mapped[str] = mapped_column(String(255), nullable=False) data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/backend/app/models/tag.py b/backend/app/models/tag.py index 5152c3a9..7bb758d1 100644 --- a/backend/app/models/tag.py +++ b/backend/app/models/tag.py @@ -51,10 +51,10 @@ class TreeTag(Base): nullable=True, index=True ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, index=True ) usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) diff --git a/backend/app/models/target_list.py b/backend/app/models/target_list.py index f2dbd7ac..b1169d72 100644 --- a/backend/app/models/target_list.py +++ b/backend/app/models/target_list.py @@ -9,6 +9,7 @@ from app.core.database import Base if TYPE_CHECKING: from app.models.user import User from app.models.team import Team + from app.models.account import Account class TargetList(Base): @@ -21,6 +22,12 @@ class TargetList(Base): UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) created_by: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) diff --git a/backend/app/models/template_tree.py b/backend/app/models/template_tree.py new file mode 100644 index 00000000..e67f70ec --- /dev/null +++ b/backend/app/models/template_tree.py @@ -0,0 +1,40 @@ +"""Template tree model — platform-owned troubleshooting trees, readable by all users. + +No account_id. No RLS. Readable by any authenticated user. +Populated by promoting is_default=TRUE trees from the trees table. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class TemplateTree(Base): + __tablename__ = "template_trees" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + tree_type: Mapped[str] = mapped_column(String(20), nullable=False, index=True) + tree_structure: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + tags: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + source_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("trees.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/models/tree.py b/backend/app/models/tree.py index 3557a158..8f0c5e8c 100644 --- a/backend/app/models/tree.py +++ b/backend/app/models/tree.py @@ -76,10 +76,10 @@ class Tree(Base): nullable=True, index=True ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, index=True ) is_active: Mapped[bool] = mapped_column(Boolean, default=True) diff --git a/backend/app/models/tree_embedding.py b/backend/app/models/tree_embedding.py index 6fba4466..064ccf07 100644 --- a/backend/app/models/tree_embedding.py +++ b/backend/app/models/tree_embedding.py @@ -37,10 +37,10 @@ class TreeEmbedding(Base): ForeignKey("trees.id", ondelete="CASCADE"), nullable=False, ) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), - nullable=True, + nullable=False, ) chunk_type: Mapped[str] = mapped_column( String(30), diff --git a/backend/app/models/user.py b/backend/app/models/user.py index c7d566d7..e1274183 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -43,10 +43,10 @@ class User(Base): must_change_password: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") # Account-based multi-tenancy (new) - account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + account_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="RESTRICT"), - nullable=True, + nullable=False, index=True ) account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer") diff --git a/backend/app/models/user_pinned_tree.py b/backend/app/models/user_pinned_tree.py index c27edd08..d23b463a 100644 --- a/backend/app/models/user_pinned_tree.py +++ b/backend/app/models/user_pinned_tree.py @@ -24,6 +24,12 @@ class UserPinnedTree(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) tree_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("trees.id", ondelete="CASCADE"), diff --git a/backend/tests/test_phase1_migrations.py b/backend/tests/test_phase1_migrations.py new file mode 100644 index 00000000..eefeba17 --- /dev/null +++ b/backend/tests/test_phase1_migrations.py @@ -0,0 +1,545 @@ +"""Phase 1 migration tests — verify account_id backfill correctness. + +These tests create objects via ORM (which uses the updated models), +then verify account_id is populated correctly. They run against a +real PostgreSQL test DB (same as all other integration tests). +""" +import pytest +import uuid +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, text + +from app.models.account import Account +from app.models.user import User +from app.models.tree import Tree +from app.models.session import Session +from app.models.attachment import Attachment +from app.models.supporting_data import SessionSupportingData +from app.models.session_resolution_output import SessionResolutionOutput +from app.models.ai_session import AISession +from app.core.security import get_password_hash + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]: + account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8]) + db.add(account) + await db.flush() + user = User( + email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", + name=f"User {suffix}", + password_hash=get_password_hash("TestPass123!"), + is_active=True, + account_id=account.id, + account_role="engineer", + ) + db.add(user) + await db.flush() + return account, user + + +async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree: + tree = Tree( + name=f"Tree {uuid.uuid4().hex[:6]}", + account_id=account.id, + author_id=user.id, + visibility="team", + tree_type="troubleshooting", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + db.add(tree) + await db.flush() + return tree + + +async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session: + s = Session( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + tree_snapshot={}, + ) + db.add(s) + await db.flush() + return s + + +# ── Group 1: Core sessions ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_account_id_matches_user(test_db: AsyncSession): + """sessions.account_id must equal the user's account_id.""" + account, user = await _make_account_and_user(test_db, "s1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + await test_db.commit() + + result = await test_db.execute(select(Session).where(Session.id == session.id)) + row = result.scalar_one() + assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}" + + +@pytest.mark.asyncio +async def test_attachment_account_id_matches_session(test_db: AsyncSession): + """attachments.account_id must match the parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "att1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + attachment = Attachment( + session_id=session.id, + account_id=account.id, + file_name="test.png", + file_type="image/png", + ) + test_db.add(attachment) + await test_db.commit() + + result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_session_supporting_data_account_id(test_db: AsyncSession): + """session_supporting_data.account_id must match parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "sd1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + sd = SessionSupportingData( + session_id=session.id, + account_id=account.id, + label="Log snippet", + data_type="text_snippet", + content="error: connection refused", + ) + test_db.add(sd) + await test_db.commit() + + result = await test_db.execute( + select(SessionSupportingData).where(SessionSupportingData.id == sd.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_session_resolution_output_account_id(test_db: AsyncSession): + """session_resolution_outputs.account_id must match the parent ai_session's account_id. + + NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions). + """ + account, user = await _make_account_and_user(test_db, "sro1") + + ai_session = AISession( + user_id=user.id, + account_id=account.id, + problem_summary="test resolution output", + problem_domain="networking", + status="active", + ) + test_db.add(ai_session) + await test_db.flush() + + output = SessionResolutionOutput( + session_id=ai_session.id, + account_id=account.id, + output_type="psa_ticket_notes", + generated_content="Ticket notes content", + generated_by_model="gpt-4", + ) + test_db.add(output) + await test_db.commit() + + result = await test_db.execute( + select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 2: AI & branching ─────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession): + """session_branches.account_id must match parent ai_session.account_id.""" + from app.models.session_branch import SessionBranch + + account, user = await _make_account_and_user(test_db, "sb1") + ai_session = AISession( + user_id=user.id, + account_id=account.id, + problem_summary="test", + problem_domain="networking", + status="active", + ) + test_db.add(ai_session) + await test_db.flush() + + branch = SessionBranch( + session_id=ai_session.id, + account_id=account.id, + label="Branch A", + branch_order=1, + conversation_messages=[], + ) + test_db.add(branch) + await test_db.commit() + + result = await test_db.execute( + select(SessionBranch).where(SessionBranch.id == branch.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession): + """ai_suggestions.account_id must match the creating user's account_id.""" + from app.models.ai_suggestion import AISuggestion + + account, user = await _make_account_and_user(test_db, "ais1") + tree = await _make_tree(test_db, account, user) + + suggestion = AISuggestion( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + action_type="add_node", + changes_json={}, + status="pending", + ) + test_db.add(suggestion) + await test_db.commit() + + result = await test_db.execute( + select(AISuggestion).where(AISuggestion.id == suggestion.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 3: Steps & ratings ────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession): + """step_ratings.account_id must be the RATER's account, not the step's account.""" + from app.models.step_library import StepLibrary, StepRating + + account_a, user_a = await _make_account_and_user(test_db, "sr-rater") + account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner") + + # Step owned by account_b + step = StepLibrary( + title="A step", + step_type="action", + content={"text": "do something"}, + created_by=user_b.id, + account_id=account_b.id, + visibility="public", + ) + test_db.add(step) + await test_db.flush() + + # user_a (account_a) rates the step + rating = StepRating( + step_id=step.id, + user_id=user_a.id, + account_id=account_a.id, # rater's account, not step owner's + was_helpful=True, + is_verified_use=False, + is_visible=True, + ) + test_db.add(rating) + await test_db.commit() + + result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id)) + row = result.scalar_one() + assert row.account_id == account_a.id, ( + f"account_id should be rater's account ({account_a.id}), got {row.account_id}" + ) + + +@pytest.mark.asyncio +async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession): + """step_usage_log.account_id must be the LOGGER's account (user who used the step).""" + from app.models.step_library import StepLibrary, StepUsageLog + + account, user = await _make_account_and_user(test_db, "sul1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + step = StepLibrary( + title="A usage step", + step_type="action", + content={"text": "do something"}, + created_by=user.id, + account_id=account.id, + visibility="team", + ) + test_db.add(step) + await test_db.flush() + + log = StepUsageLog( + step_id=step.id, + user_id=user.id, + account_id=account.id, + session_id=session.id, + ) + test_db.add(log) + await test_db.commit() + + result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id)) + row = result.scalar_one() + assert row.account_id == account.id, ( + f"account_id should be logger's account ({account.id}), got {row.account_id}" + ) + + +# ── Group 4: User personalization ──────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_user_folder_account_id_matches_user(test_db: AsyncSession): + """user_folders.account_id must match the owning user's account_id.""" + from app.models.folder import UserFolder + + account, user = await _make_account_and_user(test_db, "uf1") + folder = UserFolder( + user_id=user.id, + account_id=account.id, + name="My Folder", + color="#6366f1", + icon="folder", + display_order=0, + ) + test_db.add(folder) + await test_db.commit() + + result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession): + """user_pinned_trees.account_id must match the pinning user's account_id.""" + from app.models.user_pinned_tree import UserPinnedTree + + account, user = await _make_account_and_user(test_db, "pt1") + tree = await _make_tree(test_db, account, user) + pin = UserPinnedTree( + user_id=user.id, + tree_id=tree.id, + account_id=account.id, + display_order=0, + ) + test_db.add(pin) + await test_db.commit() + + result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 5: PSA & notifications ───────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession): + """psa_member_mappings.account_id must match psa_connection's account_id.""" + from app.models.psa_connection import PsaConnection + from app.models.psa_member_mapping import PsaMemberMapping + + account, user = await _make_account_and_user(test_db, "psa1") + conn = PsaConnection( + account_id=account.id, + provider="connectwise", + display_name="Test CW", + site_url="https://cw.example.com", + company_id="TEST", + credentials_encrypted="placeholder", + ) + test_db.add(conn) + await test_db.flush() + + mapping = PsaMemberMapping( + psa_connection_id=conn.id, + user_id=user.id, + account_id=account.id, + external_member_id="cw-123", + external_member_name="Test User", + matched_by="manual_admin", + ) + test_db.add(mapping) + await test_db.commit() + + result = await test_db.execute( + select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 6: Maintenance ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession): + """maintenance_schedules.account_id must match the tree's account_id.""" + from app.models.maintenance_schedule import MaintenanceSchedule + + account, user = await _make_account_and_user(test_db, "ms1") + tree = Tree( + name="Maintenance Flow", + account_id=account.id, + author_id=user.id, + visibility="team", + tree_type="maintenance", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + test_db.add(tree) + await test_db.flush() + + schedule = MaintenanceSchedule( + tree_id=tree.id, + account_id=account.id, + created_by=user.id, + cron_expression="0 9 * * 1", + timezone="UTC", + is_active=True, + ) + test_db.add(schedule) + await test_db.commit() + + result = await test_db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 7: Legacy team_id tables ─────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_script_builder_session_account_id(test_db: AsyncSession): + """script_builder_sessions.account_id must match user's account_id.""" + from app.models.script_builder_session import ScriptBuilderSession + + account, user = await _make_account_and_user(test_db, "sbs1") + sbs = ScriptBuilderSession( + user_id=user.id, + account_id=account.id, + language="powershell", + ) + test_db.add(sbs) + await test_db.commit() + + result = await test_db.execute( + select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 8: TargetList ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_target_list_account_id_from_team_admin(test_db: AsyncSession): + """target_lists.account_id must be set to the team admin's account_id.""" + from app.models.target_list import TargetList + from app.models.team import Team + + account, user = await _make_account_and_user(test_db, "tl1") + # Make user a team admin + team = Team(name=f"Team {uuid.uuid4().hex[:6]}") + test_db.add(team) + await test_db.flush() + + user.team_id = team.id + user.is_team_admin = True + await test_db.flush() + + target_list = TargetList( + team_id=team.id, + account_id=account.id, + created_by=user.id, + name="Server Targets", + targets=[{"label": "SRV-01"}], + ) + test_db.add(target_list) + await test_db.commit() + + result = await test_db.execute( + select(TargetList).where(TargetList.id == target_list.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +# ── Group 10 (runs first): Global content tables ────────────────────────────── + +@pytest.mark.asyncio +async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession): + """template_trees must exist and must NOT have an account_id column.""" + result = await test_db.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'template_trees' + """)) + columns = {row[0] for row in result.fetchall()} + assert 'id' in columns, "template_trees.id must exist" + assert 'account_id' not in columns, "template_trees must not have account_id (global content)" + + +@pytest.mark.asyncio +async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession): + """platform_steps must exist and must NOT have an account_id column.""" + result = await test_db.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'platform_steps' + """)) + columns = {row[0] for row in result.fetchall()} + assert 'id' in columns, "platform_steps.id must exist" + assert 'account_id' not in columns, "platform_steps must not have account_id (global content)" + + +# ── Group 9: SET NOT NULL on existing nullable columns ──────────────────────── + +@pytest.mark.asyncio +async def test_tree_account_id_is_not_null(test_db: AsyncSession): + """trees.account_id must be NOT NULL after Phase 1 — enforced at DB level.""" + from sqlalchemy.exc import IntegrityError + with pytest.raises(IntegrityError): + test_db.add(Tree( + name="Bad tree", + # account_id intentionally omitted + author_id=None, + visibility="private", + tree_type="troubleshooting", + tree_structure={}, + is_active=True, + status="draft", + )) + await test_db.flush() + + +@pytest.mark.asyncio +async def test_user_account_id_is_not_null(test_db: AsyncSession): + """users.account_id must be NOT NULL after Phase 1.""" + from sqlalchemy.exc import IntegrityError + with pytest.raises(IntegrityError): + test_db.add(User( + email=f"orphan-{uuid.uuid4().hex[:6]}@example.com", + name="Orphan", + password_hash=get_password_hash("x"), + is_active=True, + role="engineer", + account_role="engineer", + # account_id intentionally omitted + )) + await test_db.flush() diff --git a/backend/tests/test_rls_isolation.py b/backend/tests/test_rls_isolation.py new file mode 100644 index 00000000..5d6572e2 --- /dev/null +++ b/backend/tests/test_rls_isolation.py @@ -0,0 +1,266 @@ +# backend/tests/test_rls_isolation.py +""" +RLS foundation tests. + +Connect directly as resolutionflow_app (not superuser) and verify: + - Tenant A cannot read Tenant B's rows + - No tenant context set → zero rows for private data (fail-closed) + - Platform rows (PLATFORM_ACCOUNT_ID) are visible to all tenants + +Tests bypass FastAPI entirely — raw asyncpg connections only. +MUST FAIL before Task 10 (RLS migration) and PASS after it. + +Run with: + DB_APP_ROLE_PASSWORD=app_secret_change_me pytest tests/test_rls_isolation.py -v + +The test DB is patherly_test (matches conftest.py default). +""" +import os +import uuid + +import asyncpg +import pytest + +_DB_HOST = os.getenv("TEST_DB_HOST", "localhost") +_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432")) +_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py +_APP_PASSWORD = os.getenv("DB_APP_ROLE_PASSWORD", "app_secret_change_me") +_ADMIN_DSN = f"postgresql://postgres:postgres@{_DB_HOST}:{_DB_PORT}/{_DB_NAME}" + +PLATFORM_ACCOUNT_ID = "00000000-0000-0000-0000-000000000001" +ACCOUNT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" +ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +async def admin_conn(): + """Superuser asyncpg connection for fixture setup and teardown.""" + conn = await asyncpg.connect(_ADMIN_DSN) + yield conn + await conn.close() + + +@pytest.fixture(scope="module", autouse=True) +async def seed_rls_test_data(admin_conn): + """ + Create two isolated test accounts, one user per account, and one private + tree per account. Trees require a valid author_id FK to users, so users + must be created first. + + accounts.display_code must be unique and 8 chars (NOT NULL constraint). + """ + # Insert accounts + await admin_conn.execute(f""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES + ('{ACCOUNT_A_ID}', 'RLS Tenant A', 'RLSA0001', NOW(), NOW()), + ('{ACCOUNT_B_ID}', 'RLS Tenant B', 'RLSB0001', NOW(), NOW()) + ON CONFLICT (id) DO NOTHING + """) + + # Insert one user per account (users.account_id NOT NULL, password_hash NOT NULL) + user_a_id = str(uuid.uuid4()) + user_b_id = str(uuid.uuid4()) + await admin_conn.execute(f""" + INSERT INTO users ( + id, email, password_hash, name, role, is_active, account_id, + account_role, created_at + ) VALUES + ('{user_a_id}', 'rls-user-a@example.com', + 'placeholder', 'RLS User A', 'engineer', TRUE, + '{ACCOUNT_A_ID}', 'engineer', NOW()), + ('{user_b_id}', 'rls-user-b@example.com', + 'placeholder', 'RLS User B', 'engineer', TRUE, + '{ACCOUNT_B_ID}', 'engineer', NOW()) + ON CONFLICT (email) DO NOTHING + """) + + # Look up the user IDs we just inserted (ON CONFLICT may have skipped) + row_a = await admin_conn.fetchrow( + "SELECT id FROM users WHERE email = 'rls-user-a@example.com'" + ) + row_b = await admin_conn.fetchrow( + "SELECT id FROM users WHERE email = 'rls-user-b@example.com'" + ) + actual_user_a = str(row_a["id"]) + actual_user_b = str(row_b["id"]) + + # Insert one private tree per account with explicit author_id + await admin_conn.execute(f""" + INSERT INTO trees ( + id, name, tree_structure, account_id, author_id, is_active, is_default, + is_public, visibility, tree_type, created_at, updated_at + ) VALUES + (gen_random_uuid(), 'RLS Tree A', '[]'::jsonb, '{ACCOUNT_A_ID}', '{actual_user_a}', + TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW()), + (gen_random_uuid(), 'RLS Tree B', '[]'::jsonb, '{ACCOUNT_B_ID}', '{actual_user_b}', + TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW()) + """) + + # One platform-owned tree_tag (global, visible to all tenants) + await admin_conn.execute(f""" + INSERT INTO tree_tags ( + id, name, slug, account_id, usage_count, created_at + ) VALUES ( + gen_random_uuid(), 'rls-global-tag', 'rls-global-tag', + '{PLATFORM_ACCOUNT_ID}', 0, NOW() + ) ON CONFLICT DO NOTHING + """) + + yield + + # Cleanup + await admin_conn.execute( + f"DELETE FROM trees WHERE account_id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')" + ) + await admin_conn.execute( + "DELETE FROM users WHERE email IN " + "('rls-user-a@example.com', 'rls-user-b@example.com')" + ) + await admin_conn.execute( + f"DELETE FROM accounts WHERE id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')" + ) + await admin_conn.execute("DELETE FROM tree_tags WHERE slug = 'rls-global-tag'") + + +@pytest.fixture +async def conn_a(): + """App-role connection, tenant context = Account A.""" + conn = await asyncpg.connect( + host=_DB_HOST, port=_DB_PORT, database=_DB_NAME, + user="resolutionflow_app", password=_APP_PASSWORD, + ) + await conn.execute( + "SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_A_ID + ) + yield conn + await conn.close() + + +@pytest.fixture +async def conn_b(): + """App-role connection, tenant context = Account B.""" + conn = await asyncpg.connect( + host=_DB_HOST, port=_DB_PORT, database=_DB_NAME, + user="resolutionflow_app", password=_APP_PASSWORD, + ) + await conn.execute( + "SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_B_ID + ) + yield conn + await conn.close() + + +@pytest.fixture +async def conn_no_context(): + """App-role connection with NO tenant context set.""" + conn = await asyncpg.connect( + host=_DB_HOST, port=_DB_PORT, database=_DB_NAME, + user="resolutionflow_app", password=_APP_PASSWORD, + ) + yield conn + await conn.close() + + +# --------------------------------------------------------------------------- +# trees +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_trees_account_a_cannot_see_account_b_rows(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0, "Account A should not see Account B trees" + + +@pytest.mark.asyncio +async def test_trees_account_a_can_see_own_rows(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'" + ) + assert len(rows) >= 1, "Account A should see its own trees" + + +@pytest.mark.asyncio +async def test_trees_no_context_sees_no_private_trees(conn_no_context): + rows = await conn_no_context.fetch( + "SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE" + ) + assert len(rows) == 0, "No-context connection should see no private trees" + + +# --------------------------------------------------------------------------- +# tree_tags — platform visibility +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0 + + +@pytest.mark.asyncio +async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b): + rows_a = await conn_a.fetch( + f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'" + ) + rows_b = await conn_b.fetch( + f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'" + ) + assert len(rows_a) >= 1, "Account A should see platform tags" + assert len(rows_b) >= 1, "Account B should see platform tags" + + +# --------------------------------------------------------------------------- +# tree_categories — platform visibility +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tree_categories_account_a_cannot_see_account_b(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0 + + +# --------------------------------------------------------------------------- +# step_categories — platform visibility +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_step_categories_account_a_cannot_see_account_b(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0 + + +# --------------------------------------------------------------------------- +# psa_connections — tenant-only +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_psa_connections_account_a_cannot_see_account_b(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0 + + +# --------------------------------------------------------------------------- +# flow_proposals — tenant-only +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_flow_proposals_account_a_cannot_see_account_b(conn_a): + rows = await conn_a.fetch( + f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'" + ) + assert len(rows) == 0 diff --git a/backend/tests/test_script_templates.py b/backend/tests/test_script_templates.py index 868bf10e..ae7501dd 100644 --- a/backend/tests/test_script_templates.py +++ b/backend/tests/test_script_templates.py @@ -1,4 +1,6 @@ """Integration tests for Script Template Editor permissions and share endpoint.""" +from uuid import UUID as PyUUID + import pytest from httpx import AsyncClient from sqlalchemy import select @@ -65,6 +67,9 @@ class TestScriptTemplatePermissions: data = resp.json() assert data["name"] == "Test Template" assert data["created_by"] is not None + result = await test_db.execute(select(ScriptTemplate).where(ScriptTemplate.id == PyUUID(data["id"]))) + template = result.scalar_one() + assert template.account_id is not None @pytest.mark.asyncio async def test_engineer_can_edit_own_template(self, client, auth_headers, test_db): diff --git a/backend/tests/test_scripts.py b/backend/tests/test_scripts.py index eb31c79f..cf17f9a4 100644 --- a/backend/tests/test_scripts.py +++ b/backend/tests/test_scripts.py @@ -6,14 +6,18 @@ from datetime import datetime, timezone import pytest import sqlalchemy as sa +from app.models.script_template import ScriptGeneration +from app.models.user import User # ── Fixtures ────────────────────────────────────────────────────────────── @pytest.fixture -async def seed_script_data(test_db): +async def seed_script_data(test_db, test_user): """Seed script categories and templates into the test database.""" now = datetime.now(timezone.utc) cat_id = uuid.UUID("00000000-0000-0000-0000-000000000001") + user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"])) + user = user_result.scalar_one() # Insert category await test_db.execute( @@ -142,20 +146,20 @@ async def seed_script_data(test_db): await test_db.execute( sa.text(""" INSERT INTO script_templates ( - id, category_id, name, slug, description, + id, category_id, account_id, name, slug, description, script_body, parameters_schema, default_values, validation_rules, tags, complexity, estimated_runtime, requires_elevation, requires_modules, version, is_verified, is_active, usage_count, created_at, updated_at ) VALUES ( - :id, :category_id, :name, :slug, :description, + :id, :category_id, :account_id, :name, :slug, :description, :script_body, CAST(:parameters_schema AS jsonb), '{}'::jsonb, '{}'::jsonb, CAST(:tags AS jsonb), :complexity, :estimated_runtime, :requires_elevation, '[]'::jsonb, 1, true, true, 0, :now, :now ) """), - {**tmpl, "category_id": cat_id, "now": now}, + {**tmpl, "category_id": cat_id, "account_id": user.account_id, "now": now}, ) await test_db.commit() @@ -245,7 +249,7 @@ async def test_get_template_detail_not_found(client, auth_headers): # ── Generate ────────────────────────────────────────────────────────────── @pytest.mark.asyncio -async def test_generate_script_success(client, auth_headers, seed_script_data): +async def test_generate_script_success(client, auth_headers, seed_script_data, test_db, test_user): list_resp = await client.get( "/api/v1/scripts/templates?search=unlock", headers=auth_headers, @@ -265,6 +269,13 @@ async def test_generate_script_success(client, auth_headers, seed_script_data): assert "script" in data assert "jsmith" in data["script"] assert "id" in data + generation_result = await test_db.execute( + sa.select(ScriptGeneration).where(ScriptGeneration.id == uuid.UUID(data["id"])) + ) + generation = generation_result.scalar_one() + user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"])) + user = user_result.scalar_one() + assert generation.account_id == user.account_id @pytest.mark.asyncio diff --git a/backend/tests/test_tenant_context.py b/backend/tests/test_tenant_context.py new file mode 100644 index 00000000..e4ad183e --- /dev/null +++ b/backend/tests/test_tenant_context.py @@ -0,0 +1,58 @@ +import asyncio +from uuid import UUID +import pytest +from unittest.mock import MagicMock + +from app.core.tenant_context import ( + set_current_account_id, + clear_current_account_id, + get_current_account_id, +) + + +def test_contextvar_is_none_by_default(): + assert get_current_account_id() is None + + +def test_set_and_clear(): + account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001") + token = set_current_account_id(account_id) + assert get_current_account_id() == str(account_id) + clear_current_account_id(token) + assert get_current_account_id() is None + + +def test_tasks_are_isolated(): + """Each asyncio task has its own ContextVar value.""" + results = {} + + async def set_in_task(name: str, value: str): + token = set_current_account_id(UUID(value)) + await asyncio.sleep(0) + results[name] = get_current_account_id() + clear_current_account_id(token) + + async def run(): + await asyncio.gather( + set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"), + set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"), + ) + + asyncio.run(run()) + assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001" + assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002" + + +@pytest.mark.asyncio +async def test_require_tenant_context_raises_403_when_no_account(): + from fastapi import HTTPException + from app.api.deps import require_tenant_context + + user = MagicMock() + user.account_id = None + + gen = require_tenant_context(current_user=user) + with pytest.raises(HTTPException) as exc_info: + await gen.__anext__() + assert exc_info.value.status_code == 403 + assert "account required" in exc_info.value.detail.lower() diff --git a/backend/tests/test_trees.py b/backend/tests/test_trees.py index 300a50f6..8a79c6fc 100644 --- a/backend/tests/test_trees.py +++ b/backend/tests/test_trees.py @@ -447,3 +447,55 @@ class TestVisibilityFilter: assert "author_name" in trees[0] # visibility key should be present assert "visibility" in trees[0] + + @pytest.mark.asyncio + async def test_get_tree_returns_404_not_403_for_other_account_tree( + self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession + ): + """Account A must not learn that Account B's private tree exists.""" + from app.models.tree import Tree + from app.models.account import Account + from app.models.user import User + from app.core.security import get_password_hash + import uuid + + # Create a second account and user + account_b = Account(name="Other Corp", display_code="OTH00001") + test_db.add(account_b) + await test_db.flush() + + user_b = User( + email=f"user-b-{uuid.uuid4().hex[:6]}@example.com", + name="User B", + password_hash=get_password_hash("TestPass123!"), + is_active=True, + account_id=account_b.id, + account_role="engineer", + ) + test_db.add(user_b) + await test_db.flush() + + # Create a private tree belonging to account_b + private_tree = Tree( + name="Secret Tree", + account_id=account_b.id, + author_id=user_b.id, + visibility="private", + tree_type="troubleshooting", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + is_default=False, + is_public=False, + status="published", + ) + test_db.add(private_tree) + await test_db.commit() + + response = await client.get( + f"/api/v1/trees/{private_tree.id}", + headers=auth_headers, + ) + assert response.status_code == 404, ( + f"Expected 404 but got {response.status_code} — " + "leaking tree existence to wrong tenant" + ) diff --git a/docs/superpowers/plans/2026-04-09-tenant-isolation-phase-1.md b/docs/superpowers/plans/2026-04-09-tenant-isolation-phase-1.md new file mode 100644 index 00000000..86d4138a --- /dev/null +++ b/docs/superpowers/plans/2026-04-09-tenant-isolation-phase-1.md @@ -0,0 +1,2527 @@ +# Tenant Isolation — Phase 1 Schema Migrations + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add `account_id` to every tenant table that lacks it, backfill from existing FK chains, enforce NOT NULL, and create the global content tables (`template_trees`, `platform_steps`) that replace the legacy `is_default`/`visibility='public'` patterns. + +**Architecture:** Each task is one Alembic migration file covering one logical domain group. Every migration follows the non-negotiable sequence: ADD nullable → backfill → verify zero NULLs → SET NOT NULL → CREATE INDEX. Any migration that cannot zero-out NULLs at step 3 must roll back in full — no partial state. RLS is NOT enabled in this phase. `get_db()` is NOT modified. Schema only. + +**Tech Stack:** Python 3.11 · FastAPI · SQLAlchemy 2.0 async · Alembic · PostgreSQL 16 · pytest-asyncio + +**Spec:** `docs/superpowers/specs/2026-04-09-tenant-data-isolation-design.md` + +**Prerequisite:** Phase 0 merged to `main` (PRs #131 + #132 ✓). Alembic current head: `b8d2f4a6c091`. + +**Task ordering note:** Task 9 (global content separation) runs before Task 10 (SET NOT NULL on trees/categories/tags/steps). This is a dependency: `is_default=TRUE` trees have `account_id=NULL` and cannot satisfy the zero-NULL check until they are moved to `template_trees`. + +--- + +## File Map + +| File | Action | Group | +|---|---|---| +| `backend/alembic/versions/_add_account_id_core_sessions.py` | Create | 1 | +| `backend/alembic/versions/_add_account_id_ai_branching.py` | Create | 2 | +| `backend/alembic/versions/_add_account_id_step_ratings.py` | Create | 3 | +| `backend/alembic/versions/_add_account_id_user_personalization.py` | Create | 4 | +| `backend/alembic/versions/_add_account_id_psa_notifications.py` | Create | 5 | +| `backend/alembic/versions/_add_account_id_maintenance.py` | Create | 6 | +| `backend/alembic/versions/_add_account_id_script_tables.py` | Create | 7 | +| `backend/alembic/versions/_add_account_id_target_lists.py` | Create | 8 | +| `backend/alembic/versions/_create_global_content_tables.py` | Create | 9 | +| `backend/alembic/versions/_set_not_null_account_id_phase1.py` | Create | 10 | +| `backend/app/models/session.py` | Modify | 1 | +| `backend/app/models/attachment.py` | Modify | 1 | +| `backend/app/models/supporting_data.py` | Modify | 1 | +| `backend/app/models/session_resolution_output.py` | Modify | 1 | +| `backend/app/models/session_branch.py` | Modify | 2 | +| `backend/app/models/session_handoff.py` | Modify | 2 | +| `backend/app/models/fork_point.py` | Modify | 2 | +| `backend/app/models/ai_session_step.py` | Modify | 2 | +| `backend/app/models/ai_suggestion.py` | Modify | 2 | +| `backend/app/models/step_library.py` | Modify | 3 (StepRating, StepUsageLog) | +| `backend/app/models/folder.py` | Modify | 4 | +| `backend/app/models/user_pinned_tree.py` | Modify | 4 | +| `backend/app/models/psa_post_log.py` | Modify | 5 | +| `backend/app/models/psa_member_mapping.py` | Modify | 5 | +| `backend/app/models/notification_log.py` | Modify | 5 | +| `backend/app/models/maintenance_schedule.py` | Modify | 6 | +| `backend/app/models/script_builder_session.py` | Modify | 7 | +| `backend/app/models/script_template.py` | Modify | 7 (ScriptTemplate, ScriptGeneration) | +| `backend/app/models/target_list.py` | Modify | 8 | +| `backend/app/models/template_tree.py` | Create | 9 | +| `backend/app/models/platform_step.py` | Create | 9 | +| `backend/app/models/user.py` | Modify | 10 | +| `backend/app/models/tree.py` | Modify | 10 | +| `backend/app/models/category.py` | Modify | 10 | +| `backend/app/models/tag.py` | Modify | 10 | +| `backend/app/models/step_category.py` | Modify | 10 | +| `backend/app/models/step_library.py` | Modify | 10 (StepLibrary account_id NOT NULL) | +| `backend/app/models/tree_embedding.py` | Modify | 10 | +| `backend/app/models/feedback.py` | Modify | 10 | +| `backend/tests/test_phase1_migrations.py` | Create | all tasks | + +--- + +## Task 1: Group 1 — Core sessions + +**Tables:** `sessions`, `attachments`, `session_supporting_data`, `session_resolution_outputs` + +**Backfill paths:** +- `sessions`: `sessions.user_id → users.account_id` +- `attachments`: `attachments.session_id → sessions.account_id` (chain — sessions must be backfilled first in same migration) +- `session_supporting_data`: same chain as attachments +- `session_resolution_outputs`: `session_resolution_outputs.session_id → ai_sessions.account_id` (FK is to `ai_sessions`, not `sessions`) + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_core_sessions.py` +- Modify: `backend/app/models/session.py`, `attachment.py`, `supporting_data.py`, `session_resolution_output.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 1.1: Create the branch** + +```bash +git checkout main && git pull origin main +git checkout -b feat/tenant-isolation-phase-1 +``` + +- [ ] **Step 1.2: Write the failing test** + +Create `backend/tests/test_phase1_migrations.py`: + +```python +"""Phase 1 migration tests — verify account_id backfill correctness. + +These tests create objects via ORM (which uses the updated models), +then verify account_id is populated correctly. They run against a +real PostgreSQL test DB (same as all other integration tests). +""" +import pytest +import uuid +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, text + +from app.models.account import Account +from app.models.user import User +from app.models.tree import Tree +from app.models.session import Session +from app.models.attachment import Attachment +from app.models.supporting_data import SessionSupportingData +from app.models.session_resolution_output import SessionResolutionOutput +from app.models.ai_session import AISession +from app.core.security import get_password_hash + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]: + account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8]) + db.add(account) + await db.flush() + user = User( + email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", + name=f"User {suffix}", + password_hash=get_password_hash("TestPass123!"), + is_active=True, + account_id=account.id, + account_role="engineer", + ) + db.add(user) + await db.flush() + return account, user + + +async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree: + tree = Tree( + name=f"Tree {uuid.uuid4().hex[:6]}", + account_id=account.id, + author_id=user.id, + visibility="team", + tree_type="troubleshooting", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + db.add(tree) + await db.flush() + return tree + + +async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session: + s = Session( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + tree_snapshot={}, + ) + db.add(s) + await db.flush() + return s + + +# ── Group 1: Core sessions ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_account_id_matches_user(test_db: AsyncSession): + """sessions.account_id must equal the user's account_id.""" + account, user = await _make_account_and_user(test_db, "s1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + await test_db.commit() + + result = await test_db.execute(select(Session).where(Session.id == session.id)) + row = result.scalar_one() + assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}" + + +@pytest.mark.asyncio +async def test_attachment_account_id_matches_session(test_db: AsyncSession): + """attachments.account_id must match the parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "att1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + attachment = Attachment( + session_id=session.id, + account_id=account.id, + file_name="test.png", + file_type="image/png", + ) + test_db.add(attachment) + await test_db.commit() + + result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_session_supporting_data_account_id(test_db: AsyncSession): + """session_supporting_data.account_id must match parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "sd1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + sd = SessionSupportingData( + session_id=session.id, + account_id=account.id, + label="Log snippet", + data_type="text_snippet", + content="error: connection refused", + ) + test_db.add(sd) + await test_db.commit() + + result = await test_db.execute( + select(SessionSupportingData).where(SessionSupportingData.id == sd.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 1.3: Run test to confirm it fails (model doesn't have account_id yet)** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_session_account_id_matches_user -v --override-ini="addopts=" +``` + +Expected: FAIL — `Session` model has no `account_id` attribute. + +- [ ] **Step 1.4: Generate the Alembic migration file** + +```bash +cd backend && alembic revision -m "add_account_id_core_sessions" +``` + +This prints a path like `alembic/versions/xxxx_add_account_id_core_sessions.py`. Open that file and replace its contents with: + +```python +"""add account_id to core session tables + +Revision ID: +Revises: b8d2f4a6c091 +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = 'b8d2f4a6c091' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Step 1: ADD COLUMN (nullable) ──────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', + table, 'accounts', + ['account_id'], ['id'], + ondelete='CASCADE', + ) + + # ── Step 2: BACKFILL ───────────────────────────────────────────────────── + # sessions: direct join to users + op.execute(""" + UPDATE sessions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # attachments: chain through sessions (now backfilled above) + op.execute(""" + UPDATE attachments a + SET account_id = s.account_id + FROM sessions s + WHERE a.session_id = s.id + AND a.account_id IS NULL + """) + + # session_supporting_data: same chain + op.execute(""" + UPDATE session_supporting_data sd + SET account_id = s.account_id + FROM sessions s + WHERE sd.session_id = s.id + AND sd.account_id IS NULL + """) + + # session_resolution_outputs: FK is to ai_sessions, not sessions + op.execute(""" + UPDATE session_resolution_outputs sro + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE sro.session_id = ai.id + AND sro.account_id IS NULL + """) + + # ── Step 3: VERIFY zero NULLs — raises if any remain ──────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows remain in {table}. " + f"Fix the backfill before re-running." + ) + + # ── Step 4: SET NOT NULL ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.alter_column(table, 'account_id', nullable=False) + + # ── Step 5: CREATE INDEX ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +**Important:** The `down_revision` must be `b8d2f4a6c091` (current head). Do NOT change the auto-generated `revision` value at the top. + +- [ ] **Step 1.5: Run the migration against the test database** + +```bash +cd backend && alembic upgrade head +``` + +Expected: `Running upgrade b8d2f4a6c091 -> , add account_id to core session tables` + +If it errors with "NULL rows remain", investigate the backfill SQL — there are rows whose users have NULL account_id. + +- [ ] **Step 1.6: Verify zero NULLs manually** + +```bash +cd backend && python -c " +import asyncio +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import text +import os + +async def check(): + url = os.environ.get('DATABASE_URL', 'postgresql+asyncpg://postgres:postgres@localhost:5432/resolutionflow_test') + engine = create_async_engine(url) + async with engine.connect() as conn: + for t in ('sessions', 'attachments', 'session_supporting_data', 'session_resolution_outputs'): + r = await conn.execute(text(f'SELECT COUNT(*) FROM {t} WHERE account_id IS NULL')) + print(f'{t}: {r.scalar()} NULLs') + await engine.dispose() + +asyncio.run(check()) +" +``` + +Expected output (all zeros): +``` +sessions: 0 NULLs +attachments: 0 NULLs +session_supporting_data: 0 NULLs +session_resolution_outputs: 0 NULLs +``` + +- [ ] **Step 1.7: Update SQLAlchemy models** + +In `backend/app/models/session.py`, add after the `user_id` column (around line 33): + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +In `backend/app/models/attachment.py`, add after the `session_id` column (around line 22): + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +In `backend/app/models/supporting_data.py`, add after `session_id` (around line 16): + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +In `backend/app/models/session_resolution_output.py`, add after `session_id` (around line 25): + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +Each model also needs `from app.models.account import Account` added if missing from TYPE_CHECKING block. + +- [ ] **Step 1.8: Run tests** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py -k "test_session or test_attachment or test_session_supporting" -v --override-ini="addopts=" +``` + +Expected: all 3 tests PASS. + +- [ ] **Step 1.9: Run full test suite** + +```bash +cd backend && python -m pytest --override-ini="addopts=" +``` + +Expected: all tests pass (no regressions from model changes). + +- [ ] **Step 1.10: Commit** + +```bash +git add backend/alembic/versions/*add_account_id_core_sessions* \ + backend/app/models/session.py \ + backend/app/models/attachment.py \ + backend/app/models/supporting_data.py \ + backend/app/models/session_resolution_output.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 1 — add account_id to core session tables + +Migration sequence: add nullable → backfill via user_id/ai_session chain +→ verify zero NULLs → SET NOT NULL → CREATE INDEX. + +Tables: sessions, attachments, session_supporting_data, + session_resolution_outputs + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 2: Group 2 — AI & branching + +**Tables:** `session_branches`, `session_handoffs`, `fork_points`, `ai_session_steps`, `ai_suggestions` + +**Backfill paths:** +- `session_branches`, `session_handoffs`, `fork_points`, `ai_session_steps`: all have `session_id → ai_sessions.account_id` +- `ai_suggestions`: `user_id → users.account_id` + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_ai_branching.py` +- Modify: `backend/app/models/session_branch.py`, `session_handoff.py`, `fork_point.py`, `ai_session_step.py`, `ai_suggestion.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 2.1: Write the failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 2: AI & branching ─────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession): + """session_branches.account_id must match parent ai_session.account_id.""" + from app.models.session_branch import SessionBranch + + account, user = await _make_account_and_user(test_db, "sb1") + ai_session = AISession( + user_id=user.id, + account_id=account.id, + problem_summary="test", + problem_domain="networking", + status="active", + ) + test_db.add(ai_session) + await test_db.flush() + + branch = SessionBranch( + session_id=ai_session.id, + account_id=account.id, + label="Branch A", + branch_order=1, + conversation_messages=[], + ) + test_db.add(branch) + await test_db.commit() + + result = await test_db.execute( + select(SessionBranch).where(SessionBranch.id == branch.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession): + """ai_suggestions.account_id must match the creating user's account_id.""" + from app.models.ai_suggestion import AISuggestion + + account, user = await _make_account_and_user(test_db, "ais1") + tree = await _make_tree(test_db, account, user) + + suggestion = AISuggestion( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + action_type="add_node", + changes_json={}, + status="pending", + ) + test_db.add(suggestion) + await test_db.commit() + + result = await test_db.execute( + select(AISuggestion).where(AISuggestion.id == suggestion.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 2.2: Run tests to confirm they fail** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_session_branch_account_id_matches_ai_session -v --override-ini="addopts=" +``` + +Expected: FAIL — `SessionBranch` has no `account_id`. + +- [ ] **Step 2.3: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_ai_branching" +``` + +Replace the generated file content with: + +```python +"""add account_id to AI branching tables + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: ADD COLUMN (nullable) + ai_tables = ('session_branches', 'session_handoffs', 'fork_points', + 'ai_session_steps') + for table in ai_tables: + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + op.add_column('ai_suggestions', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_ai_suggestions_account_id', 'ai_suggestions', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Step 2: BACKFILL + # session_branches, session_handoffs, fork_points, ai_session_steps + # all FK to ai_sessions via session_id + for table in ai_tables: + op.execute(f""" + UPDATE {table} t + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE t.session_id = ai.id + AND t.account_id IS NULL + """) + + # ai_suggestions: user_id → users.account_id + op.execute(""" + UPDATE ai_suggestions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # Step 3: VERIFY + for table in ai_tables + ('ai_suggestions',): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows in {table}." + ) + + # Step 4: SET NOT NULL + for table in ai_tables + ('ai_suggestions',): + op.alter_column(table, 'account_id', nullable=False) + + # Step 5: CREATE INDEX + for table in ai_tables + ('ai_suggestions',): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('session_branches', 'session_handoffs', 'fork_points', + 'ai_session_steps', 'ai_suggestions'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +**Note:** Replace `` with the actual revision hash generated in Task 1 (check the file that was created: `revision: str = '...'`). + +- [ ] **Step 2.4: Run migration** + +```bash +cd backend && alembic upgrade head +``` + +- [ ] **Step 2.5: Verify zero NULLs** + +```bash +cd backend && python -c " +import asyncio +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import text +import os + +async def check(): + url = os.environ.get('DATABASE_URL', 'postgresql+asyncpg://postgres:postgres@localhost:5432/resolutionflow_test') + engine = create_async_engine(url) + async with engine.connect() as conn: + for t in ('session_branches', 'session_handoffs', 'fork_points', 'ai_session_steps', 'ai_suggestions'): + r = await conn.execute(text(f'SELECT COUNT(*) FROM {t} WHERE account_id IS NULL')) + print(f'{t}: {r.scalar()} NULLs') + await engine.dispose() + +asyncio.run(check()) +" +``` + +Expected: all zeros. + +- [ ] **Step 2.6: Update SQLAlchemy models** + +In each of these files, add `account_id` as NOT NULL after the `session_id` or `user_id` column: + +**`backend/app/models/session_branch.py`** — add after `session_id` column (line 37): +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +**`backend/app/models/session_handoff.py`** — add after `session_id` column (line 29): +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +**`backend/app/models/fork_point.py`** — add after `session_id` column (line 25): +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +**`backend/app/models/ai_session_step.py`** — add after `session_id` column (line 52): +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Denormalized from ai_sessions.account_id for direct tenant filtering.", + ) +``` + +**`backend/app/models/ai_suggestion.py`** — add after `user_id` column (line 29): +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +- [ ] **Step 2.7: Run tests, full suite, commit** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py -k "branch or suggestion" -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_ai_branching* \ + backend/app/models/session_branch.py \ + backend/app/models/session_handoff.py \ + backend/app/models/fork_point.py \ + backend/app/models/ai_session_step.py \ + backend/app/models/ai_suggestion.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 2 — add account_id to AI branching tables + +Tables: session_branches, session_handoffs, fork_points, + ai_session_steps, ai_suggestions +Backfill: session_id → ai_sessions.account_id (all except +ai_suggestions which uses user_id → users.account_id) + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 3: Group 3 — Steps & ratings + +**Tables:** `step_ratings`, `step_usage_log` + +**Note:** `session_ratings` ALREADY has `account_id NOT NULL` — do not touch it. + +**Backfill paths:** Both use `user_id → users.account_id` (the rating user's account, per design). + +**Table name:** `step_usage_log` (singular, not plural — check `StepUsageLog.__tablename__`). + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_step_ratings.py` +- Modify: `backend/app/models/step_library.py` (StepRating and StepUsageLog classes) +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 3.1: Write the failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 3: Steps & ratings ────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession): + """step_ratings.account_id must be the RATER's account, not the step's account.""" + from app.models.step_library import StepLibrary, StepRating + + account_a, user_a = await _make_account_and_user(test_db, "sr-rater") + account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner") + + # Step owned by account_b + step = StepLibrary( + title="A step", + step_type="action", + content={"text": "do something"}, + created_by=user_b.id, + account_id=account_b.id, + visibility="public", + ) + test_db.add(step) + await test_db.flush() + + # user_a (account_a) rates the step + rating = StepRating( + step_id=step.id, + user_id=user_a.id, + account_id=account_a.id, # rater's account, not step owner's + was_helpful=True, + is_verified_use=False, + is_visible=True, + ) + test_db.add(rating) + await test_db.commit() + + result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id)) + row = result.scalar_one() + assert row.account_id == account_a.id, ( + f"account_id should be rater's account ({account_a.id}), got {row.account_id}" + ) +``` + +- [ ] **Step 3.2: Run test to confirm fail** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_step_rating_account_id_is_rater_account -v --override-ini="addopts=" +``` + +Expected: FAIL — `StepRating` has no `account_id`. + +- [ ] **Step 3.3: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_step_ratings" +``` + +Replace file content: + +```python +"""add account_id to step_ratings and step_usage_log + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('step_ratings', 'step_usage_log'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + # Backfill: from the RATER/LOGGER user's account (not the step's account) + op.execute(f""" + UPDATE {table} t + SET account_id = u.account_id + FROM users u + WHERE t.user_id = u.id + AND t.account_id IS NULL + """) + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('step_ratings', 'step_usage_log'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +- [ ] **Step 3.4: Run migration and verify** + +```bash +cd backend && alembic upgrade head +``` + +```bash +cd backend && python -c " +import asyncio +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import text +import os + +async def check(): + url = os.environ.get('DATABASE_URL', 'postgresql+asyncpg://postgres:postgres@localhost:5432/resolutionflow_test') + engine = create_async_engine(url) + async with engine.connect() as conn: + for t in ('step_ratings', 'step_usage_log'): + r = await conn.execute(text(f'SELECT COUNT(*) FROM {t} WHERE account_id IS NULL')) + print(f'{t}: {r.scalar()} NULLs') + await engine.dispose() + +asyncio.run(check()) +" +``` + +- [ ] **Step 3.5: Update SQLAlchemy models in `backend/app/models/step_library.py`** + +In the `StepRating` class (starts around line 125), add after the `user_id` column: + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Account of the RATER (not the step owner).", + ) +``` + +In the `StepUsageLog` class (starts around line 172), add after the `user_id` column: + +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Account of the user who logged this usage.", + ) +``` + +- [ ] **Step 3.6: Run tests, full suite, commit** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_step_rating_account_id_is_rater_account -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_step_ratings* \ + backend/app/models/step_library.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 3 — add account_id to step_ratings and step_usage_log + +Backfill from rater/user's account_id (not the step's account_id). +This is an explicit design decision — step rating data is attributed +to the account that performed the rating. + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 4: Group 4 — User personalization + +**Tables:** `user_folders`, `user_pinned_trees` + +**Backfill:** `user_id → users.account_id` + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_user_personalization.py` +- Modify: `backend/app/models/folder.py`, `backend/app/models/user_pinned_tree.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 4.1: Write failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 4: User personalization ──────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_user_folder_account_id_matches_user(test_db: AsyncSession): + """user_folders.account_id must match the owning user's account_id.""" + from app.models.folder import UserFolder + + account, user = await _make_account_and_user(test_db, "uf1") + folder = UserFolder( + user_id=user.id, + account_id=account.id, + name="My Folder", + color="#6366f1", + icon="folder", + display_order=0, + ) + test_db.add(folder) + await test_db.commit() + + result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession): + """user_pinned_trees.account_id must match the pinning user's account_id.""" + from app.models.user_pinned_tree import UserPinnedTree + + account, user = await _make_account_and_user(test_db, "pt1") + tree = await _make_tree(test_db, account, user) + pin = UserPinnedTree( + user_id=user.id, + tree_id=tree.id, + account_id=account.id, + display_order=0, + ) + test_db.add(pin) + await test_db.commit() + + result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id)) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 4.2: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_user_personalization" +``` + +```python +"""add account_id to user personalization tables + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + op.execute(f""" + UPDATE {table} t + SET account_id = u.account_id + FROM users u + WHERE t.user_id = u.id + AND t.account_id IS NULL + """) + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +- [ ] **Step 4.3: Run migration, verify, update models, test, commit** + +```bash +cd backend && alembic upgrade head +``` + +In `backend/app/models/folder.py`, add to `UserFolder` after `user_id`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +In `backend/app/models/user_pinned_tree.py`, add after `user_id`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py -k "folder or pinned" -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_user_personalization* \ + backend/app/models/folder.py \ + backend/app/models/user_pinned_tree.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 4 — add account_id to user_folders and user_pinned_trees + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 5: Group 5 — PSA & notifications + +**Tables:** `psa_post_log`, `psa_member_mappings`, `notification_logs` + +**Backfill paths:** +- `psa_post_log`: `psa_connection_id → psa_connections.account_id`. If `psa_connection_id` is NULL, fall back to `posted_by → users.account_id`. +- `psa_member_mappings`: `psa_connection_id → psa_connections.account_id` +- `notification_logs`: `notification_config_id → notification_configs.account_id` + +**Pre-check:** `psa_connections.account_id` is already NOT NULL ✓. `notification_configs.account_id` must also be NOT NULL — verify before running. + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_psa_notifications.py` +- Modify: `backend/app/models/psa_post_log.py`, `psa_member_mapping.py`, `notification_log.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 5.1: Write failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 5: PSA & notifications ───────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession): + """psa_member_mappings.account_id must match psa_connection's account_id.""" + from app.models.psa_connection import PsaConnection + from app.models.psa_member_mapping import PsaMemberMapping + + account, user = await _make_account_and_user(test_db, "psa1") + conn = PsaConnection( + account_id=account.id, + provider="connectwise", + display_name="Test CW", + site_url="https://cw.example.com", + company_id="TEST", + credentials_encrypted="placeholder", + ) + test_db.add(conn) + await test_db.flush() + + mapping = PsaMemberMapping( + psa_connection_id=conn.id, + user_id=user.id, + account_id=account.id, + external_member_id="cw-123", + external_member_name="Test User", + matched_by="manual_admin", + ) + test_db.add(mapping) + await test_db.commit() + + result = await test_db.execute( + select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 5.2: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_psa_notifications" +``` + +```python +"""add account_id to PSA and notification tables + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: ADD COLUMN + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Step 2: BACKFILL + # psa_post_log: prefer psa_connection → fallback to posted_by user + op.execute(""" + UPDATE psa_post_log ppl + SET account_id = COALESCE(pc.account_id, u.account_id) + FROM users u + LEFT JOIN psa_connections pc ON pc.id = ppl.psa_connection_id + WHERE ppl.posted_by = u.id + AND ppl.account_id IS NULL + """) + + # psa_member_mappings: via psa_connection + op.execute(""" + UPDATE psa_member_mappings pmm + SET account_id = pc.account_id + FROM psa_connections pc + WHERE pmm.psa_connection_id = pc.id + AND pmm.account_id IS NULL + """) + + # notification_logs: via notification_config + op.execute(""" + UPDATE notification_logs nl + SET account_id = nc.account_id + FROM notification_configs nc + WHERE nl.notification_config_id = nc.id + AND nl.account_id IS NULL + """) + + # Step 3: VERIFY + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + + # Step 4: SET NOT NULL + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.alter_column(table, 'account_id', nullable=False) + + # Step 5: CREATE INDEX + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +- [ ] **Step 5.3: Run migration, verify, update models, test, commit** + +```bash +cd backend && alembic upgrade head +``` + +Add `account_id` (NOT NULL, FK to accounts) to: +- `backend/app/models/psa_post_log.py` — after `ai_session_id` column +- `backend/app/models/psa_member_mapping.py` — after `psa_connection_id` column +- `backend/app/models/notification_log.py` — after `notification_config_id` column + +Each follows the same pattern: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py -k "psa" -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_psa_notifications* \ + backend/app/models/psa_post_log.py \ + backend/app/models/psa_member_mapping.py \ + backend/app/models/notification_log.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 5 — add account_id to PSA and notification tables + +psa_post_log: backfill via psa_connection, fallback to posted_by user +psa_member_mappings: backfill via psa_connection +notification_logs: backfill via notification_config + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 6: Group 6 — Maintenance + +**Table:** `maintenance_schedules` + +**Backfill path:** `tree_id → trees.account_id`. Note: `trees.account_id` is still nullable at this point. Any maintenance schedule whose tree has `account_id=NULL` (i.e., is_default=TRUE) will not backfill. Fall back to `created_by → users.account_id` for those rows. + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_maintenance.py` +- Modify: `backend/app/models/maintenance_schedule.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 6.1: Write failing test** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 6: Maintenance ────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession): + """maintenance_schedules.account_id must match the tree's account_id.""" + from app.models.maintenance_schedule import MaintenanceSchedule + + account, user = await _make_account_and_user(test_db, "ms1") + tree = Tree( + name="Maintenance Flow", + account_id=account.id, + author_id=user.id, + visibility="team", + tree_type="maintenance", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + test_db.add(tree) + await test_db.flush() + + schedule = MaintenanceSchedule( + tree_id=tree.id, + account_id=account.id, + created_by=user.id, + cron_expression="0 9 * * 1", + timezone="UTC", + is_active=True, + ) + test_db.add(schedule) + await test_db.commit() + + result = await test_db.execute( + select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 6.2: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_maintenance" +``` + +```python +"""add account_id to maintenance_schedules + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('maintenance_schedules', + sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_maintenance_schedules_account_id', 'maintenance_schedules', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Primary: tree_id → trees.account_id + op.execute(""" + UPDATE maintenance_schedules ms + SET account_id = t.account_id + FROM trees t + WHERE ms.tree_id = t.id + AND t.account_id IS NOT NULL + AND ms.account_id IS NULL + """) + + # Fallback: created_by → users.account_id (for is_default trees with NULL account_id) + op.execute(""" + UPDATE maintenance_schedules ms + SET account_id = u.account_id + FROM users u + WHERE ms.created_by = u.id + AND u.account_id IS NOT NULL + AND ms.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM maintenance_schedules WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} maintenance_schedules rows have NULL account_id. " + "Check if created_by is NULL — those rows need manual resolution." + ) + + op.alter_column('maintenance_schedules', 'account_id', nullable=False) + op.create_index('ix_maintenance_schedules_account_id', 'maintenance_schedules', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_maintenance_schedules_account_id', table_name='maintenance_schedules') + op.drop_constraint('fk_maintenance_schedules_account_id', 'maintenance_schedules', type_='foreignkey') + op.drop_column('maintenance_schedules', 'account_id') +``` + +- [ ] **Step 6.3: Run migration, verify, update model, test, commit** + +```bash +cd backend && alembic upgrade head +``` + +In `backend/app/models/maintenance_schedule.py`, add after `created_by`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_maintenance_schedule_account_id_matches_tree -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_maintenance* \ + backend/app/models/maintenance_schedule.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 6 — add account_id to maintenance_schedules + +Primary backfill: tree_id → trees.account_id +Fallback: created_by → users.account_id (for is_default tree rows) + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 7: Group 7 — Legacy team_id tables + +**Tables:** `script_builder_sessions`, `script_templates`, `script_generations` + +**Backfill paths:** +- `script_builder_sessions`: `user_id → users.account_id` +- `script_templates`: `created_by → users.account_id` (`created_by` is nullable — handle with fallback) +- `script_generations`: `user_id → users.account_id` + +**Important:** Do NOT drop `team_id` — keep it until all application code is updated. + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_script_tables.py` +- Modify: `backend/app/models/script_builder_session.py`, `backend/app/models/script_template.py` (ScriptTemplate and ScriptGeneration) +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 7.1: Write failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 7: Legacy team_id tables ─────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_script_builder_session_account_id(test_db: AsyncSession): + """script_builder_sessions.account_id must match user's account_id.""" + from app.models.script_builder_session import ScriptBuilderSession + + account, user = await _make_account_and_user(test_db, "sbs1") + sbs = ScriptBuilderSession( + user_id=user.id, + account_id=account.id, + language="powershell", + ) + test_db.add(sbs) + await test_db.commit() + + result = await test_db.execute( + select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 7.2: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_script_tables" +``` + +```python +"""add account_id to script_builder_sessions, script_templates, script_generations + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # script_builder_sessions: user_id → users.account_id + op.execute(""" + UPDATE script_builder_sessions sbs + SET account_id = u.account_id + FROM users u + WHERE sbs.user_id = u.id + AND sbs.account_id IS NULL + """) + + # script_templates: created_by → users.account_id + # created_by is nullable, so left join + only set where not null + op.execute(""" + UPDATE script_templates st + SET account_id = u.account_id + FROM users u + WHERE st.created_by = u.id + AND st.account_id IS NULL + """) + # Fallback for script_templates with NULL created_by: team_id → team admin user + op.execute(""" + UPDATE script_templates st + SET account_id = u.account_id + FROM users u + WHERE u.team_id = st.team_id + AND u.is_team_admin = TRUE + AND st.account_id IS NULL + AND EXISTS (SELECT 1 FROM users u2 WHERE u2.team_id = st.team_id) + """) + + # script_generations: user_id → users.account_id + op.execute(""" + UPDATE script_generations sg + SET account_id = u.account_id + FROM users u + WHERE sg.user_id = u.id + AND sg.account_id IS NULL + """) + + # VERIFY + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('script_builder_sessions', 'script_templates', 'script_generations'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') +``` + +- [ ] **Step 7.3: Run migration, verify, update models, test, commit** + +```bash +cd backend && alembic upgrade head +``` + +In `backend/app/models/script_builder_session.py`, add after `user_id`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +In `backend/app/models/script_template.py`: +- In `ScriptTemplate` class, add after `team_id`: +```python + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` +- In `ScriptGeneration` class, add after `user_id`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_script_builder_session_account_id -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_script_tables* \ + backend/app/models/script_builder_session.py \ + backend/app/models/script_template.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 7 — add account_id to script tables (keep team_id) + +team_id is kept in all three tables — drop deferred until app code +is fully migrated off team_id references. + +Tables: script_builder_sessions, script_templates, script_generations +Backfill: user_id/created_by → users.account_id + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 8: Group 8 — TargetList + +**Table:** `target_lists` + +**Backfill path:** `team_id → users WHERE is_team_admin=TRUE → account_id` + +**Context:** Zero rows in production (confirmed 2026-04-09). The migration is schema-only in practice but must be correct for any future rows. The `team_id` FK to `teams` is NOT NULL — keep it. Do NOT drop it. + +**Files:** +- Create: `backend/alembic/versions/_add_account_id_target_lists.py` +- Modify: `backend/app/models/target_list.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 8.1: Write failing test** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 8: TargetList ──────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_target_list_account_id_from_team_admin(test_db: AsyncSession): + """target_lists.account_id must be set to the team admin's account_id.""" + from app.models.target_list import TargetList + from app.models.team import Team + + account, user = await _make_account_and_user(test_db, "tl1") + # Make user a team admin + team = Team(name=f"Team {uuid.uuid4().hex[:6]}") + test_db.add(team) + await test_db.flush() + + user.team_id = team.id + user.is_team_admin = True + await test_db.flush() + + target_list = TargetList( + team_id=team.id, + account_id=account.id, + created_by=user.id, + name="Server Targets", + targets=[{"label": "SRV-01"}], + ) + test_db.add(target_list) + await test_db.commit() + + result = await test_db.execute( + select(TargetList).where(TargetList.id == target_list.id) + ) + row = result.scalar_one() + assert row.account_id == account.id +``` + +- [ ] **Step 8.2: Generate migration** + +```bash +cd backend && alembic revision -m "add_account_id_target_lists" +``` + +```python +"""add account_id to target_lists (keep team_id) + +Revision ID: +Revises: +Create Date: +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('target_lists', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_target_lists_account_id', 'target_lists', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Backfill: team_id → team admin user → account_id + # If any row cannot be backfilled (no team admin found) → ROLLBACK + op.execute(""" + UPDATE target_lists tl + SET account_id = u.account_id + FROM users u + WHERE u.team_id = tl.team_id + AND u.is_team_admin = TRUE + AND u.account_id IS NOT NULL + AND tl.account_id IS NULL + """) + + # Secondary fallback: created_by user + op.execute(""" + UPDATE target_lists tl + SET account_id = u.account_id + FROM users u + WHERE tl.created_by = u.id + AND u.account_id IS NOT NULL + AND tl.account_id IS NULL + """) + + result = op.get_bind().execute( + sa.text("SELECT COUNT(*) FROM target_lists WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} target_lists rows have NULL account_id. " + "No team admin found for these teams. Resolve before re-running." + ) + + op.alter_column('target_lists', 'account_id', nullable=False) + op.create_index('ix_target_lists_account_id', 'target_lists', ['account_id']) + + +def downgrade() -> None: + op.drop_index('ix_target_lists_account_id', table_name='target_lists') + op.drop_constraint('fk_target_lists_account_id', 'target_lists', type_='foreignkey') + op.drop_column('target_lists', 'account_id') +``` + +- [ ] **Step 8.3: Run migration, verify, update model, test, commit** + +```bash +cd backend && alembic upgrade head +``` + +In `backend/app/models/target_list.py`, add after `team_id`: +```python + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +Also add the Account import to TYPE_CHECKING: +```python +if TYPE_CHECKING: + from app.models.user import User + from app.models.team import Team + from app.models.account import Account +``` + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py::test_target_list_account_id_from_team_admin -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*add_account_id_target_lists* \ + backend/app/models/target_list.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 8 — add account_id to target_lists (keep team_id) + +Zero rows in production — this is a schema-only migration in practice. +team_id kept for app code compatibility. Drop deferred to later cleanup. +Backfill: team_id → team admin user → account_id; fallback: created_by. + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 9: Group 10 — Global content separation (runs before Group 9) + +**Why this runs before Task 10:** `trees` has `account_id=NULL` for `is_default=TRUE` rows (platform trees). Task 10 sets trees.account_id NOT NULL, which would fail without first handling these rows. This task moves them to `template_trees` (no account_id column), then Task 10 can safely SET NOT NULL on trees. + +**Action:** +1. Create `template_trees` table — stores platform-owned troubleshooting trees (no account_id, no RLS) +2. Create `platform_steps` table — stores platform-owned steps (no account_id, no RLS) +3. Copy `is_default=TRUE` trees to `template_trees` +4. Copy `visibility='public'` steps from `step_library` to `platform_steps` +5. Remove the copied rows from `trees` (set `is_default=FALSE` and assign a NULL-safe account) — or delete if no sessions reference them +6. Handle global `tree_categories`, `tree_tags`, `step_categories` (NULL `account_id` rows = global platform items) — assign to a "ResolutionFlow Platform" internal account created in this migration + +**Files:** +- Create: `backend/alembic/versions/_create_global_content_tables.py` +- Create: `backend/app/models/template_tree.py` +- Create: `backend/app/models/platform_step.py` +- Modify: `backend/app/models/__init__.py` (register new models) +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 9.1: Write failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 10 (runs first): Global content tables ────────────────────────────── + +@pytest.mark.asyncio +async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession): + """template_trees must exist and must NOT have an account_id column.""" + result = await test_db.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'template_trees' + """)) + columns = {row[0] for row in result.fetchall()} + assert 'id' in columns, "template_trees.id must exist" + assert 'account_id' not in columns, "template_trees must not have account_id (global content)" + + +@pytest.mark.asyncio +async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession): + """platform_steps must exist and must NOT have an account_id column.""" + result = await test_db.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'platform_steps' + """)) + columns = {row[0] for row in result.fetchall()} + assert 'id' in columns, "platform_steps.id must exist" + assert 'account_id' not in columns, "platform_steps must not have account_id (global content)" +``` + +- [ ] **Step 9.2: Generate migration** + +```bash +cd backend && alembic revision -m "create_global_content_tables" +``` + +```python +"""create template_trees and platform_steps global content tables + +Revision ID: +Revises: +Create Date: + +These tables hold platform-owned content that is readable by all +authenticated users. No account_id. No RLS. Ever. +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Create template_trees ───────────────────────────────────────────────── + op.create_table( + 'template_trees', + sa.Column('id', UUID(), primary_key=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('category', sa.String(100), nullable=True), + sa.Column('tree_type', sa.String(20), nullable=False), + sa.Column('tree_structure', JSONB(), nullable=False), + sa.Column('tags', JSONB(), nullable=False, server_default='[]'), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + # source_tree_id: original tree this was promoted from (nullable) + sa.Column('source_tree_id', UUID(), sa.ForeignKey('trees.id', ondelete='SET NULL'), nullable=True), + ) + op.create_index('ix_template_trees_tree_type', 'template_trees', ['tree_type']) + + # ── Create platform_steps ──────────────────────────────────────────────── + op.create_table( + 'platform_steps', + sa.Column('id', UUID(), primary_key=True), + sa.Column('title', sa.String(255), nullable=False), + sa.Column('step_type', sa.String(50), nullable=False), + sa.Column('content', JSONB(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + # source_step_id: original step this was promoted from (nullable) + sa.Column('source_step_id', UUID(), sa.ForeignKey('step_library.id', ondelete='SET NULL'), nullable=True), + ) + op.create_index('ix_platform_steps_step_type', 'platform_steps', ['step_type']) + + # ── Migrate is_default=TRUE trees → template_trees ───────────────────── + op.execute(""" + INSERT INTO template_trees + (id, name, description, category, tree_type, tree_structure, + is_active, created_at, updated_at, source_tree_id) + SELECT + gen_random_uuid(), name, description, category, tree_type, + tree_structure, is_active, + COALESCE(created_at, NOW()), COALESCE(updated_at, NOW()), id + FROM trees + WHERE is_default = TRUE + """) + + # ── Migrate visibility='public' steps → platform_steps ───────────────── + op.execute(""" + INSERT INTO platform_steps + (id, title, step_type, content, is_active, created_at, updated_at, source_step_id) + SELECT + gen_random_uuid(), title, step_type, content, is_active, + COALESCE(created_at, NOW()), COALESCE(updated_at, NOW()), id + FROM step_library + WHERE visibility = 'public' + """) + + # ── Create a ResolutionFlow platform account for global content ────────── + # Used to satisfy NOT NULL on trees, tree_categories, tree_tags, etc. + # This is a sentinel account — it is NOT a real customer account. + op.execute(""" + INSERT INTO accounts (id, name, display_code, created_at, updated_at) + VALUES ( + '00000000-0000-0000-0000-000000000001', + 'ResolutionFlow Platform', + 'PLATFORM', + NOW(), + NOW() + ) + ON CONFLICT (id) DO NOTHING + """) + + # ── Assign is_default trees to platform account ────────────────────────── + op.execute(""" + UPDATE trees + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE is_default = TRUE + AND account_id IS NULL + """) + + # ── Assign global tree_categories (team_id=NULL, account_id=NULL) ──────── + op.execute(""" + UPDATE tree_categories + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + # ── Assign global tree_tags (team_id=NULL, account_id=NULL) ───────────── + op.execute(""" + UPDATE tree_tags + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + # ── Assign global step_categories (account_id=NULL) ────────────────────── + op.execute(""" + UPDATE step_categories + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + # ── Assign global step_library entries (visibility='public', account_id=NULL) ─ + op.execute(""" + UPDATE step_library + SET account_id = '00000000-0000-0000-0000-000000000001' + WHERE account_id IS NULL + """) + + # ── Verify all target tables now have zero NULLs ───────────────────────── + for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows remain in {table} " + "after platform account assignment. Investigate before re-running." + ) + + +def downgrade() -> None: + # Reverse platform account assignments (set back to NULL where platform account) + platform_id = '00000000-0000-0000-0000-000000000001' + for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'): + op.execute(f"UPDATE {table} SET account_id = NULL WHERE account_id = '{platform_id}'") + + op.execute(f"DELETE FROM accounts WHERE id = '{platform_id}'") + op.drop_index('ix_platform_steps_step_type', table_name='platform_steps') + op.drop_index('ix_template_trees_tree_type', table_name='template_trees') + op.drop_table('platform_steps') + op.drop_table('template_trees') +``` + +- [ ] **Step 9.3: Create the SQLAlchemy model files** + +Create `backend/app/models/template_tree.py`: + +```python +"""Template tree model — platform-owned troubleshooting trees, readable by all users. + +No account_id. No RLS. Readable by any authenticated user. +Populated by promoting is_default=TRUE trees from the trees table. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class TemplateTree(Base): + __tablename__ = "template_trees" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + tree_type: Mapped[str] = mapped_column(String(20), nullable=False, index=True) + tree_structure: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + tags: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + source_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("trees.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) +``` + +Create `backend/app/models/platform_step.py`: + +```python +"""Platform step model — platform-owned steps, readable by all users. + +No account_id. No RLS. Readable by any authenticated user. +Populated by promoting visibility='public' steps from step_library. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class PlatformStep(Base): + __tablename__ = "platform_steps" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + title: Mapped[str] = mapped_column(String(255), nullable=False) + step_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + content: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + source_step_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("step_library.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) +``` + +In `backend/app/models/__init__.py`, add: +```python +from .template_tree import TemplateTree +from .platform_step import PlatformStep +``` + +And add `"TemplateTree"` and `"PlatformStep"` to the `__all__` list. + +- [ ] **Step 9.4: Run migration, run tests, commit** + +```bash +cd backend && alembic upgrade head +cd backend && python -m pytest tests/test_phase1_migrations.py -k "template_trees or platform_steps" -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +git add backend/alembic/versions/*create_global_content_tables* \ + backend/app/models/template_tree.py \ + backend/app/models/platform_step.py \ + backend/app/models/__init__.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 10 — create global content tables and platform account + +Creates template_trees and platform_steps (no account_id, no RLS). +Migrates is_default=TRUE trees and public steps into them. +Creates sentinel platform account (00000000-...-0001) for global +tree_categories, tree_tags, step_categories, step_library, and +is_default trees — clearing all NULL account_id rows in those tables +as prerequisite for Group 9 SET NOT NULL. + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 10: Group 9 — SET NOT NULL on existing nullable account_id columns + +**Why this runs last:** Depends on Task 9 having cleared all NULL account_id rows via platform account assignment. + +**Tables:** `users`, `trees`, `tree_categories`, `tree_tags`, `step_categories`, `step_library`, `tree_embeddings`, `feedback` + +**Action:** For each table: +1. Verify zero NULLs (if any remain, backfill or delete) +2. SET NOT NULL +3. If index doesn't already exist, CREATE INDEX + +**Special cases:** +- `users.account_id`: Any user with NULL account_id must be investigated — they are orphaned. If none, proceed. +- `tree_embeddings.account_id`: Backfill from `tree_id → trees.account_id` (trees now all have account_id after Task 9). +- `feedback.account_id`: Backfill from `user_id → users.account_id`. + +**Files:** +- Create: `backend/alembic/versions/_set_not_null_account_id_phase1.py` +- Modify: `backend/app/models/user.py`, `tree.py`, `category.py`, `tag.py`, `step_category.py`, `step_library.py` (StepLibrary), `tree_embedding.py`, `feedback.py` +- Test: `backend/tests/test_phase1_migrations.py` + +--- + +- [ ] **Step 10.1: Write failing tests** + +Append to `backend/tests/test_phase1_migrations.py`: + +```python +# ── Group 9: SET NOT NULL on existing nullable columns ──────────────────────── + +@pytest.mark.asyncio +async def test_tree_account_id_is_not_null(test_db: AsyncSession): + """trees.account_id must be NOT NULL after Phase 1 — enforced at DB level.""" + # Try to insert a tree with no account_id — must fail + from sqlalchemy.exc import IntegrityError + with pytest.raises(IntegrityError): + test_db.add(Tree( + name="Bad tree", + # account_id intentionally omitted + author_id=None, + visibility="private", + tree_type="troubleshooting", + tree_structure={}, + is_active=True, + status="draft", + )) + await test_db.flush() + + +@pytest.mark.asyncio +async def test_user_account_id_is_not_null(test_db: AsyncSession): + """users.account_id must be NOT NULL after Phase 1.""" + from sqlalchemy.exc import IntegrityError + with pytest.raises(IntegrityError): + test_db.add(User( + email=f"orphan-{uuid.uuid4().hex[:6]}@example.com", + name="Orphan", + password_hash=get_password_hash("x"), + is_active=True, + # account_id intentionally omitted + )) + await test_db.flush() +``` + +- [ ] **Step 10.2: Generate migration** + +```bash +cd backend && alembic revision -m "set_not_null_account_id_phase1" +``` + +```python +"""set NOT NULL on all previously-nullable account_id columns + +Revision ID: +Revises: +Create Date: + +All tables in this migration had account_id set to nullable previously. +Task 9 (create_global_content_tables) cleared all NULL rows. +This migration enforces the NOT NULL constraint. +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '' +down_revision: Union[str, None] = '' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # tree_embeddings: backfill from trees (must happen before SET NOT NULL) + op.execute(""" + UPDATE tree_embeddings te + SET account_id = t.account_id + FROM trees t + WHERE te.tree_id = t.id + AND te.account_id IS NULL + """) + + # feedback: backfill from users + op.execute(""" + UPDATE feedback f + SET account_id = u.account_id + FROM users u + WHERE f.user_id = u.id + AND f.account_id IS NULL + """) + + # Verify ALL tables before touching any SET NOT NULL + tables_with_account_id = [ + 'users', 'trees', 'tree_categories', 'tree_tags', + 'step_categories', 'step_library', 'tree_embeddings', 'feedback', + ] + for table in tables_with_account_id: + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows in {table}. " + "Run Task 9 (create_global_content_tables) first, or " + "manually backfill/delete orphaned rows." + ) + + # SET NOT NULL on all + for table in tables_with_account_id: + op.alter_column(table, 'account_id', nullable=False) + + # Create indexes where they don't already exist + # (some tables like trees already have ix_trees_account_id from prior work) + new_indexes = [ + ('tree_embeddings', 'ix_tree_embeddings_account_id'), + ('feedback', 'ix_feedback_account_id'), + ] + for table, index_name in new_indexes: + # Check if index exists to avoid duplicate error + result = op.get_bind().execute(sa.text( + f"SELECT 1 FROM pg_indexes WHERE tablename='{table}' AND indexname='{index_name}'" + )) + if not result.fetchone(): + op.create_index(index_name, table, ['account_id']) + + +def downgrade() -> None: + # Revert to nullable + for table in ('users', 'trees', 'tree_categories', 'tree_tags', + 'step_categories', 'step_library', 'tree_embeddings', 'feedback'): + op.alter_column(table, 'account_id', nullable=True) + for table, index_name in ( + ('tree_embeddings', 'ix_tree_embeddings_account_id'), + ('feedback', 'ix_feedback_account_id'), + ): + try: + op.drop_index(index_name, table_name=table) + except Exception: + pass +``` + +- [ ] **Step 10.3: Run migration** + +```bash +cd backend && alembic upgrade head +``` + +If this errors with "NULL account_id rows remain in users", investigate: +```sql +-- Run from VPS SSH via docker exec +SELECT id, email, account_id FROM users WHERE account_id IS NULL; +``` +These are orphaned users. Either assign them to an account or delete them if they are test/seed data. + +- [ ] **Step 10.4: Update SQLAlchemy models — change `Mapped[Optional[uuid.UUID]]` to `Mapped[uuid.UUID]` and `nullable=True` to `nullable=False`** + +**`backend/app/models/user.py`** — find `account_id` (around line 46) and change: +```python +# BEFORE: + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="RESTRICT"), + nullable=True, + ... + ) + +# AFTER: + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="RESTRICT"), + nullable=False, + index=True, + ) +``` + +**`backend/app/models/tree.py`** — find `account_id` (around line 79) and change: +```python +# BEFORE: + account_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), ..., nullable=True, ... + ) + +# AFTER: + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) +``` + +Apply the same pattern (`Optional` → required, `nullable=True` → `nullable=False`) to: +- `backend/app/models/category.py` — `TreeCategory.account_id` +- `backend/app/models/tag.py` — `TreeTag.account_id` +- `backend/app/models/step_category.py` — `StepCategory.account_id` +- `backend/app/models/step_library.py` — `StepLibrary.account_id` +- `backend/app/models/tree_embedding.py` — `TreeEmbedding.account_id` +- `backend/app/models/feedback.py` — `Feedback.account_id` + +- [ ] **Step 10.5: Run tests, full suite, commit** + +```bash +cd backend && python -m pytest tests/test_phase1_migrations.py -v --override-ini="addopts=" +cd backend && python -m pytest --override-ini="addopts=" +``` + +If any existing tests fail because they create objects without `account_id`, update those test fixtures to provide the required field. + +```bash +git add backend/alembic/versions/*set_not_null_account_id* \ + backend/app/models/user.py \ + backend/app/models/tree.py \ + backend/app/models/category.py \ + backend/app/models/tag.py \ + backend/app/models/step_category.py \ + backend/app/models/step_library.py \ + backend/app/models/tree_embedding.py \ + backend/app/models/feedback.py \ + backend/tests/test_phase1_migrations.py +git commit -m "feat: Phase 1 Group 9 — enforce NOT NULL on all account_id columns + +All previously-nullable account_id columns are now NOT NULL. +tree_embeddings and feedback backfilled before constraint applied. +Global content assigned to platform sentinel account (00000000-...-0001) +in preceding migration. + +Tables updated: users, trees, tree_categories, tree_tags, +step_categories, step_library, tree_embeddings, feedback + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Task 11: Phase 1 gate verification + +**Run the gate verification query across all tenant tables. All must return zero NULLs.** + +**Files:** No code changes — verification only. + +--- + +- [ ] **Step 11.1: Run the gate verification query** + +From VPS SSH: + +```bash +docker exec -it resolutionflow_postgres psql -U postgres -d resolutionflow -c " +SELECT tablename, null_count +FROM ( + SELECT 'sessions' AS tablename, COUNT(*) FILTER (WHERE account_id IS NULL) AS null_count FROM sessions + UNION ALL + SELECT 'attachments', COUNT(*) FILTER (WHERE account_id IS NULL) FROM attachments + UNION ALL + SELECT 'session_supporting_data', COUNT(*) FILTER (WHERE account_id IS NULL) FROM session_supporting_data + UNION ALL + SELECT 'session_resolution_outputs',COUNT(*) FILTER (WHERE account_id IS NULL) FROM session_resolution_outputs + UNION ALL + SELECT 'session_branches', COUNT(*) FILTER (WHERE account_id IS NULL) FROM session_branches + UNION ALL + SELECT 'session_handoffs', COUNT(*) FILTER (WHERE account_id IS NULL) FROM session_handoffs + UNION ALL + SELECT 'fork_points', COUNT(*) FILTER (WHERE account_id IS NULL) FROM fork_points + UNION ALL + SELECT 'ai_session_steps', COUNT(*) FILTER (WHERE account_id IS NULL) FROM ai_session_steps + UNION ALL + SELECT 'ai_suggestions', COUNT(*) FILTER (WHERE account_id IS NULL) FROM ai_suggestions + UNION ALL + SELECT 'step_ratings', COUNT(*) FILTER (WHERE account_id IS NULL) FROM step_ratings + UNION ALL + SELECT 'step_usage_log', COUNT(*) FILTER (WHERE account_id IS NULL) FROM step_usage_log + UNION ALL + SELECT 'user_folders', COUNT(*) FILTER (WHERE account_id IS NULL) FROM user_folders + UNION ALL + SELECT 'user_pinned_trees', COUNT(*) FILTER (WHERE account_id IS NULL) FROM user_pinned_trees + UNION ALL + SELECT 'psa_post_log', COUNT(*) FILTER (WHERE account_id IS NULL) FROM psa_post_log + UNION ALL + SELECT 'psa_member_mappings', COUNT(*) FILTER (WHERE account_id IS NULL) FROM psa_member_mappings + UNION ALL + SELECT 'notification_logs', COUNT(*) FILTER (WHERE account_id IS NULL) FROM notification_logs + UNION ALL + SELECT 'maintenance_schedules', COUNT(*) FILTER (WHERE account_id IS NULL) FROM maintenance_schedules + UNION ALL + SELECT 'script_builder_sessions', COUNT(*) FILTER (WHERE account_id IS NULL) FROM script_builder_sessions + UNION ALL + SELECT 'script_templates', COUNT(*) FILTER (WHERE account_id IS NULL) FROM script_templates + UNION ALL + SELECT 'script_generations', COUNT(*) FILTER (WHERE account_id IS NULL) FROM script_generations + UNION ALL + SELECT 'target_lists', COUNT(*) FILTER (WHERE account_id IS NULL) FROM target_lists + UNION ALL + SELECT 'users', COUNT(*) FILTER (WHERE account_id IS NULL) FROM users + UNION ALL + SELECT 'trees', COUNT(*) FILTER (WHERE account_id IS NULL) FROM trees + UNION ALL + SELECT 'tree_categories', COUNT(*) FILTER (WHERE account_id IS NULL) FROM tree_categories + UNION ALL + SELECT 'tree_tags', COUNT(*) FILTER (WHERE account_id IS NULL) FROM tree_tags + UNION ALL + SELECT 'step_categories', COUNT(*) FILTER (WHERE account_id IS NULL) FROM step_categories + UNION ALL + SELECT 'step_library', COUNT(*) FILTER (WHERE account_id IS NULL) FROM step_library + UNION ALL + SELECT 'tree_embeddings', COUNT(*) FILTER (WHERE account_id IS NULL) FROM tree_embeddings + UNION ALL + SELECT 'feedback', COUNT(*) FILTER (WHERE account_id IS NULL) FROM feedback +) t +ORDER BY null_count DESC, tablename; +" +``` + +Expected: all rows show `null_count = 0`. + +Any non-zero row is a blocker — do not proceed to Phase 2 until resolved. + +- [ ] **Step 11.2: Verify CI is still green** + +```bash +gh run list --limit 3 +``` + +Check that the latest CI run on `feat/tenant-isolation-phase-1` is green. The tenant filter check will now report fewer warnings (tables that gained account_id no longer trigger false positives). + +- [ ] **Step 11.3: Create PR** + +```bash +git push -u origin feat/tenant-isolation-phase-1 +gh pr create \ + --base main \ + --title "feat: tenant isolation Phase 1 — add account_id to all tenant tables" \ + --body "Adds account_id NOT NULL to all tenant tables, creates global content tables, and enforces the platform account sentinel for legacy global content. Phase 2 (RLS + SET LOCAL in get_db) is unblocked once this merges and gate query returns all zeros." +``` + +--- + +## Phase 1 Gate Checklist + +Before merging and declaring Phase 1 complete: + +- [ ] All 10 migrations in `alembic/versions/` chained correctly (`down_revision` points to previous) +- [ ] All migrations run cleanly: `alembic upgrade head` exits 0 +- [ ] All 28 tenant tables show `null_count = 0` in gate verification query +- [ ] Full test suite passes: `python -m pytest --override-ini="addopts="` +- [ ] `python scripts/check_tenant_filters.py` warning count has decreased (tables with account_id no longer flagged) +- [ ] `session_ratings` not touched (already had `account_id NOT NULL` ✓) +- [ ] `team_id` columns NOT dropped on script tables, target_lists (deferred cleanup) +- [ ] CI passes on `feat/tenant-isolation-phase-1` branch +- [ ] Gate verification query run against **production DB** (VPS SSH) and returns all zeros