Service layer (production code): - branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint from session.account_id; load session in create_fork for this purpose - handoff_manager: set account_id on SessionHandoff from session.account_id - ai_suggestions endpoint: set account_id on AISuggestion from current_user - steps endpoint (/feedback): set account_id on StepRating from current_user - ratings endpoint: set account_id on StepRating from current_user Test infrastructure: - conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after Base.metadata.create_all so global categories and gallery items have a valid FK - test_rls_isolation: add _ensure_rls_schema fixture that runs 'alembic upgrade head' before module tests — previous function-scoped test_db fixtures drop the schema, leaving the RLS tests with no tables - test_branding: create Account before User in helper functions - test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate - test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree, ScriptTemplate, TreeCategory - test_resolution_outputs: set account_id=session.account_id on SessionResolutionOutput - test_analytics_phase5: set account_id on PsaPostLog - test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in migration default test (NOT NULL now enforced) - test_maintenance_schedules: set account_id on other_tree - test_save_session_as_tree: set account_id on all 5 Session() constructors Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
369 lines
15 KiB
Python
369 lines
15 KiB
Python
"""Tests for the public templates gallery API.
|
|
|
|
Endpoints under /api/v1/public/templates require no authentication.
|
|
"""
|
|
|
|
import uuid
|
|
import pytest
|
|
from httpx import AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.script_template import ScriptCategory, ScriptTemplate
|
|
from app.models.tree import Tree
|
|
|
|
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_tree_structure(depth: int = 4) -> dict:
|
|
"""Build a nested tree structure with the given depth."""
|
|
node_id = str(uuid.uuid4())
|
|
|
|
def _make_node(d: int, node_id: str) -> dict:
|
|
node = {
|
|
"id": node_id,
|
|
"type": "decision" if d > 0 else "solution",
|
|
"question": f"Question at depth {depth - d}",
|
|
"children": [],
|
|
}
|
|
if d > 0:
|
|
child_id = str(uuid.uuid4())
|
|
node["children"].append(_make_node(d - 1, child_id))
|
|
return node
|
|
|
|
return _make_node(depth, node_id)
|
|
|
|
|
|
async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", featured: bool = True) -> Tree:
|
|
tree = Tree(
|
|
name=name,
|
|
description="A featured flow for the gallery",
|
|
tree_type="troubleshooting",
|
|
tree_structure=_make_tree_structure(4),
|
|
account_id=_PLATFORM_ACCOUNT_ID,
|
|
is_gallery_featured=featured,
|
|
is_active=True,
|
|
usage_count=42,
|
|
visibility="public",
|
|
status="published",
|
|
)
|
|
db.add(tree)
|
|
await db.commit()
|
|
await db.refresh(tree)
|
|
return tree
|
|
|
|
|
|
async def _create_script_category(db: AsyncSession, name: str = "Networking") -> ScriptCategory:
|
|
cat = ScriptCategory(
|
|
name=name,
|
|
slug=name.lower().replace(" ", "-"),
|
|
is_active=True,
|
|
)
|
|
db.add(cat)
|
|
await db.commit()
|
|
await db.refresh(cat)
|
|
return cat
|
|
|
|
|
|
async def _create_featured_script(
|
|
db: AsyncSession,
|
|
category: ScriptCategory,
|
|
name: str = "Featured Script",
|
|
featured: bool = True,
|
|
script_body: str = "Get-NetAdapter | Format-Table",
|
|
) -> ScriptTemplate:
|
|
script = ScriptTemplate(
|
|
category_id=category.id,
|
|
account_id=_PLATFORM_ACCOUNT_ID,
|
|
name=name,
|
|
slug=name.lower().replace(" ", "-"),
|
|
description="A gallery-featured script",
|
|
script_body=script_body,
|
|
parameters_schema={
|
|
"parameters": [
|
|
{"name": "ComputerName", "description": "Target computer", "type": "string", "required": False},
|
|
]
|
|
},
|
|
default_values={},
|
|
validation_rules={},
|
|
tags=["networking", "diagnostics"],
|
|
complexity="beginner",
|
|
requires_elevation=False,
|
|
requires_modules=[],
|
|
is_gallery_featured=featured,
|
|
is_active=True,
|
|
is_verified=True,
|
|
usage_count=10,
|
|
)
|
|
db.add(script)
|
|
await db.commit()
|
|
await db.refresh(script)
|
|
return script
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test classes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGalleryAccessibility:
|
|
"""Gallery endpoints must work without any authentication."""
|
|
|
|
async def test_gallery_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
|
|
"""GET /public/templates requires no auth token."""
|
|
response = await client.get("/api/v1/public/templates")
|
|
assert response.status_code == 200
|
|
|
|
async def test_gallery_returns_json(self, client: AsyncClient, test_db: AsyncSession):
|
|
response = await client.get("/api/v1/public/templates")
|
|
data = response.json()
|
|
assert "flow_templates" in data
|
|
assert "script_templates" in data
|
|
assert "total_flows" in data
|
|
assert "total_scripts" in data
|
|
assert "categories" in data
|
|
|
|
async def test_categories_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
|
|
response = await client.get("/api/v1/public/templates/categories")
|
|
assert response.status_code == 200
|
|
|
|
async def test_search_accessible_without_auth(self, client: AsyncClient, test_db: AsyncSession):
|
|
response = await client.get("/api/v1/public/templates/search?q=network")
|
|
assert response.status_code == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestGalleryFeatureFilter:
|
|
"""Gallery must only return items where is_gallery_featured=True."""
|
|
|
|
async def test_featured_flow_appears_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
|
|
tree = await _create_featured_tree(test_db, name="Should Appear", featured=True)
|
|
response = await client.get("/api/v1/public/templates?type=flows")
|
|
data = response.json()
|
|
ids = [t["id"] for t in data["flow_templates"]]
|
|
assert str(tree.id) in ids
|
|
|
|
async def test_unfeatured_flow_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
|
|
tree = await _create_featured_tree(test_db, name="Should Not Appear", featured=False)
|
|
response = await client.get("/api/v1/public/templates?type=flows")
|
|
data = response.json()
|
|
ids = [t["id"] for t in data["flow_templates"]]
|
|
assert str(tree.id) not in ids
|
|
|
|
async def test_inactive_flow_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
|
|
tree = await _create_featured_tree(test_db, name="Inactive Flow", featured=True)
|
|
tree.is_active = False
|
|
await test_db.commit()
|
|
response = await client.get("/api/v1/public/templates?type=flows")
|
|
data = response.json()
|
|
ids = [t["id"] for t in data["flow_templates"]]
|
|
assert str(tree.id) not in ids
|
|
|
|
async def test_featured_script_appears_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
script = await _create_featured_script(test_db, cat, featured=True)
|
|
response = await client.get("/api/v1/public/templates?type=scripts")
|
|
data = response.json()
|
|
ids = [s["id"] for s in data["script_templates"]]
|
|
assert str(script.id) in ids
|
|
|
|
async def test_unfeatured_script_not_in_gallery(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
script = await _create_featured_script(test_db, cat, featured=False)
|
|
response = await client.get("/api/v1/public/templates?type=scripts")
|
|
data = response.json()
|
|
ids = [s["id"] for s in data["script_templates"]]
|
|
assert str(script.id) not in ids
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestTreeStructureTruncation:
|
|
"""The full tree structure must be truncated to 3 levels for the public preview."""
|
|
|
|
async def test_preview_structure_not_null(self, client: AsyncClient, test_db: AsyncSession):
|
|
await _create_featured_tree(test_db, name="Truncation Test")
|
|
response = await client.get("/api/v1/public/templates?type=flows")
|
|
data = response.json()
|
|
assert len(data["flow_templates"]) > 0
|
|
template = data["flow_templates"][0]
|
|
assert template["preview_structure"] is not None
|
|
|
|
async def test_preview_structure_truncated_to_3_levels(self, client: AsyncClient, test_db: AsyncSession):
|
|
"""Full tree has depth 4, preview should be truncated to depth 3."""
|
|
tree = await _create_featured_tree(test_db, name="Deep Tree")
|
|
|
|
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
preview = data["preview_structure"]
|
|
assert preview is not None
|
|
|
|
# Walk the structure and confirm depth does not exceed 3
|
|
def _max_depth(node: dict, current: int = 0) -> int:
|
|
if not node:
|
|
return current
|
|
d = current
|
|
for child in node.get("children", []):
|
|
d = max(d, _max_depth(child, current + 1))
|
|
for opt in node.get("options", []):
|
|
if isinstance(opt, dict):
|
|
for child in opt.get("children", []):
|
|
d = max(d, _max_depth(child, current + 1))
|
|
return d
|
|
|
|
max_d = _max_depth(preview)
|
|
assert max_d <= 3, f"Preview depth {max_d} exceeds 3 levels"
|
|
|
|
async def test_flow_detail_does_not_return_full_structure_beyond_3_levels(
|
|
self, client: AsyncClient, test_db: AsyncSession
|
|
):
|
|
"""The flow detail endpoint must truncate tree_structure."""
|
|
tree = await _create_featured_tree(test_db, name="Depth Check Flow")
|
|
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Full structure has 4 levels, preview must be capped at 3
|
|
assert "preview_structure" in data
|
|
assert "tree_structure" not in data # raw full structure key should not appear
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestScriptBodyProtection:
|
|
"""script_body must never be exposed in public endpoints."""
|
|
|
|
async def test_script_body_not_in_gallery_listing(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
await _create_featured_script(test_db, cat, script_body="SUPER SECRET SCRIPT BODY")
|
|
response = await client.get("/api/v1/public/templates?type=scripts")
|
|
text = response.text
|
|
assert "SUPER SECRET SCRIPT BODY" not in text
|
|
assert "script_body" not in text
|
|
|
|
async def test_script_body_not_in_detail_response(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
script = await _create_featured_script(test_db, cat, script_body="CONFIDENTIAL_BODY_XYZ")
|
|
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
|
|
assert response.status_code == 200
|
|
text = response.text
|
|
assert "CONFIDENTIAL_BODY_XYZ" not in text
|
|
assert "script_body" not in text
|
|
|
|
async def test_script_detail_includes_parameters_without_body(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
script = await _create_featured_script(test_db, cat)
|
|
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Parameters should be present (name/description only)
|
|
assert "parameters" in data
|
|
assert len(data["parameters"]) > 0
|
|
param = data["parameters"][0]
|
|
assert "name" in param
|
|
# script_body must not appear anywhere
|
|
assert "script_body" not in data
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSearch:
|
|
"""Full-text search across featured gallery items."""
|
|
|
|
async def test_search_returns_matching_flow(self, client: AsyncClient, test_db: AsyncSession):
|
|
await _create_featured_tree(test_db, name="VPN Connectivity Troubleshooting")
|
|
response = await client.get("/api/v1/public/templates/search?q=VPN")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_flows"] >= 1
|
|
names = [t["name"] for t in data["flow_templates"]]
|
|
assert any("VPN" in n for n in names)
|
|
|
|
async def test_search_returns_matching_script(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
await _create_featured_script(test_db, cat, name="DNS Flush Script")
|
|
response = await client.get("/api/v1/public/templates/search?q=DNS+Flush")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_scripts"] >= 1
|
|
names = [s["name"] for s in data["script_templates"]]
|
|
assert any("DNS" in n for n in names)
|
|
|
|
async def test_search_excludes_unfeatured_items(self, client: AsyncClient, test_db: AsyncSession):
|
|
await _create_featured_tree(test_db, name="UniqueName_NotFeatured_XYZ", featured=False)
|
|
response = await client.get("/api/v1/public/templates/search?q=UniqueName_NotFeatured_XYZ")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_flows"] == 0
|
|
|
|
async def test_search_requires_query_param(self, client: AsyncClient, test_db: AsyncSession):
|
|
response = await client.get("/api/v1/public/templates/search")
|
|
assert response.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestCategoriesEndpoint:
|
|
"""Categories endpoint returns a list of categories with counts."""
|
|
|
|
async def test_categories_returns_list(self, client: AsyncClient, test_db: AsyncSession):
|
|
response = await client.get("/api/v1/public/templates/categories")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "categories" in data
|
|
assert isinstance(data["categories"], list)
|
|
|
|
async def test_categories_reflect_featured_content(self, client: AsyncClient, test_db: AsyncSession):
|
|
from app.models.category import TreeCategory
|
|
|
|
# Create a category and a featured tree in that category
|
|
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
|
|
test_db.add(cat)
|
|
await test_db.commit()
|
|
await test_db.refresh(cat)
|
|
|
|
tree = Tree(
|
|
name="Router Diagnostics",
|
|
tree_type="troubleshooting",
|
|
tree_structure=_make_tree_structure(2),
|
|
account_id=_PLATFORM_ACCOUNT_ID,
|
|
is_gallery_featured=True,
|
|
is_active=True,
|
|
usage_count=5,
|
|
visibility="public",
|
|
status="published",
|
|
category_id=cat.id,
|
|
)
|
|
test_db.add(tree)
|
|
await test_db.commit()
|
|
|
|
response = await client.get("/api/v1/public/templates/categories")
|
|
data = response.json()
|
|
cat_names = [c["name"] for c in data["categories"]]
|
|
assert "Networking" in cat_names
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestNotFoundBehavior:
|
|
"""Non-featured or non-existent items return 404 on detail endpoints."""
|
|
|
|
async def test_flow_detail_404_for_nonexistent(self, client: AsyncClient, test_db: AsyncSession):
|
|
fake_id = str(uuid.uuid4())
|
|
response = await client.get(f"/api/v1/public/templates/flows/{fake_id}")
|
|
assert response.status_code == 404
|
|
|
|
async def test_flow_detail_404_for_unfeatured(self, client: AsyncClient, test_db: AsyncSession):
|
|
tree = await _create_featured_tree(test_db, name="Not Featured", featured=False)
|
|
response = await client.get(f"/api/v1/public/templates/flows/{tree.id}")
|
|
assert response.status_code == 404
|
|
|
|
async def test_script_detail_404_for_nonexistent(self, client: AsyncClient, test_db: AsyncSession):
|
|
fake_id = str(uuid.uuid4())
|
|
response = await client.get(f"/api/v1/public/templates/scripts/{fake_id}")
|
|
assert response.status_code == 404
|
|
|
|
async def test_script_detail_404_for_unfeatured(self, client: AsyncClient, test_db: AsyncSession):
|
|
cat = await _create_script_category(test_db)
|
|
script = await _create_featured_script(test_db, cat, featured=False)
|
|
response = await client.get(f"/api/v1/public/templates/scripts/{script.id}")
|
|
assert response.status_code == 404
|