Files
resolutionflow/backend/tests/test_public_templates.py
chihlasm 758cd61621 fix: propagate account_id through all write paths missing NOT NULL coverage
Service layer (production code):
- branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint
  from session.account_id; load session in create_fork for this purpose
- handoff_manager: set account_id on SessionHandoff from session.account_id
- ai_suggestions endpoint: set account_id on AISuggestion from current_user
- steps endpoint (/feedback): set account_id on StepRating from current_user
- ratings endpoint: set account_id on StepRating from current_user

Test infrastructure:
- conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after
  Base.metadata.create_all so global categories and gallery items have a valid FK
- test_rls_isolation: add _ensure_rls_schema fixture that runs
  'alembic upgrade head' before module tests — previous function-scoped
  test_db fixtures drop the schema, leaving the RLS tests with no tables
- test_branding: create Account before User in helper functions
- test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate
- test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree,
  ScriptTemplate, TreeCategory
- test_resolution_outputs: set account_id=session.account_id on
  SessionResolutionOutput
- test_analytics_phase5: set account_id on PsaPostLog
- test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in
  migration default test (NOT NULL now enforced)
- test_maintenance_schedules: set account_id on other_tree
- test_save_session_as_tree: set account_id on all 5 Session() constructors

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 04:24:36 +00:00

369 lines
15 KiB
Python

