feat: Phase 1 Group 9 — enforce NOT NULL on all account_id columns

All previously-nullable account_id columns are now NOT NULL.
tree_embeddings and feedback backfilled before constraint applied.
Global content assigned to platform sentinel account (00000000-...-0001)
in preceding migration.

Tables updated: users, trees, tree_categories, tree_tags,
step_categories, step_library, tree_embeddings, feedback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-04-09 05:34:32 +00:00
parent b4b8c67d3b
commit 42937b24a4
10 changed files with 139 additions and 17 deletions

View File

@@ -0,0 +1,86 @@
"""set NOT NULL on all previously-nullable account_id columns
Revision ID: 174f442795b7
Revises: 3a40fe11b427
Create Date: 2026-04-09 00:00:00.000000
All tables in this migration had account_id set to nullable previously.
Task 9 (create_global_content_tables) cleared all NULL rows.
This migration enforces the NOT NULL constraint.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '174f442795b7'
down_revision: Union[str, None] = '3a40fe11b427'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# tree_embeddings: backfill from trees (must happen before SET NOT NULL)
op.execute("""
UPDATE tree_embeddings te
SET account_id = t.account_id
FROM trees t
WHERE te.tree_id = t.id
AND te.account_id IS NULL
""")
# feedback: backfill from users
op.execute("""
UPDATE feedback f
SET account_id = u.account_id
FROM users u
WHERE f.user_id = u.id
AND f.account_id IS NULL
""")
# Verify ALL tables before touching any SET NOT NULL
tables_with_account_id = [
'users', 'trees', 'tree_categories', 'tree_tags',
'step_categories', 'step_library', 'tree_embeddings', 'feedback',
]
for table in tables_with_account_id:
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} NULL account_id rows in {table}. "
"Run Task 9 (create_global_content_tables) first, or "
"manually backfill/delete orphaned rows."
)
# SET NOT NULL on all
for table in tables_with_account_id:
op.alter_column(table, 'account_id', nullable=False)
# Create indexes where they don't already exist
new_indexes = [
('tree_embeddings', 'ix_tree_embeddings_account_id'),
('feedback', 'ix_feedback_account_id'),
]
for table, index_name in new_indexes:
result = op.get_bind().execute(sa.text(
f"SELECT 1 FROM pg_indexes WHERE tablename='{table}' AND indexname='{index_name}'"
))
if not result.fetchone():
op.create_index(index_name, table, ['account_id'])
def downgrade() -> None:
# Revert to nullable
for table in ('users', 'trees', 'tree_categories', 'tree_tags',
'step_categories', 'step_library', 'tree_embeddings', 'feedback'):
op.alter_column(table, 'account_id', nullable=True)
for table, index_name in (
('tree_embeddings', 'ix_tree_embeddings_account_id'),
('feedback', 'ix_feedback_account_id'),
):
try:
op.drop_index(index_name, table_name=table)
except Exception:
pass

View File

@@ -39,10 +39,10 @@ class TreeCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -1,6 +1,5 @@
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import String, Text, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID
@@ -11,7 +10,7 @@ class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="SET NULL"), nullable=True)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
email: Mapped[str] = mapped_column(String(255), nullable=False)
feedback_type: Mapped[str] = mapped_column(String(50), nullable=False)

View File

@@ -38,10 +38,10 @@ class StepCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -46,10 +46,10 @@ class StepLibrary(Base):
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)

View File

@@ -51,10 +51,10 @@ class TreeTag(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -76,10 +76,10 @@ class Tree(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)

View File

@@ -37,10 +37,10 @@ class TreeEmbedding(Base):
ForeignKey("trees.id", ondelete="CASCADE"),
nullable=False,
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
)
chunk_type: Mapped[str] = mapped_column(
String(30),

View File

@@ -43,10 +43,10 @@ class User(Base):
must_change_password: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
# Account-based multi-tenancy (new)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="RESTRICT"),
nullable=True,
nullable=False,
index=True
)
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")

View File

@@ -7,7 +7,7 @@ real PostgreSQL test DB (same as all other integration tests).
import pytest
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, text
from app.models.account import Account
from app.models.user import User
@@ -506,3 +506,40 @@ async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncS
columns = {row[0] for row in result.fetchall()}
assert 'id' in columns, "platform_steps.id must exist"
assert 'account_id' not in columns, "platform_steps must not have account_id (global content)"
# ── Group 9: SET NOT NULL on existing nullable columns ────────────────────────
@pytest.mark.asyncio
async def test_tree_account_id_is_not_null(test_db: AsyncSession):
"""trees.account_id must be NOT NULL after Phase 1 — enforced at DB level."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(Tree(
name="Bad tree",
# account_id intentionally omitted
author_id=None,
visibility="private",
tree_type="troubleshooting",
tree_structure={},
is_active=True,
status="draft",
))
await test_db.flush()
@pytest.mark.asyncio
async def test_user_account_id_is_not_null(test_db: AsyncSession):
"""users.account_id must be NOT NULL after Phase 1."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(User(
email=f"orphan-{uuid.uuid4().hex[:6]}@example.com",
name="Orphan",
password_hash=get_password_hash("x"),
is_active=True,
role="engineer",
account_role="engineer",
# account_id intentionally omitted
))
await test_db.flush()