From 4666c4f6d266907cba4c4a7a14da8f9555cb2ae2 Mon Sep 17 00:00:00 2001 From: chihlasm Date: Thu, 9 Apr 2026 05:07:05 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=201=20Group=201=20=E2=80=94=20add?= =?UTF-8?q?=20account=5Fid=20to=20core=20session=20tables?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migration sequence: add nullable → backfill via user_id/ai_session chain → verify zero NULLs → SET NOT NULL → CREATE INDEX. Tables: sessions, attachments, session_supporting_data, session_resolution_outputs Co-Authored-By: Claude Sonnet 4.6 --- ...214c63aa30_add_account_id_core_sessions.py | 95 ++++++++++ backend/app/models/attachment.py | 6 + backend/app/models/session.py | 6 + .../app/models/session_resolution_output.py | 6 + backend/app/models/supporting_data.py | 6 + backend/tests/test_phase1_migrations.py | 162 ++++++++++++++++++ 6 files changed, 281 insertions(+) create mode 100644 backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py create mode 100644 backend/tests/test_phase1_migrations.py diff --git a/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py b/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py new file mode 100644 index 00000000..c0e5f47c --- /dev/null +++ b/backend/alembic/versions/cc214c63aa30_add_account_id_core_sessions.py @@ -0,0 +1,95 @@ +"""add account_id to core session tables + +Revision ID: cc214c63aa30 +Revises: b8d2f4a6c091 +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'cc214c63aa30' +down_revision: Union[str, None] = 'b8d2f4a6c091' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Step 1: ADD COLUMN (nullable) ──────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', + table, 'accounts', + ['account_id'], ['id'], + ondelete='CASCADE', + ) + + # ── Step 2: BACKFILL ───────────────────────────────────────────────────── + # sessions: direct join to users + op.execute(""" + UPDATE sessions s + SET account_id = u.account_id + FROM users u + WHERE s.user_id = u.id + AND s.account_id IS NULL + """) + + # attachments: chain through sessions (now backfilled above) + op.execute(""" + UPDATE attachments a + SET account_id = s.account_id + FROM sessions s + WHERE a.session_id = s.id + AND a.account_id IS NULL + """) + + # session_supporting_data: same chain + op.execute(""" + UPDATE session_supporting_data sd + SET account_id = s.account_id + FROM sessions s + WHERE sd.session_id = s.id + AND sd.account_id IS NULL + """) + + # session_resolution_outputs: FK is to ai_sessions, not sessions + op.execute(""" + UPDATE session_resolution_outputs sro + SET account_id = ai.account_id + FROM ai_sessions ai + WHERE sro.session_id = ai.id + AND sro.account_id IS NULL + """) + + # ── Step 3: VERIFY zero NULLs — raises if any remain ──────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError( + f"ROLLBACK: {count} NULL account_id rows remain in {table}. " + f"Fix the backfill before re-running." + ) + + # ── Step 4: SET NOT NULL ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.alter_column(table, 'account_id', nullable=False) + + # ── Step 5: CREATE INDEX ───────────────────────────────────────────────── + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('sessions', 'attachments', 'session_supporting_data', + 'session_resolution_outputs'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/app/models/attachment.py b/backend/app/models/attachment.py index dc5266b6..910f697c 100644 --- a/backend/app/models/attachment.py +++ b/backend/app/models/attachment.py @@ -20,6 +20,12 @@ class Attachment(Base): ForeignKey("sessions.id"), nullable=False ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) file_name: Mapped[str] = mapped_column(String(255), nullable=False) file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) diff --git a/backend/app/models/session.py b/backend/app/models/session.py index c191572b..5bcd6241 100644 --- a/backend/app/models/session.py +++ b/backend/app/models/session.py @@ -31,6 +31,12 @@ class Session(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list) decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list) diff --git a/backend/app/models/session_resolution_output.py b/backend/app/models/session_resolution_output.py index cb56fa42..3ae32549 100644 --- a/backend/app/models/session_resolution_output.py +++ b/backend/app/models/session_resolution_output.py @@ -23,6 +23,12 @@ class SessionResolutionOutput(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) output_type: Mapped[str] = mapped_column(String(30), nullable=False) generated_content: Mapped[str] = mapped_column(Text, nullable=False) structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}") diff --git a/backend/app/models/supporting_data.py b/backend/app/models/supporting_data.py index ea04cd91..d69f66e2 100644 --- a/backend/app/models/supporting_data.py +++ b/backend/app/models/supporting_data.py @@ -14,6 +14,12 @@ class SessionSupportingData(Base): id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) label: Mapped[str] = mapped_column(String(255), nullable=False) data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/backend/tests/test_phase1_migrations.py b/backend/tests/test_phase1_migrations.py new file mode 100644 index 00000000..144ef099 --- /dev/null +++ b/backend/tests/test_phase1_migrations.py @@ -0,0 +1,162 @@ +"""Phase 1 migration tests — verify account_id backfill correctness. + +These tests create objects via ORM (which uses the updated models), +then verify account_id is populated correctly. They run against a +real PostgreSQL test DB (same as all other integration tests). +""" +import pytest +import uuid +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.models.account import Account +from app.models.user import User +from app.models.tree import Tree +from app.models.session import Session +from app.models.attachment import Attachment +from app.models.supporting_data import SessionSupportingData +from app.models.session_resolution_output import SessionResolutionOutput +from app.models.ai_session import AISession +from app.core.security import get_password_hash + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]: + account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8]) + db.add(account) + await db.flush() + user = User( + email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", + name=f"User {suffix}", + password_hash=get_password_hash("TestPass123!"), + is_active=True, + account_id=account.id, + account_role="engineer", + ) + db.add(user) + await db.flush() + return account, user + + +async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree: + tree = Tree( + name=f"Tree {uuid.uuid4().hex[:6]}", + account_id=account.id, + author_id=user.id, + visibility="team", + tree_type="troubleshooting", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + db.add(tree) + await db.flush() + return tree + + +async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session: + s = Session( + tree_id=tree.id, + user_id=user.id, + account_id=account.id, + tree_snapshot={}, + ) + db.add(s) + await db.flush() + return s + + +# ── Group 1: Core sessions ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_session_account_id_matches_user(test_db: AsyncSession): + """sessions.account_id must equal the user's account_id.""" + account, user = await _make_account_and_user(test_db, "s1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + await test_db.commit() + + result = await test_db.execute(select(Session).where(Session.id == session.id)) + row = result.scalar_one() + assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}" + + +@pytest.mark.asyncio +async def test_attachment_account_id_matches_session(test_db: AsyncSession): + """attachments.account_id must match the parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "att1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + attachment = Attachment( + session_id=session.id, + account_id=account.id, + file_name="test.png", + file_type="image/png", + ) + test_db.add(attachment) + await test_db.commit() + + result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_session_supporting_data_account_id(test_db: AsyncSession): + """session_supporting_data.account_id must match parent session's account_id.""" + account, user = await _make_account_and_user(test_db, "sd1") + tree = await _make_tree(test_db, account, user) + session = await _make_session(test_db, account, user, tree) + + sd = SessionSupportingData( + session_id=session.id, + account_id=account.id, + label="Log snippet", + data_type="text_snippet", + content="error: connection refused", + ) + test_db.add(sd) + await test_db.commit() + + result = await test_db.execute( + select(SessionSupportingData).where(SessionSupportingData.id == sd.id) + ) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_session_resolution_output_account_id(test_db: AsyncSession): + """session_resolution_outputs.account_id must match the parent ai_session's account_id. + + NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions). + """ + account, user = await _make_account_and_user(test_db, "sro1") + + ai_session = AISession( + user_id=user.id, + account_id=account.id, + problem_summary="test resolution output", + problem_domain="networking", + status="active", + ) + test_db.add(ai_session) + await test_db.flush() + + output = SessionResolutionOutput( + session_id=ai_session.id, + account_id=account.id, + output_type="psa_ticket_notes", + generated_content="Ticket notes content", + generated_by_model="gpt-4", + ) + test_db.add(output) + await test_db.commit() + + result = await test_db.execute( + select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id) + ) + row = result.scalar_one() + assert row.account_id == account.id