diff --git a/backend/app/services/copilot_service.py b/backend/app/services/copilot_service.py index 175735ee..76327ff4 100644 --- a/backend/app/services/copilot_service.py +++ b/backend/app/services/copilot_service.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta from typing import Optional, Any from uuid import UUID -from sqlalchemy import select +from sqlalchemy import select, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -103,13 +103,23 @@ async def start_conversation( Returns (conversation, greeting_message). """ - # Load tree + # Load tree — must be accessible to this account. + # Allows own account's trees, default trees, and public trees. + # Raises ValueError (caught by endpoint as 404) if not found or not accessible. result = await db.execute( - select(Tree).options(selectinload(Tree.tags)).where(Tree.id == tree_id) + select(Tree).options(selectinload(Tree.tags)).where( + Tree.id == tree_id, + or_( + Tree.account_id == account_id, + Tree.author_id == user_id, + Tree.is_default == True, + Tree.is_public == True, + ), + ) ) tree = result.scalar_one_or_none() if not tree: - raise ValueError(f"Tree {tree_id} not found") + raise ValueError(f"Tree {tree_id} not found or not accessible") conversation = CopilotConversation( user_id=user_id, diff --git a/backend/tests/test_tenant_isolation_p0.py b/backend/tests/test_tenant_isolation_p0.py new file mode 100644 index 00000000..04018f30 --- /dev/null +++ b/backend/tests/test_tenant_isolation_p0.py @@ -0,0 +1,107 @@ +"""Cross-tenant isolation tests for Phase 0 gap fixes.""" +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +import uuid + +from app.models.account import Account +from app.models.user import User +from app.models.tree import Tree +from app.core.security import get_password_hash + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +async def _create_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User, str]: + """Create an account + owner user. Returns (account, user, plain_password).""" + account = Account( + name=f"Test Corp {suffix}", + display_code=uuid.uuid4().hex[:8], + ) + db.add(account) + await db.flush() + + password = "TestPass123!" + user = User( + email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com", + name=f"User {suffix}", + password_hash=get_password_hash(password), + is_active=True, + account_id=account.id, + account_role="owner", + ) + db.add(user) + await db.flush() + return account, user, password + + +async def _login(client: AsyncClient, email: str, password: str) -> dict: + """Return auth headers for a user.""" + resp = await client.post( + "/api/v1/auth/login/json", + json={"email": email, "password": password}, + ) + assert resp.status_code == 200, resp.text + return {"Authorization": f"Bearer {resp.json()['access_token']}"} + + +async def _create_private_tree(db: AsyncSession, account: Account, user: User) -> Tree: + """Create a private tree owned by account.""" + tree = Tree( + name=f"Private Tree {uuid.uuid4().hex[:6]}", + account_id=account.id, + author_id=user.id, + visibility="private", + tree_type="troubleshooting", + tree_structure={"id": "root", "type": "start", "children": []}, + is_active=True, + status="published", + ) + db.add(tree) + await db.flush() + return tree + + +# ── Task 1: Copilot bypass ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_copilot_cannot_start_conversation_with_other_account_tree( + client: AsyncClient, test_db: AsyncSession +): + """Account A cannot start a copilot conversation using Account B's private tree UUID.""" + acct_a, user_a, pass_a = await _create_account_and_user(test_db, "a") + acct_b, user_b, pass_b = await _create_account_and_user(test_db, "b") + tree_b = await _create_private_tree(test_db, acct_b, user_b) + await test_db.commit() + + headers_a = await _login(client, user_a.email, pass_a) + + resp = await client.post( + "/api/v1/copilot/conversations", + json={"tree_id": str(tree_b.id), "session_id": None, "current_node_id": None}, + headers=headers_a, + ) + # Must be 404 (not 200, not 403 — never confirm existence) + assert resp.status_code == 404, f"Expected 404, got {resp.status_code}: {resp.text}" + + +@pytest.mark.asyncio +async def test_copilot_service_rejects_cross_tenant_tree(test_db: AsyncSession): + """copilot_service.start_conversation raises ValueError for cross-tenant tree.""" + from app.services import copilot_service + + acct_a, user_a, pass_a = await _create_account_and_user(test_db, "svc-a") + acct_b, user_b, pass_b = await _create_account_and_user(test_db, "svc-b") + tree_b = await _create_private_tree(test_db, acct_b, user_b) + await test_db.commit() + + with pytest.raises(ValueError, match="not found"): + await copilot_service.start_conversation( + user_id=user_a.id, + account_id=user_a.account_id, + tree_id=tree_b.id, + session_id=None, + current_node_id=None, + db=test_db, + )