feat: Phase 1 Group 1 — add account_id to core session tables
Migration sequence: add nullable → backfill via user_id/ai_session chain
→ verify zero NULLs → SET NOT NULL → CREATE INDEX.
Tables: sessions, attachments, session_supporting_data,
session_resolution_outputs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,95 @@
|
|||||||
|
"""add account_id to core session tables
|
||||||
|
|
||||||
|
Revision ID: cc214c63aa30
|
||||||
|
Revises: b8d2f4a6c091
|
||||||
|
Create Date: 2026-04-09 00:00:00.000000
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision: str = 'cc214c63aa30'
|
||||||
|
down_revision: Union[str, None] = 'b8d2f4a6c091'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Step 1: ADD COLUMN (nullable) ────────────────────────────────────────
|
||||||
|
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||||
|
'session_resolution_outputs'):
|
||||||
|
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
|
||||||
|
op.create_foreign_key(
|
||||||
|
f'fk_{table}_account_id',
|
||||||
|
table, 'accounts',
|
||||||
|
['account_id'], ['id'],
|
||||||
|
ondelete='CASCADE',
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Step 2: BACKFILL ─────────────────────────────────────────────────────
|
||||||
|
# sessions: direct join to users
|
||||||
|
op.execute("""
|
||||||
|
UPDATE sessions s
|
||||||
|
SET account_id = u.account_id
|
||||||
|
FROM users u
|
||||||
|
WHERE s.user_id = u.id
|
||||||
|
AND s.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# attachments: chain through sessions (now backfilled above)
|
||||||
|
op.execute("""
|
||||||
|
UPDATE attachments a
|
||||||
|
SET account_id = s.account_id
|
||||||
|
FROM sessions s
|
||||||
|
WHERE a.session_id = s.id
|
||||||
|
AND a.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# session_supporting_data: same chain
|
||||||
|
op.execute("""
|
||||||
|
UPDATE session_supporting_data sd
|
||||||
|
SET account_id = s.account_id
|
||||||
|
FROM sessions s
|
||||||
|
WHERE sd.session_id = s.id
|
||||||
|
AND sd.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# session_resolution_outputs: FK is to ai_sessions, not sessions
|
||||||
|
op.execute("""
|
||||||
|
UPDATE session_resolution_outputs sro
|
||||||
|
SET account_id = ai.account_id
|
||||||
|
FROM ai_sessions ai
|
||||||
|
WHERE sro.session_id = ai.id
|
||||||
|
AND sro.account_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ── Step 3: VERIFY zero NULLs — raises if any remain ────────────────────
|
||||||
|
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||||
|
'session_resolution_outputs'):
|
||||||
|
result = op.get_bind().execute(
|
||||||
|
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
|
||||||
|
)
|
||||||
|
count = result.scalar()
|
||||||
|
if count > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ROLLBACK: {count} NULL account_id rows remain in {table}. "
|
||||||
|
f"Fix the backfill before re-running."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Step 4: SET NOT NULL ─────────────────────────────────────────────────
|
||||||
|
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||||
|
'session_resolution_outputs'):
|
||||||
|
op.alter_column(table, 'account_id', nullable=False)
|
||||||
|
|
||||||
|
# ── Step 5: CREATE INDEX ─────────────────────────────────────────────────
|
||||||
|
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||||
|
'session_resolution_outputs'):
|
||||||
|
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
for table in ('sessions', 'attachments', 'session_supporting_data',
|
||||||
|
'session_resolution_outputs'):
|
||||||
|
op.drop_index(f'ix_{table}_account_id', table_name=table)
|
||||||
|
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
|
||||||
|
op.drop_column(table, 'account_id')
|
||||||
@@ -20,6 +20,12 @@ class Attachment(Base):
|
|||||||
ForeignKey("sessions.id"),
|
ForeignKey("sessions.id"),
|
||||||
nullable=False
|
nullable=False
|
||||||
)
|
)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||||
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ class Session(Base):
|
|||||||
nullable=False,
|
nullable=False,
|
||||||
index=True
|
index=True
|
||||||
)
|
)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
|
||||||
path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list)
|
path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list)
|
||||||
decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list)
|
decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list)
|
||||||
|
|||||||
@@ -23,6 +23,12 @@ class SessionResolutionOutput(Base):
|
|||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
output_type: Mapped[str] = mapped_column(String(30), nullable=False)
|
output_type: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||||
generated_content: Mapped[str] = mapped_column(Text, nullable=False)
|
generated_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}")
|
structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}")
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ class SessionSupportingData(Base):
|
|||||||
|
|
||||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False)
|
data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False)
|
||||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
|||||||
162
backend/tests/test_phase1_migrations.py
Normal file
162
backend/tests/test_phase1_migrations.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""Phase 1 migration tests — verify account_id backfill correctness.
|
||||||
|
|
||||||
|
These tests create objects via ORM (which uses the updated models),
|
||||||
|
then verify account_id is populated correctly. They run against a
|
||||||
|
real PostgreSQL test DB (same as all other integration tests).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.tree import Tree
|
||||||
|
from app.models.session import Session
|
||||||
|
from app.models.attachment import Attachment
|
||||||
|
from app.models.supporting_data import SessionSupportingData
|
||||||
|
from app.models.session_resolution_output import SessionResolutionOutput
|
||||||
|
from app.models.ai_session import AISession
|
||||||
|
from app.core.security import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
|
||||||
|
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
|
||||||
|
db.add(account)
|
||||||
|
await db.flush()
|
||||||
|
user = User(
|
||||||
|
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
|
||||||
|
name=f"User {suffix}",
|
||||||
|
password_hash=get_password_hash("TestPass123!"),
|
||||||
|
is_active=True,
|
||||||
|
account_id=account.id,
|
||||||
|
account_role="engineer",
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
return account, user
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
|
||||||
|
tree = Tree(
|
||||||
|
name=f"Tree {uuid.uuid4().hex[:6]}",
|
||||||
|
account_id=account.id,
|
||||||
|
author_id=user.id,
|
||||||
|
visibility="team",
|
||||||
|
tree_type="troubleshooting",
|
||||||
|
tree_structure={"id": "root", "type": "start", "children": []},
|
||||||
|
is_active=True,
|
||||||
|
status="published",
|
||||||
|
)
|
||||||
|
db.add(tree)
|
||||||
|
await db.flush()
|
||||||
|
return tree
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
|
||||||
|
s = Session(
|
||||||
|
tree_id=tree.id,
|
||||||
|
user_id=user.id,
|
||||||
|
account_id=account.id,
|
||||||
|
tree_snapshot={},
|
||||||
|
)
|
||||||
|
db.add(s)
|
||||||
|
await db.flush()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── Group 1: Core sessions ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_account_id_matches_user(test_db: AsyncSession):
|
||||||
|
"""sessions.account_id must equal the user's account_id."""
|
||||||
|
account, user = await _make_account_and_user(test_db, "s1")
|
||||||
|
tree = await _make_tree(test_db, account, user)
|
||||||
|
session = await _make_session(test_db, account, user, tree)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
result = await test_db.execute(select(Session).where(Session.id == session.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
|
||||||
|
"""attachments.account_id must match the parent session's account_id."""
|
||||||
|
account, user = await _make_account_and_user(test_db, "att1")
|
||||||
|
tree = await _make_tree(test_db, account, user)
|
||||||
|
session = await _make_session(test_db, account, user, tree)
|
||||||
|
|
||||||
|
attachment = Attachment(
|
||||||
|
session_id=session.id,
|
||||||
|
account_id=account.id,
|
||||||
|
file_name="test.png",
|
||||||
|
file_type="image/png",
|
||||||
|
)
|
||||||
|
test_db.add(attachment)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.account_id == account.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_supporting_data_account_id(test_db: AsyncSession):
|
||||||
|
"""session_supporting_data.account_id must match parent session's account_id."""
|
||||||
|
account, user = await _make_account_and_user(test_db, "sd1")
|
||||||
|
tree = await _make_tree(test_db, account, user)
|
||||||
|
session = await _make_session(test_db, account, user, tree)
|
||||||
|
|
||||||
|
sd = SessionSupportingData(
|
||||||
|
session_id=session.id,
|
||||||
|
account_id=account.id,
|
||||||
|
label="Log snippet",
|
||||||
|
data_type="text_snippet",
|
||||||
|
content="error: connection refused",
|
||||||
|
)
|
||||||
|
test_db.add(sd)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
result = await test_db.execute(
|
||||||
|
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.account_id == account.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_resolution_output_account_id(test_db: AsyncSession):
|
||||||
|
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
|
||||||
|
|
||||||
|
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
|
||||||
|
"""
|
||||||
|
account, user = await _make_account_and_user(test_db, "sro1")
|
||||||
|
|
||||||
|
ai_session = AISession(
|
||||||
|
user_id=user.id,
|
||||||
|
account_id=account.id,
|
||||||
|
problem_summary="test resolution output",
|
||||||
|
problem_domain="networking",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
test_db.add(ai_session)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
output = SessionResolutionOutput(
|
||||||
|
session_id=ai_session.id,
|
||||||
|
account_id=account.id,
|
||||||
|
output_type="psa_ticket_notes",
|
||||||
|
generated_content="Ticket notes content",
|
||||||
|
generated_by_model="gpt-4",
|
||||||
|
)
|
||||||
|
test_db.add(output)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
result = await test_db.execute(
|
||||||
|
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.account_id == account.id
|
||||||
Reference in New Issue
Block a user