diff --git a/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py b/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py new file mode 100644 index 00000000..92d0e6e5 --- /dev/null +++ b/backend/alembic/versions/478c159e5654_add_account_id_ai_branching.py @@ -0,0 +1,77 @@ +"""add account_id to AI branching tables + +Revision ID: 478c159e5654 +Revises: cc214c63aa30 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = '478c159e5654' +down_revision: Union[str, None] = 'cc214c63aa30' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + ai_tables = ('session_branches', 'session_handoffs', 'fork_points', 'ai_session_steps') + + # Step 1: ADD COLUMN (nullable) + for table in ai_tables: + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + op.add_column('ai_suggestions', sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + 'fk_ai_suggestions_account_id', 'ai_suggestions', 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + + # Step 2: BACKFILL + for table in ai_tables: + op.execute(f""" + UPDATE {table} t + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE t.session_id = ai.id + AND t.account_id IS NULL + """) + + op.execute(""" + UPDATE ai_suggestions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # Step 3: VERIFY zero NULLs + for table in ai_tables + ('ai_suggestions',): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows in {table}." + ) + + # Step 4: SET NOT NULL + for table in ai_tables + ('ai_suggestions',): + op.alter_column(table, 'account_id', nullable=False) + + # Step 5: CREATE INDEX + for table in ai_tables + ('ai_suggestions',): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('session_branches', 'session_handoffs', 'fork_points', + 'ai_session_steps', 'ai_suggestions'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/app/models/ai_session_step.py b/backend/app/models/ai_session_step.py index 1642632b..09ffc4c1 100644 --- a/backend/app/models/ai_session_step.py +++ b/backend/app/models/ai_session_step.py @@ -50,6 +50,13 @@ class AISessionStep(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Denormalized from ai_sessions.account_id for direct tenant filtering.", + ) step_order: Mapped[int] = mapped_column( Integer, nullable=False, comment="Sequential position in the session (0-indexed)", diff --git a/backend/app/models/ai_suggestion.py b/backend/app/models/ai_suggestion.py index 8ee65dd5..12321c9a 100644 --- a/backend/app/models/ai_suggestion.py +++ b/backend/app/models/ai_suggestion.py @@ -28,6 +28,12 @@ class AISuggestion(Base): nullable=False, index=True, ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) session_id: Mapped[Optional[uuid.UUID]] = mapped_column( UUID(as_uuid=True), ForeignKey("ai_chat_sessions.id", ondelete="SET NULL"), diff --git a/backend/app/models/fork_point.py b/backend/app/models/fork_point.py index a5700774..8c89d49d 100644 --- a/backend/app/models/fork_point.py +++ b/backend/app/models/fork_point.py @@ -23,6 +23,12 @@ class ForkPoint(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) parent_branch_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=False) trigger_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True) fork_reason: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/backend/app/models/session_branch.py b/backend/app/models/session_branch.py index ab6cc50e..e3716806 100644 --- a/backend/app/models/session_branch.py +++ b/backend/app/models/session_branch.py @@ -35,6 +35,12 @@ class SessionBranch(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) parent_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=True) fork_point_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True) branch_order: Mapped[int] = mapped_column(Integer, nullable=False, default=1) diff --git a/backend/app/models/session_handoff.py b/backend/app/models/session_handoff.py index 0fd53128..1b44df56 100644 --- a/backend/app/models/session_handoff.py +++ b/backend/app/models/session_handoff.py @@ -27,6 +27,12 @@ class SessionHandoff(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) handed_off_by: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False) intent: Mapped[str] = mapped_column(String(20), nullable=False) source_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="SET NULL"), nullable=True) diff --git a/backend/tests/test_phase1_migrations.py b/backend/tests/test_phase1_migrations.py index 144ef099..b90279a9 100644 --- a/backend/tests/test_phase1_migrations.py +++ b/backend/tests/test_phase1_migrations.py @@ -160,3 +160,64 @@ async def test_session_resolution_output_account_id(test_db: AsyncSession): ) row = result.scalar_one() assert row.account_id == account.id + + +# ── Group 2: AI & branching ─────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession): + """session_branches.account_id must match parent ai_session.account_id.""" + from app.models.session_branch import SessionBranch + + account, user = await _make_account_and_user(test_db, "sb1") + ai_session = AISession( + user_id=user.id, + account_id=account.id, + problem_summary="test", + problem_domain="networking", + status="active", + ) + test_db.add(ai_session) + await test_db.flush() + + branch = SessionBranch( + session_id=ai_session.id, + account_id=account.id, + label="Branch A", + branch_order=1, + conversation_messages=[], + ) + test_db.add(branch) + await test_db.commit() + + result = await test_db.execute( + select(SessionBranch).where(SessionBranch.id == branch.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession): + """ai_suggestions.account_id must match the creating user's account_id.""" + from app.models.ai_suggestion import AISuggestion + + account, user = await _make_account_and_user(test_db, "ais1") + tree = await _make_tree(test_db, account, user) + + suggestion = AISuggestion( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + action_type="add_node", + changes_json={}, + status="pending", + ) + test_db.add(suggestion) + await test_db.commit() + + result = await test_db.execute( + select(AISuggestion).where(AISuggestion.id == suggestion.id) + ) + row = result.scalar_one() + assert row.account_id == account.id