feat: Phase 1 tenant isolation — add account_id to all tenant tables #133
102
backend/alembic/versions/0b470d9e6cf1_create_db_roles.py
Normal file
102
backend/alembic/versions/0b470d9e6cf1_create_db_roles.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""create_db_roles
|
||||
|
||||
Revision ID: 0b470d9e6cf1
|
||||
Revises: a9f3b2c1d4e5
|
||||
Create Date: 2026-04-10 03:58:10.207919
|
||||
|
||||
"""
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0b470d9e6cf1'
|
||||
down_revision: Union[str, None] = 'a9f3b2c1d4e5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Passwords from env vars. For local dev, defaults are sufficient.
|
||||
# For production (Railway), set DB_APP_ROLE_PASSWORD and
|
||||
# DB_ADMIN_ROLE_PASSWORD as environment variables before running migrations.
|
||||
# Passwords must not contain single quotes.
|
||||
app_pw = os.environ.get("DB_APP_ROLE_PASSWORD", "app_secret_change_me")
|
||||
admin_pw = os.environ.get("DB_ADMIN_ROLE_PASSWORD", "admin_secret_change_me")
|
||||
|
||||
# Fetch the current database name dynamically — avoids hardcoding
|
||||
# (the DB is named 'resolutionflow' in dev, potentially different elsewhere).
|
||||
conn = op.get_bind()
|
||||
db_name = conn.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
# ── Application role ────────────────────────────────────────────────────
|
||||
# Subject to RLS. Used by FastAPI at runtime via DATABASE_URL.
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_app') THEN
|
||||
CREATE ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}';
|
||||
ELSE
|
||||
ALTER ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}';
|
||||
END IF;
|
||||
END $$
|
||||
""")
|
||||
op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_app")
|
||||
op.execute("GRANT USAGE ON SCHEMA public TO resolutionflow_app")
|
||||
op.execute(
|
||||
"GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public "
|
||||
"TO resolutionflow_app"
|
||||
)
|
||||
op.execute(
|
||||
"GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO resolutionflow_app"
|
||||
)
|
||||
# Ensure future tables automatically get the same permissions
|
||||
op.execute(
|
||||
"ALTER DEFAULT PRIVILEGES IN SCHEMA public "
|
||||
"GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO resolutionflow_app"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER DEFAULT PRIVILEGES IN SCHEMA public "
|
||||
"GRANT USAGE, SELECT ON SEQUENCES TO resolutionflow_app"
|
||||
)
|
||||
|
||||
# ── Admin role ──────────────────────────────────────────────────────────
|
||||
# BYPASSRLS. Used by Alembic (DATABASE_URL_SYNC) and /admin/* endpoints
|
||||
# (ADMIN_DATABASE_URL) after Task 11.
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_admin') THEN
|
||||
CREATE ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}';
|
||||
ELSE
|
||||
ALTER ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}';
|
||||
END IF;
|
||||
END $$
|
||||
""")
|
||||
op.execute("GRANT resolutionflow_app TO resolutionflow_admin")
|
||||
op.execute("ALTER ROLE resolutionflow_admin BYPASSRLS")
|
||||
op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_admin")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
db_name = conn.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
op.execute(
|
||||
"REVOKE ALL ON ALL TABLES IN SCHEMA public FROM resolutionflow_app"
|
||||
)
|
||||
op.execute(
|
||||
"REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM resolutionflow_app"
|
||||
)
|
||||
op.execute(
|
||||
f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_app"
|
||||
)
|
||||
op.execute(
|
||||
f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_admin"
|
||||
)
|
||||
op.execute("DROP ROLE IF EXISTS resolutionflow_admin")
|
||||
op.execute("DROP ROLE IF EXISTS resolutionflow_app")
|
||||
@@ -0,0 +1,86 @@
|
||||
"""set NOT NULL on all previously-nullable account_id columns
|
||||
|
||||
Revision ID: 174f442795b7
|
||||
Revises: 3a40fe11b427
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
|
||||
All tables in this migration had account_id set to nullable previously.
|
||||
Task 9 (create_global_content_tables) cleared all NULL rows.
|
||||
This migration enforces the NOT NULL constraint.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '174f442795b7'
|
||||
down_revision: Union[str, None] = '3a40fe11b427'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# tree_embeddings: backfill from trees (must happen before SET NOT NULL)
|
||||
op.execute("""
|
||||
UPDATE tree_embeddings te
|
||||
SET account_id = t.account_id
|
||||
FROM trees t
|
||||
WHERE te.tree_id = t.id
|
||||
AND te.account_id IS NULL
|
||||
""")
|
||||
|
||||
# feedback: backfill from users
|
||||
op.execute("""
|
||||
UPDATE feedback f
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE f.user_id = u.id
|
||||
AND f.account_id IS NULL
|
||||
""")
|
||||
|
||||
# Verify ALL tables before touching any SET NOT NULL
|
||||
tables_with_account_id = [
|
||||
'users', 'trees', 'tree_categories', 'tree_tags',
|
||||
'step_categories', 'step_library', 'tree_embeddings', 'feedback',
|
||||
]
|
||||
for table in tables_with_account_id:
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} NULL account_id rows in {table}. "
|
||||
"Run Task 9 (create_global_content_tables) first, or "
|
||||
"manually backfill/delete orphaned rows."
|
||||
)
|
||||
|
||||
# SET NOT NULL on all
|
||||
for table in tables_with_account_id:
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
|
||||
# Create indexes where they don't already exist
|
||||
new_indexes = [
|
||||
('tree_embeddings', 'ix_tree_embeddings_account_id'),
|
||||
('feedback', 'ix_feedback_account_id'),
|
||||
]
|
||||
for table, index_name in new_indexes:
|
||||
result = op.get_bind().execute(sa.text(
|
||||
f"SELECT 1 FROM pg_indexes WHERE tablename='{table}' AND indexname='{index_name}'"
|
||||
))
|
||||
if not result.fetchone():
|
||||
op.create_index(index_name, table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert to nullable
|
||||
for table in ('users', 'trees', 'tree_categories', 'tree_tags',
|
||||
'step_categories', 'step_library', 'tree_embeddings', 'feedback'):
|
||||
op.alter_column(table, 'account_id', nullable=True)
|
||||
for table, index_name in (
|
||||
('tree_embeddings', 'ix_tree_embeddings_account_id'),
|
||||
('feedback', 'ix_feedback_account_id'),
|
||||
):
|
||||
try:
|
||||
op.drop_index(index_name, table_name=table)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,62 @@
|
||||
"""add account_id to target_lists (keep team_id)
|
||||
|
||||
Revision ID: 2c6aabd89bc6
|
||||
Revises: 78fc200abac1
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '2c6aabd89bc6'
|
||||
down_revision: Union[str, None] = '78fc200abac1'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('target_lists', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_target_lists_account_id', 'target_lists', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Primary: team_id → team admin user → account_id
|
||||
op.execute("""
|
||||
UPDATE target_lists tl
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE u.team_id = tl.team_id
|
||||
AND u.is_team_admin = TRUE
|
||||
AND u.account_id IS NOT NULL
|
||||
AND tl.account_id IS NULL
|
||||
""")
|
||||
|
||||
# Fallback: created_by → users.account_id
|
||||
op.execute("""
|
||||
UPDATE target_lists tl
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE tl.created_by = u.id
|
||||
AND u.account_id IS NOT NULL
|
||||
AND tl.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM target_lists WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} target_lists rows have NULL account_id. "
|
||||
"No team admin found for these teams. Resolve before re-running."
|
||||
)
|
||||
|
||||
op.alter_column('target_lists', 'account_id', nullable=False)
|
||||
op.create_index('ix_target_lists_account_id', 'target_lists', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_target_lists_account_id', table_name='target_lists')
|
||||
op.drop_constraint('fk_target_lists_account_id', 'target_lists', type_='foreignkey')
|
||||
op.drop_column('target_lists', 'account_id')
|
||||
@@ -0,0 +1,155 @@
|
||||
"""create template_trees and platform_steps global content tables
|
||||
|
||||
Revision ID: 3a40fe11b427
|
||||
Revises: 2c6aabd89bc6
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
|
||||
These tables hold platform-owned content that is readable by all
|
||||
authenticated users. No account_id. No RLS. Ever.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
revision: str = '3a40fe11b427'
|
||||
down_revision: Union[str, None] = '2c6aabd89bc6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Create template_trees ─────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
'template_trees',
|
||||
sa.Column('id', UUID(), primary_key=True),
|
||||
sa.Column('name', sa.String(255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('category', sa.String(100), nullable=True),
|
||||
sa.Column('tree_type', sa.String(20), nullable=False),
|
||||
sa.Column('tree_structure', JSONB(), nullable=False),
|
||||
sa.Column('tags', JSONB(), nullable=False, server_default='[]'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('source_tree_id', UUID(), sa.ForeignKey('trees.id', ondelete='SET NULL'), nullable=True),
|
||||
)
|
||||
op.create_index('ix_template_trees_tree_type', 'template_trees', ['tree_type'])
|
||||
|
||||
# ── Create platform_steps ────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
'platform_steps',
|
||||
sa.Column('id', UUID(), primary_key=True),
|
||||
sa.Column('title', sa.String(255), nullable=False),
|
||||
sa.Column('step_type', sa.String(50), nullable=False),
|
||||
sa.Column('content', JSONB(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('source_step_id', UUID(), sa.ForeignKey('step_library.id', ondelete='SET NULL'), nullable=True),
|
||||
)
|
||||
op.create_index('ix_platform_steps_step_type', 'platform_steps', ['step_type'])
|
||||
|
||||
# ── Copy is_default=TRUE trees → template_trees ─────────────────────────
|
||||
# Note: trees.tags is a relationship via tree_tags join table — no direct column.
|
||||
# Aggregate tag names via a correlated subquery.
|
||||
op.execute("""
|
||||
INSERT INTO template_trees
|
||||
(id, name, description, category, tree_type, tree_structure,
|
||||
tags, is_active, created_at, updated_at, source_tree_id)
|
||||
SELECT
|
||||
gen_random_uuid(), t.name, t.description, t.category, t.tree_type,
|
||||
t.tree_structure,
|
||||
COALESCE(
|
||||
(SELECT jsonb_agg(tt.name ORDER BY tt.name)
|
||||
FROM tree_tag_assignments ta
|
||||
JOIN tree_tags tt ON tt.id = ta.tag_id
|
||||
WHERE ta.tree_id = t.id),
|
||||
'[]'::jsonb
|
||||
),
|
||||
t.is_active,
|
||||
COALESCE(t.created_at, NOW()), COALESCE(t.updated_at, NOW()), t.id
|
||||
FROM trees t
|
||||
WHERE t.is_default = TRUE
|
||||
""")
|
||||
|
||||
# ── Copy visibility='public' steps → platform_steps ─────────────────────
|
||||
op.execute("""
|
||||
INSERT INTO platform_steps
|
||||
(id, title, step_type, content, is_active, created_at, updated_at, source_step_id)
|
||||
SELECT
|
||||
gen_random_uuid(), title, step_type, content, is_active,
|
||||
COALESCE(created_at, NOW()), COALESCE(updated_at, NOW()), id
|
||||
FROM step_library
|
||||
WHERE visibility = 'public'
|
||||
""")
|
||||
|
||||
# ── Create platform sentinel account ─────────────────────────────────────
|
||||
op.execute("""
|
||||
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
|
||||
VALUES (
|
||||
'00000000-0000-0000-0000-000000000001',
|
||||
'ResolutionFlow Platform',
|
||||
'PLATFORM',
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
""")
|
||||
|
||||
# ── Assign is_default trees to platform account ──────────────────────────
|
||||
op.execute("""
|
||||
UPDATE trees
|
||||
SET account_id = '00000000-0000-0000-0000-000000000001'
|
||||
WHERE is_default = TRUE
|
||||
AND account_id IS NULL
|
||||
""")
|
||||
|
||||
# ── Assign global categories/tags/steps to platform account ─────────────
|
||||
op.execute("""
|
||||
UPDATE tree_categories
|
||||
SET account_id = '00000000-0000-0000-0000-000000000001'
|
||||
WHERE account_id IS NULL
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
UPDATE tree_tags
|
||||
SET account_id = '00000000-0000-0000-0000-000000000001'
|
||||
WHERE account_id IS NULL
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
UPDATE step_categories
|
||||
SET account_id = '00000000-0000-0000-0000-000000000001'
|
||||
WHERE account_id IS NULL
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
UPDATE step_library
|
||||
SET account_id = '00000000-0000-0000-0000-000000000001'
|
||||
WHERE account_id IS NULL
|
||||
""")
|
||||
|
||||
# ── Verify zero NULLs in all 5 tables ───────────────────────────────────
|
||||
for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'):
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} NULL account_id rows remain in {table} "
|
||||
"after platform account assignment. Investigate before re-running."
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
platform_id = '00000000-0000-0000-0000-000000000001'
|
||||
for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'):
|
||||
op.execute(f"UPDATE {table} SET account_id = NULL WHERE account_id = '{platform_id}'")
|
||||
|
||||
op.execute(f"DELETE FROM accounts WHERE id = '{platform_id}'")
|
||||
op.drop_index('ix_platform_steps_step_type', table_name='platform_steps')
|
||||
op.drop_index('ix_template_trees_tree_type', table_name='template_trees')
|
||||
op.drop_table('platform_steps')
|
||||
op.drop_table('template_trees')
|
||||
@@ -0,0 +1,77 @@
|
||||
"""add account_id to AI branching tables
|
||||
|
||||
Revision ID: 478c159e5654
|
||||
Revises: cc214c63aa30
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '478c159e5654'
|
||||
down_revision: Union[str, None] = 'cc214c63aa30'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
ai_tables = ('session_branches', 'session_handoffs', 'fork_points', 'ai_session_steps')
|
||||
|
||||
# Step 1: ADD COLUMN (nullable)
|
||||
for table in ai_tables:
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
op.add_column('ai_suggestions', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_ai_suggestions_account_id', 'ai_suggestions', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Step 2: BACKFILL
|
||||
for table in ai_tables:
|
||||
op.execute(f"""
|
||||
UPDATE {table} t
|
||||
SET account_id = ai.account_id
|
||||
FROM ai_sessions ai
|
||||
WHERE t.session_id = ai.id
|
||||
AND t.account_id IS NULL
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
UPDATE ai_suggestions s
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE s.user_id = u.id
|
||||
AND s.account_id IS NULL
|
||||
""")
|
||||
|
||||
# Step 3: VERIFY zero NULLs
|
||||
for table in ai_tables + ('ai_suggestions',):
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} NULL account_id rows in {table}."
|
||||
)
|
||||
|
||||
# Step 4: SET NOT NULL
|
||||
for table in ai_tables + ('ai_suggestions',):
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
|
||||
# Step 5: CREATE INDEX
|
||||
for table in ai_tables + ('ai_suggestions',):
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('session_branches', 'session_handoffs', 'fork_points',
|
||||
'ai_session_steps', 'ai_suggestions'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -0,0 +1,46 @@
|
||||
"""add account_id to step_ratings and step_usage_log
|
||||
|
||||
Revision ID: 7167e9374b0c
|
||||
Revises: 478c159e5654
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '7167e9374b0c'
|
||||
down_revision: Union[str, None] = '478c159e5654'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in ('step_ratings', 'step_usage_log'):
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
# Backfill: from the RATER/LOGGER user's account (not the step's account)
|
||||
op.execute(f"""
|
||||
UPDATE {table} t
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE t.user_id = u.id
|
||||
AND t.account_id IS NULL
|
||||
""")
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('step_ratings', 'step_usage_log'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -0,0 +1,103 @@
|
||||
"""add account_id to script_builder_sessions, script_templates, script_generations
|
||||
|
||||
Revision ID: 78fc200abac1
|
||||
Revises: 7f136778f5a8
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '78fc200abac1'
|
||||
down_revision: Union[str, None] = '7f136778f5a8'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
PLATFORM_ACCOUNT_ID = '00000000-0000-0000-0000-000000000001'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Ensure the platform sentinel account exists before any fallback assignments.
|
||||
# Migration 3a40fe11b427 also inserts this with ON CONFLICT DO NOTHING — safe.
|
||||
op.execute(f"""
|
||||
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
|
||||
VALUES (
|
||||
'{PLATFORM_ACCOUNT_ID}',
|
||||
'ResolutionFlow Platform',
|
||||
'PLATFORM',
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
""")
|
||||
|
||||
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# script_builder_sessions: user_id → users.account_id
|
||||
op.execute("""
|
||||
UPDATE script_builder_sessions sbs
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE sbs.user_id = u.id
|
||||
AND sbs.account_id IS NULL
|
||||
""")
|
||||
|
||||
# script_templates: created_by → users.account_id (nullable created_by)
|
||||
op.execute("""
|
||||
UPDATE script_templates st
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE st.created_by = u.id
|
||||
AND st.account_id IS NULL
|
||||
""")
|
||||
# Fallback: team_id → team admin user
|
||||
op.execute("""
|
||||
UPDATE script_templates st
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE u.team_id = st.team_id
|
||||
AND u.is_team_admin = TRUE
|
||||
AND u.account_id IS NOT NULL
|
||||
AND st.account_id IS NULL
|
||||
""")
|
||||
# Final fallback: platform-seeded templates with NULL team_id AND NULL created_by
|
||||
# (e.g. the 6 AD templates inserted by migration 057) → platform sentinel account
|
||||
op.execute(f"""
|
||||
UPDATE script_templates
|
||||
SET account_id = '{PLATFORM_ACCOUNT_ID}'
|
||||
WHERE account_id IS NULL
|
||||
""")
|
||||
|
||||
# script_generations: user_id → users.account_id
|
||||
op.execute("""
|
||||
UPDATE script_generations sg
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE sg.user_id = u.id
|
||||
AND sg.account_id IS NULL
|
||||
""")
|
||||
|
||||
# VERIFY
|
||||
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
|
||||
|
||||
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -0,0 +1,62 @@
|
||||
"""add account_id to maintenance_schedules
|
||||
|
||||
Revision ID: 7f136778f5a8
|
||||
Revises: 8aac5b372402
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '7f136778f5a8'
|
||||
down_revision: Union[str, None] = '8aac5b372402'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('maintenance_schedules',
|
||||
sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_maintenance_schedules_account_id', 'maintenance_schedules', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Primary: tree_id → trees.account_id (only where tree.account_id is NOT NULL)
|
||||
op.execute("""
|
||||
UPDATE maintenance_schedules ms
|
||||
SET account_id = t.account_id
|
||||
FROM trees t
|
||||
WHERE ms.tree_id = t.id
|
||||
AND t.account_id IS NOT NULL
|
||||
AND ms.account_id IS NULL
|
||||
""")
|
||||
|
||||
# Fallback: created_by → users.account_id (for is_default trees with NULL account_id)
|
||||
op.execute("""
|
||||
UPDATE maintenance_schedules ms
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE ms.created_by = u.id
|
||||
AND u.account_id IS NOT NULL
|
||||
AND ms.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM maintenance_schedules WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} maintenance_schedules rows have NULL account_id. "
|
||||
"Check if created_by is NULL — those rows need manual resolution."
|
||||
)
|
||||
|
||||
op.alter_column('maintenance_schedules', 'account_id', nullable=False)
|
||||
op.create_index('ix_maintenance_schedules_account_id', 'maintenance_schedules', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_maintenance_schedules_account_id', table_name='maintenance_schedules')
|
||||
op.drop_constraint('fk_maintenance_schedules_account_id', 'maintenance_schedules', type_='foreignkey')
|
||||
op.drop_column('maintenance_schedules', 'account_id')
|
||||
@@ -0,0 +1,81 @@
|
||||
"""add account_id to PSA and notification tables
|
||||
|
||||
Revision ID: 8aac5b372402
|
||||
Revises: a1d2a84b9abb
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '8aac5b372402'
|
||||
down_revision: Union[str, None] = 'a1d2a84b9abb'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: ADD COLUMN
|
||||
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Step 2: BACKFILL
|
||||
# psa_post_log: prefer psa_connection → fallback to posted_by user
|
||||
# Note: cannot reference the updated table (ppl) inside the FROM clause JOIN,
|
||||
# so use a correlated subquery for psa_connections lookup instead.
|
||||
op.execute("""
|
||||
UPDATE psa_post_log ppl
|
||||
SET account_id = COALESCE(
|
||||
(SELECT account_id FROM psa_connections WHERE id = ppl.psa_connection_id),
|
||||
u.account_id
|
||||
)
|
||||
FROM users u
|
||||
WHERE ppl.posted_by = u.id
|
||||
AND ppl.account_id IS NULL
|
||||
""")
|
||||
|
||||
# psa_member_mappings: via psa_connection
|
||||
op.execute("""
|
||||
UPDATE psa_member_mappings pmm
|
||||
SET account_id = pc.account_id
|
||||
FROM psa_connections pc
|
||||
WHERE pmm.psa_connection_id = pc.id
|
||||
AND pmm.account_id IS NULL
|
||||
""")
|
||||
|
||||
# notification_logs: via notification_config
|
||||
op.execute("""
|
||||
UPDATE notification_logs nl
|
||||
SET account_id = nc.account_id
|
||||
FROM notification_configs nc
|
||||
WHERE nl.notification_config_id = nc.id
|
||||
AND nl.account_id IS NULL
|
||||
""")
|
||||
|
||||
# Step 3: VERIFY
|
||||
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
|
||||
|
||||
# Step 4: SET NOT NULL
|
||||
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
|
||||
# Step 5: CREATE INDEX
|
||||
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add account_id to user personalization tables
|
||||
|
||||
Revision ID: a1d2a84b9abb
|
||||
Revises: 7167e9374b0c
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = 'a1d2a84b9abb'
|
||||
down_revision: Union[str, None] = '7167e9374b0c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in ('user_folders', 'user_pinned_trees'):
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id', table, 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
op.execute(f"""
|
||||
UPDATE {table} t
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE t.user_id = u.id
|
||||
AND t.account_id IS NULL
|
||||
""")
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('user_folders', 'user_pinned_trees'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -0,0 +1,24 @@
|
||||
"""merge Phase 1 tenant isolation chain with main head
|
||||
|
||||
Revision ID: a9f3b2c1d4e5
|
||||
Revises: 070, 174f442795b7
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
|
||||
Merge migration: consolidates the Phase 1 account_id chain (cc214c63aa30 → … → 174f442795b7)
|
||||
with the main sequential chain (… → 070) into a single head so that
|
||||
`alembic upgrade head` works without ambiguity.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
revision: str = 'a9f3b2c1d4e5'
|
||||
down_revision: Union[str, tuple] = ('070', '174f442795b7')
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
108
backend/alembic/versions/c5f48b9890f9_enable_rls_phase1.py
Normal file
108
backend/alembic/versions/c5f48b9890f9_enable_rls_phase1.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""enable_rls_phase1
|
||||
|
||||
Revision ID: c5f48b9890f9
|
||||
Revises: 0b470d9e6cf1
|
||||
Create Date: 2026-04-10 04:01:13.043321
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c5f48b9890f9'
|
||||
down_revision: Union[str, None] = '0b470d9e6cf1'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
_CURRENT_ACCOUNT = (
|
||||
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
f"'{_NULL_UUID}')::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── trees ───────────────────────────────────────────────────────────────
|
||||
# Extended policy mirrors can_access_tree() in app/core/permissions.py.
|
||||
# Tenant sees: own rows, platform rows, any default tree, any public tree,
|
||||
# any gallery-featured tree.
|
||||
# is_gallery_featured = TRUE is included because /public/templates is a
|
||||
# no-auth endpoint — no tenant context is set, so gallery trees must pass
|
||||
# RLS on their own flag rather than relying on account_id or is_public.
|
||||
# Private/team trees from other accounts are hidden.
|
||||
op.execute("ALTER TABLE trees ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE trees FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON trees
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
OR is_default = TRUE
|
||||
OR is_public = TRUE
|
||||
OR is_gallery_featured = TRUE
|
||||
)
|
||||
""")
|
||||
|
||||
# ── tree_tags ────────────────────────────────────────────────────────────
|
||||
# Own account + platform tags (global tags visible to all tenants).
|
||||
op.execute("ALTER TABLE tree_tags ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE tree_tags FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON tree_tags
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
)
|
||||
""")
|
||||
|
||||
# ── tree_categories ──────────────────────────────────────────────────────
|
||||
op.execute("ALTER TABLE tree_categories ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE tree_categories FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON tree_categories
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
)
|
||||
""")
|
||||
|
||||
# ── step_categories ──────────────────────────────────────────────────────
|
||||
op.execute("ALTER TABLE step_categories ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE step_categories FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON step_categories
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
)
|
||||
""")
|
||||
|
||||
# ── psa_connections ──────────────────────────────────────────────────────
|
||||
# Tenant-only — PSA credentials must never cross tenant boundaries.
|
||||
op.execute("ALTER TABLE psa_connections ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE psa_connections FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON psa_connections
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
# ── flow_proposals ────────────────────────────────────────────────────────
|
||||
# Tenant-only.
|
||||
op.execute("ALTER TABLE flow_proposals ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE flow_proposals FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON flow_proposals
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ["trees", "tree_tags", "tree_categories", "step_categories",
|
||||
"psa_connections", "flow_proposals"]:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
@@ -0,0 +1,95 @@
|
||||
"""add account_id to core session tables
|
||||
|
||||
Revision ID: cc214c63aa30
|
||||
Revises: b8d2f4a6c091
|
||||
Create Date: 2026-04-09 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = 'cc214c63aa30'
|
||||
down_revision: Union[str, None] = '064'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = ('067',)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Step 1: ADD COLUMN (nullable) ────────────────────────────────────────
|
||||
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||
'session_resolution_outputs'):
|
||||
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
f'fk_{table}_account_id',
|
||||
table, 'accounts',
|
||||
['account_id'], ['id'],
|
||||
ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# ── Step 2: BACKFILL ─────────────────────────────────────────────────────
|
||||
# sessions: direct join to users
|
||||
op.execute("""
|
||||
UPDATE sessions s
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE s.user_id = u.id
|
||||
AND s.account_id IS NULL
|
||||
""")
|
||||
|
||||
# attachments: chain through sessions (now backfilled above)
|
||||
op.execute("""
|
||||
UPDATE attachments a
|
||||
SET account_id = s.account_id
|
||||
FROM sessions s
|
||||
WHERE a.session_id = s.id
|
||||
AND a.account_id IS NULL
|
||||
""")
|
||||
|
||||
# session_supporting_data: same chain
|
||||
op.execute("""
|
||||
UPDATE session_supporting_data sd
|
||||
SET account_id = s.account_id
|
||||
FROM sessions s
|
||||
WHERE sd.session_id = s.id
|
||||
AND sd.account_id IS NULL
|
||||
""")
|
||||
|
||||
# session_resolution_outputs: FK is to ai_sessions, not sessions
|
||||
op.execute("""
|
||||
UPDATE session_resolution_outputs sro
|
||||
SET account_id = ai.account_id
|
||||
FROM ai_sessions ai
|
||||
WHERE sro.session_id = ai.id
|
||||
AND sro.account_id IS NULL
|
||||
""")
|
||||
|
||||
# ── Step 3: VERIFY zero NULLs — raises if any remain ────────────────────
|
||||
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||
'session_resolution_outputs'):
|
||||
result = op.get_bind().execute(
|
||||
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} NULL account_id rows remain in {table}. "
|
||||
f"Fix the backfill before re-running."
|
||||
)
|
||||
|
||||
# ── Step 4: SET NOT NULL ─────────────────────────────────────────────────
|
||||
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||
'session_resolution_outputs'):
|
||||
op.alter_column(table, 'account_id', nullable=False)
|
||||
|
||||
# ── Step 5: CREATE INDEX ─────────────────────────────────────────────────
|
||||
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||
'session_resolution_outputs'):
|
||||
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||
'session_resolution_outputs'):
|
||||
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||
op.drop_column(table, 'account_id')
|
||||
@@ -10,6 +10,8 @@ from app.core.database import get_db
|
||||
from app.core.security import decode_token
|
||||
from app.models.user import User
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.core.tenant_context import set_current_account_id, clear_current_account_id
|
||||
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
|
||||
|
||||
# Routes that are allowed even when must_change_password is True
|
||||
_PASSWORD_CHANGE_ALLOWLIST = {
|
||||
@@ -192,18 +194,42 @@ async def get_plan_limits_for_user(
|
||||
return await get_user_plan_limits(current_user.account_id, db)
|
||||
|
||||
|
||||
async def get_tenant_context(
|
||||
async def require_tenant_context(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> UUID:
|
||||
"""Return the current user's account_id.
|
||||
):
|
||||
"""Set per-request tenant context for RLS.
|
||||
|
||||
Use this dependency instead of reading current_user.account_id directly.
|
||||
Raises 403 if the user has no account association (should not happen in
|
||||
normal flows — users are always associated with an account on registration).
|
||||
Raises 403 if the authenticated user has no account_id — never falls back
|
||||
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
|
||||
malformed account).
|
||||
|
||||
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
|
||||
issue set_config('app.current_account_id', …, true) on every transaction.
|
||||
|
||||
Applied to every user-facing router. NOT applied to /admin/* routers or
|
||||
public endpoints (auth, shared, webhooks).
|
||||
"""
|
||||
if current_user.account_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User not associated with any account",
|
||||
detail="User account required",
|
||||
)
|
||||
return current_user.account_id
|
||||
token = set_current_account_id(current_user.account_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
clear_current_account_id(token)
|
||||
|
||||
|
||||
async def require_admin_db(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
) -> AsyncSession:
|
||||
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
|
||||
|
||||
Use on /admin/* endpoints that query RLS-protected tables. Replaces
|
||||
Depends(get_db) on the db parameter of those endpoints.
|
||||
The current_user dep is still declared separately on the endpoint if
|
||||
the user object is needed in the handler.
|
||||
"""
|
||||
return db
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.audit import log_audit
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_password_hash, generate_temp_password, create_password_reset_token, decode_token, hash_token
|
||||
@@ -37,7 +37,7 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
@router.get("/users", response_model=list[UserResponse])
|
||||
async def list_users(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
@@ -74,7 +74,7 @@ def _generate_display_code() -> str:
|
||||
@router.post("/users", response_model=AdminUserCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
data: AdminUserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Create a new user with a temporary password (super admin only).
|
||||
@@ -199,7 +199,7 @@ async def create_user(
|
||||
@router.get("/users/{user_id}", response_model=UserDetailResponse)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Get enriched user details (super admin only)."""
|
||||
@@ -317,7 +317,7 @@ async def get_user(
|
||||
async def update_user_role(
|
||||
user_id: UUID,
|
||||
role_data: RoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Change user role (super admin only)."""
|
||||
@@ -349,7 +349,7 @@ async def update_user_role(
|
||||
async def update_account_role(
|
||||
user_id: UUID,
|
||||
data: AccountRoleUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Change a user's account role (super admin only)."""
|
||||
@@ -375,7 +375,7 @@ async def update_account_role(
|
||||
async def update_super_admin_status(
|
||||
user_id: UUID,
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Promote or demote a user to/from super admin (super admin only)."""
|
||||
@@ -414,7 +414,7 @@ async def update_super_admin_status(
|
||||
@router.put("/users/{user_id}/deactivate", response_model=UserResponse)
|
||||
async def deactivate_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Deactivate a user account (super admin only)."""
|
||||
@@ -443,7 +443,7 @@ async def deactivate_user(
|
||||
@router.put("/users/{user_id}/activate", response_model=UserResponse)
|
||||
async def activate_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)]
|
||||
):
|
||||
"""Reactivate a user account (super admin only)."""
|
||||
@@ -467,7 +467,7 @@ async def activate_user(
|
||||
async def move_user_account(
|
||||
user_id: UUID,
|
||||
data: MoveUserAccount,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Move a user to a different account (super admin only)."""
|
||||
@@ -520,7 +520,7 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
|
||||
async def update_user_plan(
|
||||
user_id: UUID,
|
||||
data: SubscriptionPlanUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Change a user's subscription plan (super admin only)."""
|
||||
@@ -539,7 +539,7 @@ async def update_user_plan(
|
||||
async def extend_user_trial(
|
||||
user_id: UUID,
|
||||
data: ExtendTrialRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Extend or start a trial for a user's subscription (super admin only)."""
|
||||
@@ -569,7 +569,7 @@ async def extend_user_trial(
|
||||
async def admin_reset_password(
|
||||
user_id: UUID,
|
||||
data: AdminPasswordReset,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Admin-triggered password reset (super admin only).
|
||||
@@ -640,7 +640,7 @@ async def admin_reset_password(
|
||||
@router.put("/users/{user_id}/archive", response_model=UserResponse)
|
||||
async def archive_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Archive (soft delete) a user (super admin only)."""
|
||||
@@ -675,7 +675,7 @@ async def archive_user(
|
||||
@router.put("/users/{user_id}/restore", response_model=UserResponse)
|
||||
async def restore_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Restore an archived user (super admin only)."""
|
||||
@@ -700,7 +700,7 @@ async def restore_user(
|
||||
@router.get("/users/{user_id}/hard-delete-check", response_model=HardDeleteCheckResponse)
|
||||
async def hard_delete_check(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Check if a user can be hard-deleted (super admin only). Returns blockers."""
|
||||
@@ -773,7 +773,7 @@ async def hard_delete_check(
|
||||
@router.delete("/users/{user_id}/hard-delete", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def hard_delete_user(
|
||||
user_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Permanently delete a user (super admin only). User must be archived first."""
|
||||
@@ -833,7 +833,7 @@ async def hard_delete_user(
|
||||
@router.post("/invites", status_code=status.HTTP_201_CREATED)
|
||||
async def admin_create_invite(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Quick-invite a user to an account (super admin only).
|
||||
|
||||
@@ -4,25 +4,26 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.audit import log_audit
|
||||
from app.models.user import User
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.tree import Tree
|
||||
from app.schemas.admin import GlobalCategoryCreate, GlobalCategoryUpdate, GlobalCategoryResponse
|
||||
from app.api.deps import require_admin
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
router = APIRouter(prefix="/admin/categories", tags=["admin-categories"])
|
||||
|
||||
|
||||
@router.get("/global", response_model=list[GlobalCategoryResponse])
|
||||
async def list_global_categories(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List all global categories (account_id IS NULL)."""
|
||||
result = await db.execute(
|
||||
select(TreeCategory).where(TreeCategory.account_id.is_(None)).order_by(TreeCategory.name)
|
||||
select(TreeCategory).where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID).order_by(TreeCategory.name)
|
||||
)
|
||||
categories = result.scalars().all()
|
||||
|
||||
@@ -45,36 +46,36 @@ async def list_global_categories(
|
||||
@router.post("/global", response_model=GlobalCategoryResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_global_category(
|
||||
data: GlobalCategoryCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Create a global category."""
|
||||
# Check slug uniqueness for global categories
|
||||
existing = await db.execute(
|
||||
select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None))
|
||||
select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Global category with this slug already exists")
|
||||
|
||||
category = TreeCategory(name=data.name, slug=data.slug, account_id=None)
|
||||
category = TreeCategory(name=data.name, slug=data.slug, account_id=PLATFORM_ACCOUNT_ID)
|
||||
db.add(category)
|
||||
await log_audit(db, current_user.id, "global_category.create", "category", details={"name": data.name})
|
||||
await db.commit()
|
||||
await db.refresh(category)
|
||||
|
||||
return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=None, tree_count=0)
|
||||
return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=PLATFORM_ACCOUNT_ID, tree_count=0)
|
||||
|
||||
|
||||
@router.put("/global/{category_id}", response_model=GlobalCategoryResponse)
|
||||
async def update_global_category(
|
||||
category_id: UUID,
|
||||
data: GlobalCategoryUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update a global category."""
|
||||
result = await db.execute(
|
||||
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None))
|
||||
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
|
||||
)
|
||||
category = result.scalar_one_or_none()
|
||||
if not category:
|
||||
@@ -86,7 +87,7 @@ async def update_global_category(
|
||||
# Check slug uniqueness
|
||||
existing = await db.execute(
|
||||
select(TreeCategory).where(
|
||||
TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None), TreeCategory.id != category_id
|
||||
TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID, TreeCategory.id != category_id
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
@@ -103,19 +104,19 @@ async def update_global_category(
|
||||
|
||||
return GlobalCategoryResponse(
|
||||
id=category.id, name=category.name, slug=category.slug,
|
||||
account_id=None, tree_count=tree_count,
|
||||
account_id=PLATFORM_ACCOUNT_ID, tree_count=tree_count,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/global/{category_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_global_category(
|
||||
category_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Delete (archive) a global category."""
|
||||
result = await db.execute(
|
||||
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None))
|
||||
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
|
||||
)
|
||||
category = result.scalar_one_or_none()
|
||||
if not category:
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.user import User
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.tree import Tree
|
||||
@@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin/dashboard", tags=["admin-dashboard"])
|
||||
|
||||
@router.get("/metrics", response_model=DashboardMetrics)
|
||||
async def get_dashboard_metrics(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Get platform overview metrics."""
|
||||
@@ -45,7 +45,7 @@ async def get_dashboard_metrics(
|
||||
|
||||
@router.get("/activity", response_model=list[ActivityEntry])
|
||||
async def get_dashboard_activity(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Get recent audit log entries for activity feed."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import require_admin
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.script_template import ScriptTemplate
|
||||
from app.models.tree import Tree
|
||||
from app.models.user import User
|
||||
@@ -66,7 +66,7 @@ def _script_summary(script: ScriptTemplate) -> dict:
|
||||
|
||||
@router.get("/featured")
|
||||
async def list_featured(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List all featured flows and scripts (super admin only)."""
|
||||
@@ -92,7 +92,7 @@ async def list_featured(
|
||||
|
||||
@router.get("/items")
|
||||
async def list_all_items(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""List ALL flows and scripts with their gallery status (super admin only)."""
|
||||
@@ -119,7 +119,7 @@ async def list_all_items(
|
||||
async def toggle_flow_featured(
|
||||
flow_id: UUID,
|
||||
body: FeatureToggle,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Toggle is_gallery_featured on a flow (super admin only)."""
|
||||
@@ -138,7 +138,7 @@ async def toggle_flow_featured(
|
||||
async def update_flow_sort_order(
|
||||
flow_id: UUID,
|
||||
body: SortOrderUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update gallery_sort_order on a flow (super admin only)."""
|
||||
@@ -157,7 +157,7 @@ async def update_flow_sort_order(
|
||||
async def toggle_script_featured(
|
||||
script_id: UUID,
|
||||
body: FeatureToggle,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Toggle is_gallery_featured on a script (super admin only)."""
|
||||
@@ -176,7 +176,7 @@ async def toggle_script_featured(
|
||||
async def update_script_sort_order(
|
||||
script_id: UUID,
|
||||
body: SortOrderUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update gallery_sort_order on a script (super admin only)."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.models.user import User
|
||||
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_category, can_create_category
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
from app.core.filters import tenant_filter
|
||||
|
||||
router = APIRouter(prefix="/categories", tags=["categories"])
|
||||
@@ -48,13 +49,13 @@ async def list_categories(
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeCategory.account_id.is_(None), # Global
|
||||
TreeCategory.account_id == PLATFORM_ACCOUNT_ID, # Global
|
||||
TreeCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no account, only show global categories
|
||||
query = query.where(TreeCategory.account_id.is_(None))
|
||||
query = query.where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
|
||||
|
||||
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
|
||||
|
||||
@@ -176,7 +177,7 @@ async def create_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
account_id=category_data.account_id,
|
||||
account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
|
||||
@@ -197,6 +197,7 @@ async def create_template(
|
||||
template = ScriptTemplate(
|
||||
category_id=data.category_id,
|
||||
team_id=current_user.team_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
name=data.name,
|
||||
slug=slug,
|
||||
@@ -364,6 +365,7 @@ async def generate_script(
|
||||
generation = ScriptGeneration(
|
||||
template_id=template.id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
session_id=data.session_id,
|
||||
ai_session_id=data.ai_session_id,
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.schemas.step_category import (
|
||||
)
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_step_category, can_create_step_category
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
router = APIRouter(prefix="/step-categories", tags=["step-categories"])
|
||||
|
||||
@@ -44,13 +45,13 @@ async def list_step_categories(
|
||||
elif current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
StepCategory.account_id.is_(None), # Global
|
||||
StepCategory.account_id == PLATFORM_ACCOUNT_ID, # Global
|
||||
StepCategory.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User has no account, only show global categories
|
||||
query = query.where(StepCategory.account_id.is_(None))
|
||||
query = query.where(StepCategory.account_id == PLATFORM_ACCOUNT_ID)
|
||||
|
||||
query = query.order_by(StepCategory.display_order, StepCategory.name)
|
||||
|
||||
@@ -155,7 +156,7 @@ async def create_step_category(
|
||||
name=category_data.name,
|
||||
slug=slug,
|
||||
description=category_data.description,
|
||||
account_id=category_data.account_id,
|
||||
account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID,
|
||||
display_order=max_order + 1,
|
||||
created_by=current_user.id
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.models.user import User
|
||||
from app.schemas.tag import TagCreate, TagResponse, TagListResponse, TagAssignment
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.permissions import can_manage_tree_tags, can_create_tag
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
router = APIRouter(prefix="/tags", tags=["tags"])
|
||||
|
||||
@@ -33,13 +34,13 @@ async def list_tags(
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.account_id.is_(None), # Global
|
||||
TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global
|
||||
TreeTag.account_id == current_user.account_id # User's account
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only show global tags
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID)
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
|
||||
|
||||
@@ -71,12 +72,12 @@ async def search_tags(
|
||||
if include_account and current_user.account_id:
|
||||
query = query.where(
|
||||
or_(
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == PLATFORM_ACCOUNT_ID,
|
||||
TreeTag.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.where(TreeTag.account_id.is_(None))
|
||||
query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID)
|
||||
|
||||
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
|
||||
|
||||
@@ -147,7 +148,7 @@ async def create_tag(
|
||||
new_tag = TreeTag(
|
||||
name=tag_data.name,
|
||||
slug=slug,
|
||||
account_id=tag_data.account_id,
|
||||
account_id=tag_data.account_id if tag_data.account_id is not None else PLATFORM_ACCOUNT_ID,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.add(new_tag)
|
||||
@@ -206,7 +207,7 @@ async def add_tags_to_tree(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.account_id.is_(None), # Global tag
|
||||
TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global tag
|
||||
TreeTag.account_id == tag_account_id # Account tag
|
||||
)
|
||||
)
|
||||
@@ -340,7 +341,7 @@ async def replace_tree_tags(
|
||||
tag_query = select(TreeTag).where(
|
||||
TreeTag.slug == slug,
|
||||
or_(
|
||||
TreeTag.account_id.is_(None),
|
||||
TreeTag.account_id == PLATFORM_ACCOUNT_ID,
|
||||
TreeTag.account_id == tag_account_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from app.core.subscriptions import check_tree_limit, get_account_subscription, g
|
||||
from app.core.audit import log_audit
|
||||
from app.core.config import settings
|
||||
from app.core.tree_validation import can_publish_tree
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
from app.core.step_sync import sync_steps_from_tree, deactivate_synced_steps_for_tree
|
||||
from app.services.rag_service import index_tree as rag_index_tree
|
||||
|
||||
@@ -391,6 +392,7 @@ async def get_tree(
|
||||
)
|
||||
|
||||
if not tree.is_active or not can_access_tree(current_user, tree):
|
||||
# Always 404, never 403. A 403 confirms the resource exists.
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
@@ -470,7 +472,7 @@ async def create_tree(
|
||||
tree_structure=tree_data.tree_structure,
|
||||
intake_form=intake_form_data,
|
||||
author_id=service_account_id if is_default else current_user.id,
|
||||
account_id=None if is_default else current_user.account_id,
|
||||
account_id=PLATFORM_ACCOUNT_ID if is_default else current_user.account_id,
|
||||
is_public=True if is_default else tree_data.is_public, # Default trees are always public
|
||||
is_default=is_default,
|
||||
status=tree_data.status
|
||||
|
||||
@@ -1,51 +1,89 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, sidebar, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares, shared, tree_markdown
|
||||
from app.api.endpoints import admin_dashboard, admin_audit, admin_plan_limits, admin_feature_flags, admin_settings, admin_categories
|
||||
from app.api.endpoints import ratings, analytics
|
||||
from app.api.endpoints import target_lists
|
||||
from app.api.endpoints import maintenance_schedules
|
||||
from app.api.endpoints import feedback
|
||||
from app.api.endpoints import ai_builder
|
||||
from app.api.endpoints import ai_fix
|
||||
from app.api.endpoints import ai_chat
|
||||
from app.api.endpoints import copilot
|
||||
from app.api.endpoints import assistant_chat
|
||||
from app.api.endpoints import survey
|
||||
from app.api.endpoints import admin_survey
|
||||
from app.api.endpoints import tree_transfer
|
||||
from app.api.endpoints import ai_suggestions
|
||||
from app.api.endpoints import kb_accelerator
|
||||
from app.api.endpoints import beta_signup
|
||||
from app.api.endpoints import scripts
|
||||
from app.api.endpoints import integrations
|
||||
from app.api.endpoints import onboarding
|
||||
from app.api.endpoints import branding
|
||||
from app.api.endpoints import supporting_data
|
||||
from app.api.endpoints import ai_sessions
|
||||
from app.api.endpoints import flow_proposals
|
||||
from app.api.endpoints import flowpilot_analytics
|
||||
from app.api.endpoints import notifications
|
||||
from app.api.endpoints import public_templates
|
||||
from app.api.endpoints import admin_gallery
|
||||
from app.api.endpoints import uploads
|
||||
from app.api.endpoints import script_builder
|
||||
from app.api.endpoints import beta_feedback
|
||||
from app.api.endpoints import session_branches
|
||||
from app.api.endpoints import session_handoffs
|
||||
from app.api.endpoints import session_resolutions
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.deps import require_tenant_context
|
||||
from app.api.endpoints import (
|
||||
admin,
|
||||
admin_audit,
|
||||
admin_categories,
|
||||
admin_dashboard,
|
||||
admin_feature_flags,
|
||||
admin_gallery,
|
||||
admin_plan_limits,
|
||||
admin_settings,
|
||||
admin_survey,
|
||||
ai_builder,
|
||||
ai_chat,
|
||||
ai_fix,
|
||||
ai_sessions,
|
||||
ai_suggestions,
|
||||
analytics,
|
||||
assistant_chat,
|
||||
auth,
|
||||
beta_feedback,
|
||||
beta_signup,
|
||||
branding,
|
||||
categories,
|
||||
copilot,
|
||||
feedback,
|
||||
flow_proposals,
|
||||
flowpilot_analytics,
|
||||
folders,
|
||||
integrations,
|
||||
invite,
|
||||
kb_accelerator,
|
||||
maintenance_schedules,
|
||||
notifications,
|
||||
onboarding,
|
||||
public_templates,
|
||||
ratings,
|
||||
scripts,
|
||||
script_builder,
|
||||
session_branches,
|
||||
session_handoffs,
|
||||
session_resolutions,
|
||||
sessions,
|
||||
shared,
|
||||
shares,
|
||||
sidebar,
|
||||
step_categories,
|
||||
steps,
|
||||
supporting_data,
|
||||
survey,
|
||||
tags,
|
||||
target_lists,
|
||||
tree_markdown,
|
||||
tree_transfer,
|
||||
trees,
|
||||
uploads,
|
||||
webhooks,
|
||||
accounts,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public / unauthenticated endpoints — no tenant context
|
||||
#
|
||||
# Note: auth.router contains both public endpoints (register, login,
|
||||
# forgot-password, reset-password, email/verify) and authenticated endpoints
|
||||
# (GET/PATCH /me, logout, change-password, email/send-verification).
|
||||
# The authenticated auth endpoints only query the `users` table, which is
|
||||
# excluded from Phase 1 RLS. They work correctly without tenant context
|
||||
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
||||
# ---------------------------------------------------------------------------
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(trees.router)
|
||||
api_router.include_router(sidebar.router)
|
||||
api_router.include_router(sessions.router)
|
||||
api_router.include_router(invite.router)
|
||||
api_router.include_router(categories.router)
|
||||
api_router.include_router(tags.router)
|
||||
api_router.include_router(folders.router)
|
||||
api_router.include_router(step_categories.router)
|
||||
api_router.include_router(steps.router)
|
||||
api_router.include_router(shared.router) # Public share links (no auth)
|
||||
api_router.include_router(beta_signup.router)
|
||||
api_router.include_router(webhooks.router) # Stripe webhook receiver
|
||||
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin endpoints — super_admin only
|
||||
# admin_categories, admin_gallery, admin_dashboard, admin query Phase 1 RLS
|
||||
# tables and MUST use get_admin_db (migrated in Task 8). The remaining admin
|
||||
# endpoints (admin_audit, admin_plan_limits, admin_feature_flags,
|
||||
# admin_settings, admin_survey) are safe until Phase 2 extends RLS.
|
||||
# ---------------------------------------------------------------------------
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(admin_dashboard.router)
|
||||
api_router.include_router(admin_audit.router)
|
||||
@@ -53,42 +91,54 @@ api_router.include_router(admin_plan_limits.router)
|
||||
api_router.include_router(admin_feature_flags.router)
|
||||
api_router.include_router(admin_settings.router)
|
||||
api_router.include_router(admin_categories.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
api_router.include_router(shares.router)
|
||||
api_router.include_router(shared.router) # Public endpoints (no auth)
|
||||
api_router.include_router(tree_markdown.router)
|
||||
api_router.include_router(ratings.router)
|
||||
api_router.include_router(analytics.router)
|
||||
api_router.include_router(target_lists.router)
|
||||
api_router.include_router(maintenance_schedules.router)
|
||||
api_router.include_router(feedback.router)
|
||||
api_router.include_router(ai_builder.router)
|
||||
api_router.include_router(ai_fix.router)
|
||||
api_router.include_router(ai_chat.router)
|
||||
api_router.include_router(copilot.router)
|
||||
api_router.include_router(assistant_chat.router)
|
||||
api_router.include_router(survey.router)
|
||||
api_router.include_router(admin_survey.router)
|
||||
api_router.include_router(tree_transfer.router)
|
||||
api_router.include_router(ai_suggestions.router)
|
||||
api_router.include_router(kb_accelerator.router)
|
||||
api_router.include_router(beta_signup.router)
|
||||
api_router.include_router(scripts.router)
|
||||
api_router.include_router(integrations.router)
|
||||
api_router.include_router(onboarding.router)
|
||||
api_router.include_router(branding.router)
|
||||
api_router.include_router(supporting_data.router)
|
||||
api_router.include_router(session_handoffs.queue_router) # Must be before ai_sessions to avoid /{session_id} conflict
|
||||
api_router.include_router(session_resolutions.router) # Must be before ai_sessions to avoid /{session_id} conflict
|
||||
api_router.include_router(ai_sessions.router)
|
||||
api_router.include_router(flow_proposals.router)
|
||||
api_router.include_router(flowpilot_analytics.router)
|
||||
api_router.include_router(notifications.router)
|
||||
api_router.include_router(public_templates.router)
|
||||
api_router.include_router(admin_gallery.router)
|
||||
api_router.include_router(uploads.router)
|
||||
api_router.include_router(script_builder.router)
|
||||
api_router.include_router(beta_feedback.router)
|
||||
api_router.include_router(session_branches.router)
|
||||
api_router.include_router(session_handoffs.router)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-facing endpoints — tenant context required
|
||||
# ---------------------------------------------------------------------------
|
||||
_tenant_deps = [Depends(require_tenant_context)]
|
||||
|
||||
api_router.include_router(trees.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(sidebar.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(sessions.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(invite.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(categories.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(tags.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(folders.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(step_categories.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(steps.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(accounts.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(shares.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(tree_markdown.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ratings.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(analytics.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(target_lists.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(feedback.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ai_builder.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ai_fix.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ai_chat.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(copilot.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(assistant_chat.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(survey.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(tree_transfer.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(scripts.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(integrations.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(branding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
|
||||
# session_handoffs queue router must come before ai_sessions to avoid conflict
|
||||
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(ai_sessions.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(flow_proposals.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(notifications.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(uploads.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(script_builder.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)
|
||||
|
||||
36
backend/app/core/admin_database.py
Normal file
36
backend/app/core/admin_database.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# backend/app/core/admin_database.py
|
||||
"""
|
||||
Admin database engine — connects as resolutionflow_admin (BYPASSRLS).
|
||||
|
||||
Use ONLY for /admin/* endpoints and internal tooling.
|
||||
Never use this engine from user-facing endpoints.
|
||||
"""
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
admin_engine = create_async_engine(
|
||||
settings.ADMIN_DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
_admin_session_factory = async_sessionmaker(
|
||||
admin_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints."""
|
||||
async with _admin_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -23,10 +23,33 @@ class Settings(BaseSettings):
|
||||
return v.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
return v
|
||||
|
||||
@property
|
||||
def DATABASE_URL_SYNC(self) -> str:
|
||||
"""Get sync URL by removing asyncpg prefix from DATABASE_URL."""
|
||||
return self.DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://", 1)
|
||||
# Sync URL for Alembic migrations. Defaults to DATABASE_URL (sync-converted).
|
||||
# Set explicitly in .env to use a different role for migrations (e.g. superuser)
|
||||
# when DATABASE_URL has been switched to the app role.
|
||||
DATABASE_URL_SYNC: str = ""
|
||||
|
||||
@field_validator("DATABASE_URL_SYNC", mode="before")
|
||||
@classmethod
|
||||
def default_database_url_sync(cls, v: str, info) -> str:
|
||||
"""Fall back to sync-converted DATABASE_URL if not explicitly set."""
|
||||
if not v:
|
||||
base = info.data.get("DATABASE_URL", "")
|
||||
return base.replace("postgresql+asyncpg://", "postgresql://", 1)
|
||||
return v
|
||||
|
||||
# Admin database — resolutionflow_admin role, BYPASSRLS.
|
||||
# Used by /admin/* endpoints. Defaults to DATABASE_URL for local dev.
|
||||
ADMIN_DATABASE_URL: str = ""
|
||||
|
||||
@field_validator("ADMIN_DATABASE_URL", mode="before")
|
||||
@classmethod
|
||||
def default_admin_database_url(cls, v: str, info) -> str:
|
||||
"""Fall back to DATABASE_URL if ADMIN_DATABASE_URL is not set."""
|
||||
if not v:
|
||||
return info.data.get("DATABASE_URL", "")
|
||||
if v.startswith("postgresql://"):
|
||||
return v.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
return v
|
||||
|
||||
# JWT Settings
|
||||
SECRET_KEY: str = _DEFAULT_SECRET_KEY
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from .config import settings
|
||||
from app.core.tenant_context import register_tenant_listener
|
||||
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
@@ -16,6 +17,11 @@ async_session_maker = async_sessionmaker(
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# Register the RLS tenant context listener on the app engine.
|
||||
# Fires at the start of every transaction; issues set_config automatically.
|
||||
# Must NOT be called on admin_engine — admin connections bypass RLS.
|
||||
register_tenant_listener(engine)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
|
||||
@@ -18,6 +18,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
||||
SERVICE_ACCOUNT_NAME = "ResolutionFlow"
|
||||
|
||||
# Well-known UUID for the platform account — owns all default/global content.
|
||||
# Created by migration 3a40fe11b427_create_global_content_tables.
|
||||
PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
SYSTEM_ACCOUNT_NAME = "ResolutionFlow System"
|
||||
SYSTEM_ACCOUNT_DISPLAY_CODE = "RF-SYS-1"
|
||||
|
||||
|
||||
92
backend/app/core/tenant_context.py
Normal file
92
backend/app/core/tenant_context.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# backend/app/core/tenant_context.py
|
||||
"""
|
||||
Per-request tenant context for row-level security.
|
||||
|
||||
Flow:
|
||||
1. require_tenant_context (FastAPI dep) calls set_current_account_id().
|
||||
2. The SQLAlchemy transaction-begin listener fires on every new transaction
|
||||
and calls set_config('app.current_account_id', <id>, true) automatically.
|
||||
3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE)
|
||||
to filter rows.
|
||||
|
||||
The ContextVar is asyncio-task-scoped: each concurrent request has its own value.
|
||||
set_config with is_local=true is transaction-scoped: it resets on COMMIT or
|
||||
ROLLBACK, so the listener re-applies it at the start of every transaction.
|
||||
"""
|
||||
import contextvars
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import event, or_, text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
# One slot per async task — each concurrent request gets its own value.
|
||||
_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"current_account_id", default=None
|
||||
)
|
||||
|
||||
# Platform account — global content visible to all tenants.
|
||||
PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
def set_current_account_id(account_id: UUID) -> contextvars.Token:
|
||||
"""Set tenant context for the current request coroutine.
|
||||
|
||||
Returns a token so the caller can reset it after the request.
|
||||
"""
|
||||
return _current_account_id.set(str(account_id))
|
||||
|
||||
|
||||
def clear_current_account_id(token: contextvars.Token) -> None:
|
||||
"""Reset the ContextVar to its previous value (call in finally block)."""
|
||||
_current_account_id.reset(token)
|
||||
|
||||
|
||||
def get_current_account_id() -> str | None:
|
||||
"""Return the account_id string for the current request, or None."""
|
||||
return _current_account_id.get()
|
||||
|
||||
|
||||
def register_tenant_listener(engine: AsyncEngine) -> None:
|
||||
"""Register the transaction-begin listener on the given engine.
|
||||
|
||||
Must be called once at application startup, AFTER the engine is created.
|
||||
The listener issues set_config() at the start of every transaction so that
|
||||
the setting is re-applied automatically even when a request commits
|
||||
mid-flight and starts a new transaction.
|
||||
|
||||
Do NOT call this on admin_engine — admin connections must never set tenant
|
||||
context automatically.
|
||||
"""
|
||||
|
||||
@event.listens_for(engine.sync_engine, "begin")
|
||||
def _on_transaction_begin(conn) -> None: # noqa: ANN001
|
||||
account_id = _current_account_id.get()
|
||||
if account_id:
|
||||
# set_config(name, value, is_local=true) ≡ SET LOCAL.
|
||||
# Unlike SET LOCAL, set_config IS parameterisable.
|
||||
conn.execute(
|
||||
text("SELECT set_config('app.current_account_id', :id, true)"),
|
||||
{"id": account_id},
|
||||
)
|
||||
# If no account_id is set, do nothing. The RLS policy falls back to a
|
||||
# null-matching UUID and returns zero rows — fail-closed behaviour.
|
||||
|
||||
|
||||
def tenant_filter(Model, current_user: "User"): # noqa: ANN001
|
||||
"""SQLAlchemy filter clause for tables that contain platform-owned rows.
|
||||
|
||||
Use for: tree_tags, tree_categories, step_categories, step_library,
|
||||
template_trees, platform_steps.
|
||||
|
||||
For tenant-only tables (trees, sessions, psa_connections, etc.) use:
|
||||
Model.account_id == current_user.account_id
|
||||
directly.
|
||||
"""
|
||||
return or_(
|
||||
Model.account_id == current_user.account_id,
|
||||
Model.account_id == PLATFORM_ACCOUNT_ID,
|
||||
)
|
||||
@@ -54,6 +54,8 @@ from .session_branch import SessionBranch
|
||||
from .fork_point import ForkPoint
|
||||
from .session_handoff import SessionHandoff
|
||||
from .session_resolution_output import SessionResolutionOutput
|
||||
from .template_tree import TemplateTree
|
||||
from .platform_step import PlatformStep
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -122,4 +124,6 @@ __all__ = [
|
||||
"ForkPoint",
|
||||
"SessionHandoff",
|
||||
"SessionResolutionOutput",
|
||||
"TemplateTree",
|
||||
"PlatformStep",
|
||||
]
|
||||
|
||||
@@ -50,6 +50,13 @@ class AISessionStep(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Denormalized from ai_sessions.account_id for direct tenant filtering.",
|
||||
)
|
||||
step_order: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False,
|
||||
comment="Sequential position in the session (0-indexed)",
|
||||
|
||||
@@ -28,6 +28,12 @@ class AISuggestion(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
session_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("ai_chat_sessions.id", ondelete="SET NULL"),
|
||||
|
||||
@@ -20,6 +20,12 @@ class Attachment(Base):
|
||||
ForeignKey("sessions.id"),
|
||||
nullable=False
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
||||
|
||||
@@ -39,10 +39,10 @@ class TreeCategory(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
@@ -11,7 +10,7 @@ class Feedback(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="SET NULL"), nullable=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
feedback_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
|
||||
@@ -46,6 +46,12 @@ class UserFolder(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
color: Mapped[str] = mapped_column(String(7), nullable=False, default="#6366f1")
|
||||
icon: Mapped[str] = mapped_column(String(50), nullable=False, default="folder")
|
||||
|
||||
@@ -23,6 +23,12 @@ class ForkPoint(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
parent_branch_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=False)
|
||||
trigger_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True)
|
||||
fork_reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
@@ -23,6 +23,12 @@ class MaintenanceSchedule(Base):
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
cron_expression: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
timezone: Mapped[str] = mapped_column(String(100), nullable=False, default="UTC")
|
||||
target_list_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
|
||||
@@ -31,6 +31,12 @@ class NotificationLog(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
event: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
payload: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="sent")
|
||||
|
||||
37
backend/app/models/platform_step.py
Normal file
37
backend/app/models/platform_step.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Platform step model — platform-owned steps, readable by all users.
|
||||
|
||||
No account_id. No RLS. Readable by any authenticated user.
|
||||
Populated by promoting visibility='public' steps from step_library.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class PlatformStep(Base):
|
||||
__tablename__ = "platform_steps"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
step_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
source_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("step_library.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -25,6 +25,12 @@ class PsaMemberMapping(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -35,6 +35,12 @@ class PsaPostLog(Base):
|
||||
ForeignKey("psa_connections.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
ticket_id: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
note_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
content_posted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
@@ -29,6 +29,12 @@ class ScriptBuilderSession(Base):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("teams.id", ondelete="SET NULL"),
|
||||
|
||||
@@ -44,6 +44,12 @@ class ScriptTemplate(Base):
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
@@ -97,6 +103,12 @@ class ScriptGeneration(Base):
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
|
||||
@@ -31,6 +31,12 @@ class Session(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||
path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list)
|
||||
decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list)
|
||||
|
||||
@@ -35,6 +35,12 @@ class SessionBranch(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
parent_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=True)
|
||||
fork_point_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True)
|
||||
branch_order: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
@@ -27,6 +27,12 @@ class SessionHandoff(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
handed_off_by: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
intent: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
source_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="SET NULL"), nullable=True)
|
||||
|
||||
@@ -23,6 +23,12 @@ class SessionResolutionOutput(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
output_type: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
generated_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}")
|
||||
|
||||
@@ -38,10 +38,10 @@ class StepCategory(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
@@ -46,10 +46,10 @@ class StepLibrary(Base):
|
||||
ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
@@ -143,6 +143,13 @@ class StepRating(Base):
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Account of the RATER (not the step owner).",
|
||||
)
|
||||
rating: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
was_helpful: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
review_text: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
@@ -187,6 +194,13 @@ class StepUsageLog(Base):
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Account of the user who logged this usage.",
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sessions.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -14,6 +14,12 @@ class SessionSupportingData(Base):
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
@@ -51,10 +51,10 @@ class TreeTag(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.core.database import Base
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
class TargetList(Base):
|
||||
@@ -21,6 +22,12 @@ class TargetList(Base):
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
40
backend/app/models/template_tree.py
Normal file
40
backend/app/models/template_tree.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Template tree model — platform-owned troubleshooting trees, readable by all users.
|
||||
|
||||
No account_id. No RLS. Readable by any authenticated user.
|
||||
Populated by promoting is_default=TRUE trees from the trees table.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TemplateTree(Base):
|
||||
__tablename__ = "template_trees"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||
tree_type: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
|
||||
tree_structure: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||
tags: Mapped[list] = mapped_column(JSONB, nullable=False, default=list)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
source_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -76,10 +76,10 @@ class Tree(Base):
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
@@ -37,10 +37,10 @@ class TreeEmbedding(Base):
|
||||
ForeignKey("trees.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
)
|
||||
chunk_type: Mapped[str] = mapped_column(
|
||||
String(30),
|
||||
|
||||
@@ -43,10 +43,10 @@ class User(Base):
|
||||
must_change_password: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
|
||||
|
||||
# Account-based multi-tenancy (new)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")
|
||||
|
||||
@@ -24,6 +24,12 @@ class UserPinnedTree(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
tree_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="CASCADE"),
|
||||
|
||||
545
backend/tests/test_phase1_migrations.py
Normal file
545
backend/tests/test_phase1_migrations.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""Phase 1 migration tests — verify account_id backfill correctness.
|
||||
|
||||
These tests create objects via ORM (which uses the updated models),
|
||||
then verify account_id is populated correctly. They run against a
|
||||
real PostgreSQL test DB (same as all other integration tests).
|
||||
"""
|
||||
import pytest
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.models.tree import Tree
|
||||
from app.models.session import Session
|
||||
from app.models.attachment import Attachment
|
||||
from app.models.supporting_data import SessionSupportingData
|
||||
from app.models.session_resolution_output import SessionResolutionOutput
|
||||
from app.models.ai_session import AISession
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
|
||||
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
user = User(
|
||||
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
|
||||
name=f"User {suffix}",
|
||||
password_hash=get_password_hash("TestPass123!"),
|
||||
is_active=True,
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
return account, user
|
||||
|
||||
|
||||
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
|
||||
tree = Tree(
|
||||
name=f"Tree {uuid.uuid4().hex[:6]}",
|
||||
account_id=account.id,
|
||||
author_id=user.id,
|
||||
visibility="team",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
db.add(tree)
|
||||
await db.flush()
|
||||
return tree
|
||||
|
||||
|
||||
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
|
||||
s = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
tree_snapshot={},
|
||||
)
|
||||
db.add(s)
|
||||
await db.flush()
|
||||
return s
|
||||
|
||||
|
||||
# ── Group 1: Core sessions ────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_account_id_matches_user(test_db: AsyncSession):
|
||||
"""sessions.account_id must equal the user's account_id."""
|
||||
account, user = await _make_account_and_user(test_db, "s1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
session = await _make_session(test_db, account, user, tree)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(Session).where(Session.id == session.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
|
||||
"""attachments.account_id must match the parent session's account_id."""
|
||||
account, user = await _make_account_and_user(test_db, "att1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
session = await _make_session(test_db, account, user, tree)
|
||||
|
||||
attachment = Attachment(
|
||||
session_id=session.id,
|
||||
account_id=account.id,
|
||||
file_name="test.png",
|
||||
file_type="image/png",
|
||||
)
|
||||
test_db.add(attachment)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_supporting_data_account_id(test_db: AsyncSession):
|
||||
"""session_supporting_data.account_id must match parent session's account_id."""
|
||||
account, user = await _make_account_and_user(test_db, "sd1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
session = await _make_session(test_db, account, user, tree)
|
||||
|
||||
sd = SessionSupportingData(
|
||||
session_id=session.id,
|
||||
account_id=account.id,
|
||||
label="Log snippet",
|
||||
data_type="text_snippet",
|
||||
content="error: connection refused",
|
||||
)
|
||||
test_db.add(sd)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_resolution_output_account_id(test_db: AsyncSession):
|
||||
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
|
||||
|
||||
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
|
||||
"""
|
||||
account, user = await _make_account_and_user(test_db, "sro1")
|
||||
|
||||
ai_session = AISession(
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
problem_summary="test resolution output",
|
||||
problem_domain="networking",
|
||||
status="active",
|
||||
)
|
||||
test_db.add(ai_session)
|
||||
await test_db.flush()
|
||||
|
||||
output = SessionResolutionOutput(
|
||||
session_id=ai_session.id,
|
||||
account_id=account.id,
|
||||
output_type="psa_ticket_notes",
|
||||
generated_content="Ticket notes content",
|
||||
generated_by_model="gpt-4",
|
||||
)
|
||||
test_db.add(output)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 2: AI & branching ───────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession):
|
||||
"""session_branches.account_id must match parent ai_session.account_id."""
|
||||
from app.models.session_branch import SessionBranch
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "sb1")
|
||||
ai_session = AISession(
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
problem_summary="test",
|
||||
problem_domain="networking",
|
||||
status="active",
|
||||
)
|
||||
test_db.add(ai_session)
|
||||
await test_db.flush()
|
||||
|
||||
branch = SessionBranch(
|
||||
session_id=ai_session.id,
|
||||
account_id=account.id,
|
||||
label="Branch A",
|
||||
branch_order=1,
|
||||
conversation_messages=[],
|
||||
)
|
||||
test_db.add(branch)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(SessionBranch).where(SessionBranch.id == branch.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession):
|
||||
"""ai_suggestions.account_id must match the creating user's account_id."""
|
||||
from app.models.ai_suggestion import AISuggestion
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "ais1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
|
||||
suggestion = AISuggestion(
|
||||
tree_id=tree.id,
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
action_type="add_node",
|
||||
changes_json={},
|
||||
status="pending",
|
||||
)
|
||||
test_db.add(suggestion)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(AISuggestion).where(AISuggestion.id == suggestion.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 3: Steps & ratings ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession):
|
||||
"""step_ratings.account_id must be the RATER's account, not the step's account."""
|
||||
from app.models.step_library import StepLibrary, StepRating
|
||||
|
||||
account_a, user_a = await _make_account_and_user(test_db, "sr-rater")
|
||||
account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner")
|
||||
|
||||
# Step owned by account_b
|
||||
step = StepLibrary(
|
||||
title="A step",
|
||||
step_type="action",
|
||||
content={"text": "do something"},
|
||||
created_by=user_b.id,
|
||||
account_id=account_b.id,
|
||||
visibility="public",
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.flush()
|
||||
|
||||
# user_a (account_a) rates the step
|
||||
rating = StepRating(
|
||||
step_id=step.id,
|
||||
user_id=user_a.id,
|
||||
account_id=account_a.id, # rater's account, not step owner's
|
||||
was_helpful=True,
|
||||
is_verified_use=False,
|
||||
is_visible=True,
|
||||
)
|
||||
test_db.add(rating)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account_a.id, (
|
||||
f"account_id should be rater's account ({account_a.id}), got {row.account_id}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession):
|
||||
"""step_usage_log.account_id must be the LOGGER's account (user who used the step)."""
|
||||
from app.models.step_library import StepLibrary, StepUsageLog
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "sul1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
session = await _make_session(test_db, account, user, tree)
|
||||
|
||||
step = StepLibrary(
|
||||
title="A usage step",
|
||||
step_type="action",
|
||||
content={"text": "do something"},
|
||||
created_by=user.id,
|
||||
account_id=account.id,
|
||||
visibility="team",
|
||||
)
|
||||
test_db.add(step)
|
||||
await test_db.flush()
|
||||
|
||||
log = StepUsageLog(
|
||||
step_id=step.id,
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
test_db.add(log)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id, (
|
||||
f"account_id should be logger's account ({account.id}), got {row.account_id}"
|
||||
)
|
||||
|
||||
|
||||
# ── Group 4: User personalization ────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_folder_account_id_matches_user(test_db: AsyncSession):
|
||||
"""user_folders.account_id must match the owning user's account_id."""
|
||||
from app.models.folder import UserFolder
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "uf1")
|
||||
folder = UserFolder(
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
name="My Folder",
|
||||
color="#6366f1",
|
||||
icon="folder",
|
||||
display_order=0,
|
||||
)
|
||||
test_db.add(folder)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession):
|
||||
"""user_pinned_trees.account_id must match the pinning user's account_id."""
|
||||
from app.models.user_pinned_tree import UserPinnedTree
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "pt1")
|
||||
tree = await _make_tree(test_db, account, user)
|
||||
pin = UserPinnedTree(
|
||||
user_id=user.id,
|
||||
tree_id=tree.id,
|
||||
account_id=account.id,
|
||||
display_order=0,
|
||||
)
|
||||
test_db.add(pin)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id))
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 5: PSA & notifications ─────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession):
|
||||
"""psa_member_mappings.account_id must match psa_connection's account_id."""
|
||||
from app.models.psa_connection import PsaConnection
|
||||
from app.models.psa_member_mapping import PsaMemberMapping
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "psa1")
|
||||
conn = PsaConnection(
|
||||
account_id=account.id,
|
||||
provider="connectwise",
|
||||
display_name="Test CW",
|
||||
site_url="https://cw.example.com",
|
||||
company_id="TEST",
|
||||
credentials_encrypted="placeholder",
|
||||
)
|
||||
test_db.add(conn)
|
||||
await test_db.flush()
|
||||
|
||||
mapping = PsaMemberMapping(
|
||||
psa_connection_id=conn.id,
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
external_member_id="cw-123",
|
||||
external_member_name="Test User",
|
||||
matched_by="manual_admin",
|
||||
)
|
||||
test_db.add(mapping)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 6: Maintenance ──────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession):
|
||||
"""maintenance_schedules.account_id must match the tree's account_id."""
|
||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "ms1")
|
||||
tree = Tree(
|
||||
name="Maintenance Flow",
|
||||
account_id=account.id,
|
||||
author_id=user.id,
|
||||
visibility="team",
|
||||
tree_type="maintenance",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
status="published",
|
||||
)
|
||||
test_db.add(tree)
|
||||
await test_db.flush()
|
||||
|
||||
schedule = MaintenanceSchedule(
|
||||
tree_id=tree.id,
|
||||
account_id=account.id,
|
||||
created_by=user.id,
|
||||
cron_expression="0 9 * * 1",
|
||||
timezone="UTC",
|
||||
is_active=True,
|
||||
)
|
||||
test_db.add(schedule)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 7: Legacy team_id tables ───────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_script_builder_session_account_id(test_db: AsyncSession):
|
||||
"""script_builder_sessions.account_id must match user's account_id."""
|
||||
from app.models.script_builder_session import ScriptBuilderSession
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "sbs1")
|
||||
sbs = ScriptBuilderSession(
|
||||
user_id=user.id,
|
||||
account_id=account.id,
|
||||
language="powershell",
|
||||
)
|
||||
test_db.add(sbs)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 8: TargetList ────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
||||
"""target_lists.account_id must be set to the team admin's account_id."""
|
||||
from app.models.target_list import TargetList
|
||||
from app.models.team import Team
|
||||
|
||||
account, user = await _make_account_and_user(test_db, "tl1")
|
||||
# Make user a team admin
|
||||
team = Team(name=f"Team {uuid.uuid4().hex[:6]}")
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
|
||||
user.team_id = team.id
|
||||
user.is_team_admin = True
|
||||
await test_db.flush()
|
||||
|
||||
target_list = TargetList(
|
||||
team_id=team.id,
|
||||
account_id=account.id,
|
||||
created_by=user.id,
|
||||
name="Server Targets",
|
||||
targets=[{"label": "SRV-01"}],
|
||||
)
|
||||
test_db.add(target_list)
|
||||
await test_db.commit()
|
||||
|
||||
result = await test_db.execute(
|
||||
select(TargetList).where(TargetList.id == target_list.id)
|
||||
)
|
||||
row = result.scalar_one()
|
||||
assert row.account_id == account.id
|
||||
|
||||
|
||||
# ── Group 10 (runs first): Global content tables ──────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession):
|
||||
"""template_trees must exist and must NOT have an account_id column."""
|
||||
result = await test_db.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'template_trees'
|
||||
"""))
|
||||
columns = {row[0] for row in result.fetchall()}
|
||||
assert 'id' in columns, "template_trees.id must exist"
|
||||
assert 'account_id' not in columns, "template_trees must not have account_id (global content)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession):
|
||||
"""platform_steps must exist and must NOT have an account_id column."""
|
||||
result = await test_db.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'platform_steps'
|
||||
"""))
|
||||
columns = {row[0] for row in result.fetchall()}
|
||||
assert 'id' in columns, "platform_steps.id must exist"
|
||||
assert 'account_id' not in columns, "platform_steps must not have account_id (global content)"
|
||||
|
||||
|
||||
# ── Group 9: SET NOT NULL on existing nullable columns ────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_account_id_is_not_null(test_db: AsyncSession):
|
||||
"""trees.account_id must be NOT NULL after Phase 1 — enforced at DB level."""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
with pytest.raises(IntegrityError):
|
||||
test_db.add(Tree(
|
||||
name="Bad tree",
|
||||
# account_id intentionally omitted
|
||||
author_id=None,
|
||||
visibility="private",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={},
|
||||
is_active=True,
|
||||
status="draft",
|
||||
))
|
||||
await test_db.flush()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_account_id_is_not_null(test_db: AsyncSession):
|
||||
"""users.account_id must be NOT NULL after Phase 1."""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
with pytest.raises(IntegrityError):
|
||||
test_db.add(User(
|
||||
email=f"orphan-{uuid.uuid4().hex[:6]}@example.com",
|
||||
name="Orphan",
|
||||
password_hash=get_password_hash("x"),
|
||||
is_active=True,
|
||||
role="engineer",
|
||||
account_role="engineer",
|
||||
# account_id intentionally omitted
|
||||
))
|
||||
await test_db.flush()
|
||||
266
backend/tests/test_rls_isolation.py
Normal file
266
backend/tests/test_rls_isolation.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# backend/tests/test_rls_isolation.py
|
||||
"""
|
||||
RLS foundation tests.
|
||||
|
||||
Connect directly as resolutionflow_app (not superuser) and verify:
|
||||
- Tenant A cannot read Tenant B's rows
|
||||
- No tenant context set → zero rows for private data (fail-closed)
|
||||
- Platform rows (PLATFORM_ACCOUNT_ID) are visible to all tenants
|
||||
|
||||
Tests bypass FastAPI entirely — raw asyncpg connections only.
|
||||
MUST FAIL before Task 10 (RLS migration) and PASS after it.
|
||||
|
||||
Run with:
|
||||
DB_APP_ROLE_PASSWORD=app_secret_change_me pytest tests/test_rls_isolation.py -v
|
||||
|
||||
The test DB is patherly_test (matches conftest.py default).
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
||||
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
||||
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
||||
_APP_PASSWORD = os.getenv("DB_APP_ROLE_PASSWORD", "app_secret_change_me")
|
||||
_ADMIN_DSN = f"postgresql://postgres:postgres@{_DB_HOST}:{_DB_PORT}/{_DB_NAME}"
|
||||
|
||||
PLATFORM_ACCOUNT_ID = "00000000-0000-0000-0000-000000000001"
|
||||
ACCOUNT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def admin_conn():
|
||||
"""Superuser asyncpg connection for fixture setup and teardown."""
|
||||
conn = await asyncpg.connect(_ADMIN_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
async def seed_rls_test_data(admin_conn):
|
||||
"""
|
||||
Create two isolated test accounts, one user per account, and one private
|
||||
tree per account. Trees require a valid author_id FK to users, so users
|
||||
must be created first.
|
||||
|
||||
accounts.display_code must be unique and 8 chars (NOT NULL constraint).
|
||||
"""
|
||||
# Insert accounts
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
|
||||
VALUES
|
||||
('{ACCOUNT_A_ID}', 'RLS Tenant A', 'RLSA0001', NOW(), NOW()),
|
||||
('{ACCOUNT_B_ID}', 'RLS Tenant B', 'RLSB0001', NOW(), NOW())
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
""")
|
||||
|
||||
# Insert one user per account (users.account_id NOT NULL, password_hash NOT NULL)
|
||||
user_a_id = str(uuid.uuid4())
|
||||
user_b_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO users (
|
||||
id, email, password_hash, name, role, is_active, account_id,
|
||||
account_role, created_at
|
||||
) VALUES
|
||||
('{user_a_id}', 'rls-user-a@example.com',
|
||||
'placeholder', 'RLS User A', 'engineer', TRUE,
|
||||
'{ACCOUNT_A_ID}', 'engineer', NOW()),
|
||||
('{user_b_id}', 'rls-user-b@example.com',
|
||||
'placeholder', 'RLS User B', 'engineer', TRUE,
|
||||
'{ACCOUNT_B_ID}', 'engineer', NOW())
|
||||
ON CONFLICT (email) DO NOTHING
|
||||
""")
|
||||
|
||||
# Look up the user IDs we just inserted (ON CONFLICT may have skipped)
|
||||
row_a = await admin_conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = 'rls-user-a@example.com'"
|
||||
)
|
||||
row_b = await admin_conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
|
||||
)
|
||||
actual_user_a = str(row_a["id"])
|
||||
actual_user_b = str(row_b["id"])
|
||||
|
||||
# Insert one private tree per account with explicit author_id
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO trees (
|
||||
id, name, tree_structure, account_id, author_id, is_active, is_default,
|
||||
is_public, visibility, tree_type, created_at, updated_at
|
||||
) VALUES
|
||||
(gen_random_uuid(), 'RLS Tree A', '[]'::jsonb, '{ACCOUNT_A_ID}', '{actual_user_a}',
|
||||
TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW()),
|
||||
(gen_random_uuid(), 'RLS Tree B', '[]'::jsonb, '{ACCOUNT_B_ID}', '{actual_user_b}',
|
||||
TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW())
|
||||
""")
|
||||
|
||||
# One platform-owned tree_tag (global, visible to all tenants)
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO tree_tags (
|
||||
id, name, slug, account_id, usage_count, created_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), 'rls-global-tag', 'rls-global-tag',
|
||||
'{PLATFORM_ACCOUNT_ID}', 0, NOW()
|
||||
) ON CONFLICT DO NOTHING
|
||||
""")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM trees WHERE account_id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
"DELETE FROM users WHERE email IN "
|
||||
"('rls-user-a@example.com', 'rls-user-b@example.com')"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM accounts WHERE id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')"
|
||||
)
|
||||
await admin_conn.execute("DELETE FROM tree_tags WHERE slug = 'rls-global-tag'")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def conn_a():
|
||||
"""App-role connection, tenant context = Account A."""
|
||||
conn = await asyncpg.connect(
|
||||
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
|
||||
user="resolutionflow_app", password=_APP_PASSWORD,
|
||||
)
|
||||
await conn.execute(
|
||||
"SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_A_ID
|
||||
)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def conn_b():
|
||||
"""App-role connection, tenant context = Account B."""
|
||||
conn = await asyncpg.connect(
|
||||
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
|
||||
user="resolutionflow_app", password=_APP_PASSWORD,
|
||||
)
|
||||
await conn.execute(
|
||||
"SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_B_ID
|
||||
)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def conn_no_context():
|
||||
"""App-role connection with NO tenant context set."""
|
||||
conn = await asyncpg.connect(
|
||||
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
|
||||
user="resolutionflow_app", password=_APP_PASSWORD,
|
||||
)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# trees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
)
|
||||
assert len(rows) >= 1, "Account A should see its own trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
rows = await conn_no_context.fetch(
|
||||
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
||||
)
|
||||
assert len(rows) == 0, "No-context connection should see no private trees"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tree_tags — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
rows_a = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||
)
|
||||
rows_b = await conn_b.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||
)
|
||||
assert len(rows_a) >= 1, "Account A should see platform tags"
|
||||
assert len(rows_b) >= 1, "Account B should see platform tags"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tree_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# psa_connections — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# flow_proposals — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Integration tests for Script Template Editor permissions and share endpoint."""
|
||||
from uuid import UUID as PyUUID
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
@@ -65,6 +67,9 @@ class TestScriptTemplatePermissions:
|
||||
data = resp.json()
|
||||
assert data["name"] == "Test Template"
|
||||
assert data["created_by"] is not None
|
||||
result = await test_db.execute(select(ScriptTemplate).where(ScriptTemplate.id == PyUUID(data["id"])))
|
||||
template = result.scalar_one()
|
||||
assert template.account_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_can_edit_own_template(self, client, auth_headers, test_db):
|
||||
|
||||
@@ -6,14 +6,18 @@ from datetime import datetime, timezone
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.models.script_template import ScriptGeneration
|
||||
from app.models.user import User
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
async def seed_script_data(test_db):
|
||||
async def seed_script_data(test_db, test_user):
|
||||
"""Seed script categories and templates into the test database."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cat_id = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"]))
|
||||
user = user_result.scalar_one()
|
||||
|
||||
# Insert category
|
||||
await test_db.execute(
|
||||
@@ -142,20 +146,20 @@ async def seed_script_data(test_db):
|
||||
await test_db.execute(
|
||||
sa.text("""
|
||||
INSERT INTO script_templates (
|
||||
id, category_id, name, slug, description,
|
||||
id, category_id, account_id, name, slug, description,
|
||||
script_body, parameters_schema, default_values, validation_rules,
|
||||
tags, complexity, estimated_runtime, requires_elevation,
|
||||
requires_modules, version, is_verified, is_active, usage_count,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
:id, :category_id, :name, :slug, :description,
|
||||
:id, :category_id, :account_id, :name, :slug, :description,
|
||||
:script_body, CAST(:parameters_schema AS jsonb), '{}'::jsonb, '{}'::jsonb,
|
||||
CAST(:tags AS jsonb), :complexity, :estimated_runtime, :requires_elevation,
|
||||
'[]'::jsonb, 1, true, true, 0,
|
||||
:now, :now
|
||||
)
|
||||
"""),
|
||||
{**tmpl, "category_id": cat_id, "now": now},
|
||||
{**tmpl, "category_id": cat_id, "account_id": user.account_id, "now": now},
|
||||
)
|
||||
|
||||
await test_db.commit()
|
||||
@@ -245,7 +249,7 @@ async def test_get_template_detail_not_found(client, auth_headers):
|
||||
# ── Generate ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_script_success(client, auth_headers, seed_script_data):
|
||||
async def test_generate_script_success(client, auth_headers, seed_script_data, test_db, test_user):
|
||||
list_resp = await client.get(
|
||||
"/api/v1/scripts/templates?search=unlock",
|
||||
headers=auth_headers,
|
||||
@@ -265,6 +269,13 @@ async def test_generate_script_success(client, auth_headers, seed_script_data):
|
||||
assert "script" in data
|
||||
assert "jsmith" in data["script"]
|
||||
assert "id" in data
|
||||
generation_result = await test_db.execute(
|
||||
sa.select(ScriptGeneration).where(ScriptGeneration.id == uuid.UUID(data["id"]))
|
||||
)
|
||||
generation = generation_result.scalar_one()
|
||||
user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"]))
|
||||
user = user_result.scalar_one()
|
||||
assert generation.account_id == user.account_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
58
backend/tests/test_tenant_context.py
Normal file
58
backend/tests/test_tenant_context.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.core.tenant_context import (
|
||||
set_current_account_id,
|
||||
clear_current_account_id,
|
||||
get_current_account_id,
|
||||
)
|
||||
|
||||
|
||||
def test_contextvar_is_none_by_default():
|
||||
assert get_current_account_id() is None
|
||||
|
||||
|
||||
def test_set_and_clear():
|
||||
account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001")
|
||||
token = set_current_account_id(account_id)
|
||||
assert get_current_account_id() == str(account_id)
|
||||
clear_current_account_id(token)
|
||||
assert get_current_account_id() is None
|
||||
|
||||
|
||||
def test_tasks_are_isolated():
|
||||
"""Each asyncio task has its own ContextVar value."""
|
||||
results = {}
|
||||
|
||||
async def set_in_task(name: str, value: str):
|
||||
token = set_current_account_id(UUID(value))
|
||||
await asyncio.sleep(0)
|
||||
results[name] = get_current_account_id()
|
||||
clear_current_account_id(token)
|
||||
|
||||
async def run():
|
||||
await asyncio.gather(
|
||||
set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"),
|
||||
set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"),
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
|
||||
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_tenant_context_raises_403_when_no_account():
|
||||
from fastapi import HTTPException
|
||||
from app.api.deps import require_tenant_context
|
||||
|
||||
user = MagicMock()
|
||||
user.account_id = None
|
||||
|
||||
gen = require_tenant_context(current_user=user)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await gen.__anext__()
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "account required" in exc_info.value.detail.lower()
|
||||
@@ -447,3 +447,55 @@ class TestVisibilityFilter:
|
||||
assert "author_name" in trees[0]
|
||||
# visibility key should be present
|
||||
assert "visibility" in trees[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tree_returns_404_not_403_for_other_account_tree(
|
||||
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
|
||||
):
|
||||
"""Account A must not learn that Account B's private tree exists."""
|
||||
from app.models.tree import Tree
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash
|
||||
import uuid
|
||||
|
||||
# Create a second account and user
|
||||
account_b = Account(name="Other Corp", display_code="OTH00001")
|
||||
test_db.add(account_b)
|
||||
await test_db.flush()
|
||||
|
||||
user_b = User(
|
||||
email=f"user-b-{uuid.uuid4().hex[:6]}@example.com",
|
||||
name="User B",
|
||||
password_hash=get_password_hash("TestPass123!"),
|
||||
is_active=True,
|
||||
account_id=account_b.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(user_b)
|
||||
await test_db.flush()
|
||||
|
||||
# Create a private tree belonging to account_b
|
||||
private_tree = Tree(
|
||||
name="Secret Tree",
|
||||
account_id=account_b.id,
|
||||
author_id=user_b.id,
|
||||
visibility="private",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure={"id": "root", "type": "start", "children": []},
|
||||
is_active=True,
|
||||
is_default=False,
|
||||
is_public=False,
|
||||
status="published",
|
||||
)
|
||||
test_db.add(private_tree)
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/trees/{private_tree.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404, (
|
||||
f"Expected 404 but got {response.status_code} — "
|
||||
"leaking tree existence to wrong tenant"
|
||||
)
|
||||
|
||||
2527
docs/superpowers/plans/2026-04-09-tenant-isolation-phase-1.md
Normal file
2527
docs/superpowers/plans/2026-04-09-tenant-isolation-phase-1.md
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user