"""Tests for the public templates gallery API.
Endpoints under /api/v1/public/templates require no authentication.
"""
import uuid
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.script_template import ScriptCategory, ScriptTemplate
from app.models.tree import Tree
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_tree_structure(depth: int = 4) -> dict:
"""Build a nested tree structure with the given depth."""
node_id = str(uuid.uuid4())
def _make_node(d: int, node_id: str) -> dict:
node = {
"id": node_id,
"type": "decision" if d > 0 else "solution",
"question": f"Question at depth {depth - d}",
"children": [],
}
if d > 0:
child_id = str(uuid.uuid4())
node["children"].append(_make_node(d - 1, child_id))
return node
return _make_node(depth, node_id)
async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", featured: bool = True) -> Tree:
tree = Tree(
name=name,
description="A featured flow for the gallery",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(4),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=featured,
is_active=True,
usage_count=42,
visibility="public",
status="published",
)
db.add(tree)
await db.commit()
await db.refresh(tree)
return tree
async def _create_script_category(db: AsyncSession, name: str = "Networking") -> ScriptCategory:
cat = ScriptCategory(
name=name,
slug=name.lower().replace(" ", "-"),
is_active=True,
)
db.add(cat)
await db.commit()
await db.refresh(cat)
return cat
async def _create_featured_script(
db: AsyncSession,
category: ScriptCategory,
name: str = "Featured Script",
featured: bool = True,
script_body: str = "Get-NetAdapter | Format-Table",
) -> ScriptTemplate:
script = ScriptTemplate(
category_id=category.id,
account_id=_PLATFORM_ACCOUNT_ID,
name=name,
slug=name.lower().replace(" ", "-"),
description="A gallery-featured script",
script_body=script_body,
parameters_schema={
"parameters": [
{"name": "ComputerName", "description": "Target computer", "type": "string", "required": False},
]
},
default_values={},
validation_rules={},
tags=["networking", "diagnostics"],
complexity="beginner",
requires_elevation=False,
requires_modules=[],
is_gallery_featured=featured,
is_active=True,
is_verified=True,
usage_count=10,
)
db.add(script)
await db.commit()
await db.refresh(script)
return script
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestGalleryAccessibility:
"""Gallery endpoints must work without any authentication."""
async def test_gallery_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
"""GET /public/templates requires no auth token."""
response = await client.get("/api/v1/public/templates")
assert response.status_code == 200
async def test_gallery_returns_json(self, client: AsyncClient, test_db: AsyncSession):
response = await client.get("/api/v1/public/templates")
data = response.json()
assert "flow_templates" in data
assert "script_templates" in data
assert "total_flows" in data
assert "total_scripts" in data
assert "categories" in data
async def test_categories_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
response = await client.get("/api/v1/public/templates/categories")
assert response.status_code == 200
async def test_search_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
response = await client.get("/api/v1/public/templates/search?q=network")
assert response.status_code == 200
@pytest.mark.asyncio
class TestGalleryFeatureFilter:
"""Gallery must only return items where is_gallery_featured=True."""
async def test_featured_flow_appears_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
tree = await _create_featured_tree(test_db, name="Should Appear", featured=True)
response = await client.get("/api/v1/public/templates?type=flows")
data = response.json()
ids = [t["id"] for t in data["flow_templates"]]
assert str(tree.id) in ids
async def test_unfeatured_flow_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
tree = await _create_featured_tree(test_db, name="Should Not Appear", featured=False)
response = await client.get("/api/v1/public/templates?type=flows")
data = response.json()
ids = [t["id"] for t in data["flow_templates"]]
assert str(tree.id) not in ids
async def test_inactive_flow_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
tree = await _create_featured_tree(test_db, name="Inactive Flow", featured=True)
tree.is_active = False
await test_db.commit()
response = await client.get("/api/v1/public/templates?type=flows")
data = response.json()
ids = [t["id"] for t in data["flow_templates"]]
assert str(tree.id) not in ids
async def test_featured_script_appears_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
script = await _create_featured_script(test_db, cat, featured=True)
response = await client.get("/api/v1/public/templates?type=scripts")
data = response.json()
ids = [s["id"] for s in data["script_templates"]]
assert str(script.id) in ids
async def test_unfeatured_script_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
script = await _create_featured_script(test_db, cat, featured=False)
response = await client.get("/api/v1/public/templates?type=scripts")
data = response.json()
ids = [s["id"] for s in data["script_templates"]]
assert str(script.id) not in ids
@pytest.mark.asyncio
class TestTreeStructureTruncation:
"""The full tree structure must be truncated to 3 levels for the public preview."""
async def test_preview_structure_not_null(self, client: AsyncClient, test_db: AsyncSession):
await _create_featured_tree(test_db, name="Truncation Test")
response = await client.get("/api/v1/public/templates?type=flows")
data = response.json()
assert len(data["flow_templates"]) > 0
template = data["flow_templates"][0]
assert template["preview_structure"] is not None
async def test_preview_structure_truncated_to_3_levels(self, client: AsyncClient, test_db: AsyncSession):
"""Full tree has depth 4, preview should be truncated to depth 3."""
tree = await _create_featured_tree(test_db, name="Deep Tree")
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
assert response.status_code == 200
data = response.json()
preview = data["preview_structure"]
assert preview is not None
# Walk the structure and confirm depth does not exceed 3
def _max_depth(node: dict, current: int = 0) -> int:
if not node:
return current
d = current
for child in node.get("children", []):
d = max(d, _max_depth(child, current + 1))
for opt in node.get("options", []):
if isinstance(opt, dict):
for child in opt.get("children", []):
d = max(d, _max_depth(child, current + 1))
return d
max_d = _max_depth(preview)
assert max_d <= 3, f"Preview depth {max_d} exceeds 3 levels"
async def test_flow_detail_does_not_return_full_structure_beyond_3_levels(
self, client: AsyncClient, test_db: AsyncSession
):
"""The flow detail endpoint must truncate tree_structure."""
tree = await _create_featured_tree(test_db, name="Depth Check Flow")
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
assert response.status_code == 200
data = response.json()
# Full structure has 4 levels, preview must be capped at 3
assert "preview_structure" in data
assert "tree_structure" not in data # raw full structure key should not appear
@pytest.mark.asyncio
class TestScriptBodyProtection:
"""script_body must never be exposed in public endpoints."""
async def test_script_body_not_in_gallery_listing(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
await _create_featured_script(test_db, cat, script_body="SUPER SECRET SCRIPT BODY")
response = await client.get("/api/v1/public/templates?type=scripts")
text = response.text
assert "SUPER SECRET SCRIPT BODY" not in text
assert "script_body" not in text
async def test_script_body_not_in_detail_response(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
script = await _create_featured_script(test_db, cat, script_body="CONFIDENTIAL_BODY_XYZ")
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
assert response.status_code == 200
text = response.text
assert "CONFIDENTIAL_BODY_XYZ" not in text
assert "script_body" not in text
async def test_script_detail_includes_parameters_without_body(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
script = await _create_featured_script(test_db, cat)
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
assert response.status_code == 200
data = response.json()
# Parameters should be present (name/description only)
assert "parameters" in data
assert len(data["parameters"]) > 0
param = data["parameters"][0]
assert "name" in param
# script_body must not appear anywhere
assert "script_body" not in data
@pytest.mark.asyncio
class TestSearch:
"""Full-text search across featured gallery items."""
async def test_search_returns_matching_flow(self, client: AsyncClient, test_db: AsyncSession):
await _create_featured_tree(test_db, name="VPN Connectivity Troubleshooting")
response = await client.get("/api/v1/public/templates/search?q=VPN")
assert response.status_code == 200
data = response.json()
assert data["total_flows"] >= 1
names = [t["name"] for t in data["flow_templates"]]
assert any("VPN" in n for n in names)
async def test_search_returns_matching_script(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
await _create_featured_script(test_db, cat, name="DNS Flush Script")
response = await client.get("/api/v1/public/templates/search?q=DNS+Flush")
assert response.status_code == 200
data = response.json()
assert data["total_scripts"] >= 1
names = [s["name"] for s in data["script_templates"]]
assert any("DNS" in n for n in names)
async def test_search_excludes_unfeatured_items(self, client: AsyncClient, test_db: AsyncSession):
await _create_featured_tree(test_db, name="UniqueName_NotFeatured_XYZ", featured=False)
response = await client.get("/api/v1/public/templates/search?q=UniqueName_NotFeatured_XYZ")
assert response.status_code == 200
data = response.json()
assert data["total_flows"] == 0
async def test_search_requires_query_param(self, client: AsyncClient, test_db: AsyncSession):
response = await client.get("/api/v1/public/templates/search")
assert response.status_code == 422
@pytest.mark.asyncio
class TestCategoriesEndpoint:
"""Categories endpoint returns a list of categories with counts."""
async def test_categories_returns_list(self, client: AsyncClient, test_db: AsyncSession):
response = await client.get("/api/v1/public/templates/categories")
assert response.status_code == 200
data = response.json()
assert "categories" in data
assert isinstance(data["categories"], list)
async def test_categories_reflect_featured_content(self, client: AsyncClient, test_db: AsyncSession):
from app.models.category import TreeCategory
# Create a category and a featured tree in that category
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
test_db.add(cat)
await test_db.commit()
await test_db.refresh(cat)
tree = Tree(
name="Router Diagnostics",
tree_type="troubleshooting",
tree_structure=_make_tree_structure(2),
account_id=_PLATFORM_ACCOUNT_ID,
is_gallery_featured=True,
is_active=True,
usage_count=5,
visibility="public",
status="published",
category_id=cat.id,
)
test_db.add(tree)
await test_db.commit()
response = await client.get("/api/v1/public/templates/categories")
data = response.json()
cat_names = [c["name"] for c in data["categories"]]
assert "Networking" in cat_names
@pytest.mark.asyncio
class TestNotFoundBehavior:
"""Non-featured or non-existent items return 404 on detail endpoints."""
async def test_flow_detail_404_for_nonexistent(self, client: AsyncClient, test_db: AsyncSession):
fake_id = str(uuid.uuid4())
response = await client.get(f"/api/v1/public/templates/flows/{fake_id}")
assert response.status_code == 404
async def test_flow_detail_404_for_unfeatured(self, client: AsyncClient, test_db: AsyncSession):
tree = await _create_featured_tree(test_db, name="Not Featured", featured=False)
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
assert response.status_code == 404
async def test_script_detail_404_for_nonexistent(self, client: AsyncClient, test_db: AsyncSession):
fake_id = str(uuid.uuid4())
response = await client.get(f"/api/v1/public/templates/scripts/{fake_id}")
assert response.status_code == 404
async def test_script_detail_404_for_unfeatured(self, client: AsyncClient, test_db: AsyncSession):
cat = await _create_script_category(test_db)
script = await _create_featured_script(test_db, cat, featured=False)
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
assert response.status_code == 404