diff --git a/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py b/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py new file mode 100644 index 00000000..ca32f0d2 --- /dev/null +++ b/backend/alembic/versions/a1d2a84b9abb_add_account_id_user_personalization.py @@ -0,0 +1,45 @@ +"""add account_id to user personalization tables + +Revision ID: a1d2a84b9abb +Revises: 7167e9374b0c +Create Date: 2026-04-09 00:00:00.000000 +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +revision: str = 'a1d2a84b9abb' +down_revision: Union[str, None] = '7167e9374b0c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True)) + op.create_foreign_key( + f'fk_{table}_account_id', table, 'accounts', + ['account_id'], ['id'], ondelete='CASCADE', + ) + op.execute(f""" + UPDATE {table} t + SET account_id = u.account_id + FROM users u + WHERE t.user_id = u.id + AND t.account_id IS NULL + """) + result = op.get_bind().execute( + sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL") + ) + count = result.scalar() + if count > 0: + raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.") + op.alter_column(table, 'account_id', nullable=False) + op.create_index(f'ix_{table}_account_id', table, ['account_id']) + + +def downgrade() -> None: + for table in ('user_folders', 'user_pinned_trees'): + op.drop_index(f'ix_{table}_account_id', table_name=table) + op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey') + op.drop_column(table, 'account_id') diff --git a/backend/app/models/folder.py b/backend/app/models/folder.py index 7edaeaef..50923c86 100644 --- a/backend/app/models/folder.py +++ b/backend/app/models/folder.py @@ -46,6 +46,12 @@ class UserFolder(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) name: Mapped[str] = mapped_column(String(100), nullable=False) color: Mapped[str] = mapped_column(String(7), nullable=False, default="#6366f1") icon: Mapped[str] = mapped_column(String(50), nullable=False, default="folder") diff --git a/backend/app/models/user_pinned_tree.py b/backend/app/models/user_pinned_tree.py index c27edd08..d23b463a 100644 --- a/backend/app/models/user_pinned_tree.py +++ b/backend/app/models/user_pinned_tree.py @@ -24,6 +24,12 @@ class UserPinnedTree(Base): nullable=False, index=True ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) tree_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("trees.id", ondelete="CASCADE"), diff --git a/backend/tests/test_phase1_migrations.py b/backend/tests/test_phase1_migrations.py index 416ca4ab..a4f33de6 100644 --- a/backend/tests/test_phase1_migrations.py +++ b/backend/tests/test_phase1_migrations.py @@ -298,3 +298,48 @@ async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession assert row.account_id == account.id, ( f"account_id should be logger's account ({account.id}), got {row.account_id}" ) + + +# ── Group 4: User personalization ──────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_user_folder_account_id_matches_user(test_db: AsyncSession): + """user_folders.account_id must match the owning user's account_id.""" + from app.models.folder import UserFolder + + account, user = await _make_account_and_user(test_db, "uf1") + folder = UserFolder( + user_id=user.id, + account_id=account.id, + name="My Folder", + color="#6366f1", + icon="folder", + display_order=0, + ) + test_db.add(folder) + await test_db.commit() + + result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id)) + row = result.scalar_one() + assert row.account_id == account.id + + +@pytest.mark.asyncio +async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession): + """user_pinned_trees.account_id must match the pinning user's account_id.""" + from app.models.user_pinned_tree import UserPinnedTree + + account, user = await _make_account_and_user(test_db, "pt1") + tree = await _make_tree(test_db, account, user) + pin = UserPinnedTree( + user_id=user.id, + tree_id=tree.id, + account_id=account.id, + display_order=0, + ) + test_db.add(pin) + await test_db.commit() + + result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id)) + row = result.scalar_one() + assert row.account_id == account.id