Files
resolutionflow/backend/app/api/endpoints/kb_accelerator.py
Michael Chihlas 53b6878742 fix: KB procedural import — map step content to description field
Steps built by _build_procedural_tree() were stored under the "content"
key but StepDetail.tsx reads "description". Renamed the key so step
text and commands display correctly during session execution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-11 12:50:13 -04:00

959 lines
34 KiB
Python

"""KB Accelerator endpoints.
Upload KB articles, convert to flows via AI, review, and commit.
POST /kb-accelerator/upload — Upload file or paste text
GET /kb-accelerator/{id} — Get import with nodes
GET /kb-accelerator — List imports for account
POST /kb-accelerator/{id}/convert — Re-trigger AI conversion
PATCH /kb-accelerator/{id}/nodes/{nid} — Edit a node
POST /kb-accelerator/{id}/commit — Commit to flow library
DELETE /kb-accelerator/{id} — Cancel/cleanup
GET /kb-accelerator/quota — Plan entitlements + usage
"""
import logging
import mimetypes
from datetime import datetime, timezone
from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, UploadFile, File, Form, status
from sqlalchemy import select, func, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
from app.core.config import settings
from app.core.rate_limit import limiter
from app.core.subscriptions import get_plan_limits
from app.core.ai_quota_service import get_user_plan
from app.core.ai_tree_validator import validate_generated_tree
from app.core.tree_validation import validate_procedural_structure
from app.core.kb_extraction_service import extract_text
from app.core.kb_conversion_service import convert_document
from app.models.kb_import import KBImport, KBImportNode
from app.models.plan_limits import PlanLimits
from app.models.tree import Tree
from app.models.user import User
from app.schemas.kb_accelerator import (
KBUploadTextRequest,
KBNodeEditRequest,
KBCommitRequest,
KBUploadResponse,
KBImportResponse,
KBImportNodeResponse,
KBImportSummary,
KBImportListResponse,
KBCommitResponse,
KBQuotaResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/kb-accelerator", tags=["kb-accelerator"])
# Max upload size: 10MB
MAX_UPLOAD_SIZE = 10 * 1024 * 1024
ALLOWED_EXTENSIONS = {
"txt": ["text/plain"],
"docx": ["application/vnd.openxmlformats-officedocument.wordprocessingml.document"],
}
# Phase 2 formats (not yet enabled)
PHASE2_EXTENSIONS = {
"pdf": ["application/pdf"],
"html": ["text/html"],
"md": ["text/markdown", "text/plain"],
}
def _detect_format(filename: str) -> str | None:
"""Detect source format from filename extension."""
if not filename:
return None
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else None
if ext in ALLOWED_EXTENSIONS or ext in PHASE2_EXTENSIONS:
return ext
return None
async def _get_kb_limits(user: User, db: AsyncSession) -> PlanLimits | None:
plan = await get_user_plan(user.account_id, db)
return await get_plan_limits(plan, db)
async def _check_kb_enabled(user: User, db: AsyncSession) -> PlanLimits:
limits = await _get_kb_limits(user, db)
if not limits or not limits.kb_accelerator_enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="KB Accelerator is not available on your plan.",
)
return limits
async def _check_lifetime_limit(user: User, limits: PlanLimits, db: AsyncSession) -> None:
if limits.kb_max_lifetime_conversions is None:
return # Unlimited
count = await db.scalar(
select(func.count(KBImport.id)).where(
KBImport.account_id == user.account_id,
KBImport.status == "committed",
)
) or 0
if count >= limits.kb_max_lifetime_conversions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You have reached your lifetime limit of {limits.kb_max_lifetime_conversions} KB conversions. Upgrade your plan for unlimited conversions.",
)
async def _check_format_allowed(source_format: str, limits: PlanLimits) -> None:
allowed = limits.kb_allowed_formats or ["txt", "paste"]
if source_format not in allowed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Format '{source_format}' is not available on your plan. Allowed: {', '.join(allowed)}",
)
async def _get_import_or_404(
import_id: UUID, user: User, db: AsyncSession, *, load_nodes: bool = True
) -> KBImport:
query = select(KBImport).where(
KBImport.id == import_id,
KBImport.account_id == user.account_id,
)
if load_nodes:
query = query.options(selectinload(KBImport.nodes))
result = await db.execute(query)
kb_import = result.scalar_one_or_none()
if not kb_import:
raise HTTPException(status_code=404, detail="KB import not found")
return kb_import
async def _run_conversion(import_id: UUID, db_url: str) -> None:
"""Background task: run AI conversion on a KB import."""
from app.core.database import async_session_maker
async with async_session_maker() as db:
result = await db.execute(
select(KBImport).where(KBImport.id == import_id)
)
kb_import = result.scalar_one_or_none()
if not kb_import or kb_import.status != "processing":
return
try:
await convert_document(kb_import, db)
await db.commit()
except Exception as e:
logger.error("Background KB conversion failed: %s", e)
kb_import.status = "failed"
kb_import.error_message = f"Conversion error: {str(e)}"
await db.commit()
def _serialize_import(kb_import: KBImport) -> dict:
"""Serialize a KBImport to dict for response."""
return {
"id": kb_import.id,
"account_id": kb_import.account_id,
"created_by": kb_import.created_by,
"source_filename": kb_import.source_filename,
"source_format": kb_import.source_format,
"source_text": kb_import.source_text,
"source_metadata": kb_import.source_metadata,
"target_type": kb_import.target_type,
"status": kb_import.status,
"confidence_avg": kb_import.confidence_avg,
"error_message": kb_import.error_message,
"processing_time_ms": kb_import.processing_time_ms,
"ai_tokens_input": kb_import.ai_tokens_input,
"ai_tokens_output": kb_import.ai_tokens_output,
"tree_id": kb_import.tree_id,
"nodes": [
KBImportNodeResponse.model_validate(n) for n in kb_import.nodes
] if kb_import.nodes else [],
"created_at": kb_import.created_at.isoformat(),
"updated_at": kb_import.updated_at.isoformat(),
}
# ── Endpoints ──
@router.get("/quota", response_model=KBQuotaResponse)
async def get_quota(
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get KB Accelerator entitlements and usage for the current account."""
plan = await get_user_plan(user.account_id, db)
limits = await get_plan_limits(plan, db)
committed_count = await db.scalar(
select(func.count(KBImport.id)).where(
KBImport.account_id == user.account_id,
KBImport.status == "committed",
)
) or 0
if not limits:
return KBQuotaResponse(
plan=plan,
kb_accelerator_enabled=False,
lifetime_conversions_used=committed_count,
lifetime_conversions_limit=0,
allowed_formats=["txt", "paste"],
detailed_analysis=False,
conversational_refinement=False,
step_library_matching=False,
history_limit=3,
can_convert=False,
)
can_convert = limits.kb_accelerator_enabled
if limits.kb_max_lifetime_conversions is not None:
can_convert = can_convert and committed_count < limits.kb_max_lifetime_conversions
return KBQuotaResponse(
plan=plan,
kb_accelerator_enabled=limits.kb_accelerator_enabled,
lifetime_conversions_used=committed_count,
lifetime_conversions_limit=limits.kb_max_lifetime_conversions,
allowed_formats=limits.kb_allowed_formats or ["txt", "paste"],
detailed_analysis=limits.kb_detailed_analysis,
conversational_refinement=limits.kb_conversational_refinement,
step_library_matching=limits.kb_step_library_matching,
history_limit=limits.kb_history_limit,
can_convert=can_convert,
)
@router.post("/upload", response_model=KBUploadResponse, status_code=201)
@limiter.limit("10/minute")
async def upload_kb_article(
request: Request,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
file: Optional[UploadFile] = File(None),
content: Optional[str] = Form(None),
title: Optional[str] = Form(None),
target_type: Optional[str] = Form(None),
):
"""Upload a KB article file or paste text for conversion."""
if not settings.ai_enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI is not configured.",
)
limits = await _check_kb_enabled(user, db)
await _check_lifetime_limit(user, limits, db)
# Determine source format and extract text
if file and file.filename:
source_format = _detect_format(file.filename)
if not source_format:
raise HTTPException(
status_code=400,
detail=f"Unsupported file format. Supported: {', '.join(ALLOWED_EXTENSIONS.keys())}",
)
await _check_format_allowed(source_format, limits)
file_bytes = await file.read()
if len(file_bytes) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail="File exceeds 10MB limit.")
if len(file_bytes) == 0:
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
source_filename = file.filename
try:
source_text, source_metadata = extract_text(file_bytes, source_format)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
elif content:
source_format = "paste"
await _check_format_allowed(source_format, limits)
source_filename = title
source_text = content.strip()
source_metadata = None
if len(source_text) < 10:
raise HTTPException(status_code=400, detail="Content must be at least 10 characters.")
else:
raise HTTPException(status_code=400, detail="Provide either a file or content text.")
# Validate target_type
if target_type and target_type not in ("troubleshooting", "procedural"):
raise HTTPException(status_code=400, detail="target_type must be 'troubleshooting' or 'procedural'.")
if not target_type:
target_type = "troubleshooting" # Default; Phase 2 adds "let AI decide"
# Create KB import record
kb_import = KBImport(
account_id=user.account_id,
created_by=user.id,
source_filename=source_filename,
source_format=source_format,
source_text=source_text,
source_metadata=source_metadata,
target_type=target_type,
status="processing",
)
db.add(kb_import)
await db.flush()
# Trigger AI conversion in background
background_tasks.add_task(_run_conversion, kb_import.id, settings.DATABASE_URL)
await db.commit()
return KBUploadResponse(
id=kb_import.id,
status=kb_import.status,
source_format=kb_import.source_format,
)
@router.get("/{import_id}", response_model=KBImportResponse)
async def get_kb_import(
import_id: UUID,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get a KB import with its generated nodes."""
kb_import = await _get_import_or_404(import_id, user, db)
return _serialize_import(kb_import)
@router.get("", response_model=KBImportListResponse)
async def list_kb_imports(
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
skip: int = 0,
limit: int = 20,
status_filter: Optional[str] = None,
):
"""List KB imports for the current account."""
limits = await _get_kb_limits(user, db)
history_limit = limits.kb_history_limit if limits else 3
query = select(KBImport).where(KBImport.account_id == user.account_id)
count_query = select(func.count(KBImport.id)).where(KBImport.account_id == user.account_id)
if status_filter:
query = query.where(KBImport.status == status_filter)
count_query = count_query.where(KBImport.status == status_filter)
total = await db.scalar(count_query) or 0
query = query.order_by(KBImport.created_at.desc())
# Apply history limit for free tier
effective_limit = limit
if history_limit is not None:
effective_limit = min(limit, history_limit - skip) if skip < history_limit else 0
if effective_limit <= 0:
return KBImportListResponse(items=[], total=total, skip=skip, limit=limit)
query = query.offset(skip).limit(effective_limit)
query = query.options(selectinload(KBImport.nodes))
result = await db.execute(query)
imports = result.scalars().all()
items = []
for imp in imports:
items.append(KBImportSummary(
id=imp.id,
source_filename=imp.source_filename,
source_format=imp.source_format,
target_type=imp.target_type,
status=imp.status,
confidence_avg=imp.confidence_avg,
node_count=len(imp.nodes) if imp.nodes else 0,
created_at=imp.created_at.isoformat(),
))
return KBImportListResponse(items=items, total=total, skip=skip, limit=limit)
@router.post("/{import_id}/convert", response_model=KBUploadResponse)
@limiter.limit("30/minute")
async def reconvert(
request: Request,
import_id: UUID,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Re-trigger AI conversion on an existing import (retry/regenerate)."""
if not settings.ai_enabled:
raise HTTPException(status_code=503, detail="AI is not configured.")
kb_import = await _get_import_or_404(import_id, user, db, load_nodes=False)
if kb_import.status == "committed":
raise HTTPException(status_code=400, detail="Cannot reconvert a committed import.")
# Delete existing nodes
await db.execute(
delete(KBImportNode).where(KBImportNode.kb_import_id == kb_import.id)
)
kb_import.status = "processing"
kb_import.error_message = None
kb_import.confidence_avg = None
await db.flush()
background_tasks.add_task(_run_conversion, kb_import.id, settings.DATABASE_URL)
await db.commit()
return KBUploadResponse(
id=kb_import.id, status="processing", source_format=kb_import.source_format
)
@router.patch("/{import_id}/nodes/{node_id}", response_model=KBImportNodeResponse)
async def edit_node(
import_id: UUID,
node_id: UUID,
data: KBNodeEditRequest,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Edit a specific node in a KB import during review."""
kb_import = await _get_import_or_404(import_id, user, db, load_nodes=False)
if kb_import.status != "ready":
raise HTTPException(status_code=400, detail="Import must be in 'ready' status to edit nodes.")
result = await db.execute(
select(KBImportNode).where(
KBImportNode.id == node_id,
KBImportNode.kb_import_id == import_id,
)
)
node = result.scalar_one_or_none()
if not node:
raise HTTPException(status_code=404, detail="Node not found")
op = data.operation
if op == "approve":
node.user_approved = True
elif op == "reject":
node.user_approved = False
elif op == "edit":
if not data.content:
raise HTTPException(status_code=400, detail="Content required for edit operation.")
node.content = data.content
node.user_edited = True
elif op == "delete":
await db.delete(node)
# Reorder remaining nodes
remaining = await db.execute(
select(KBImportNode)
.where(KBImportNode.kb_import_id == import_id)
.order_by(KBImportNode.node_order)
)
for idx, n in enumerate(remaining.scalars().all()):
n.node_order = idx
await db.flush()
await db.commit()
# Return a placeholder response for deleted node
return KBImportNodeResponse(
id=node_id,
kb_import_id=import_id,
node_order=-1,
node_type="step",
content={"deleted": True},
confidence_score=0,
user_edited=False,
user_approved=False,
)
elif op == "insert_after":
if not data.content:
raise HTTPException(status_code=400, detail="Content required for insert_after operation.")
# Shift subsequent nodes
subsequent = await db.execute(
select(KBImportNode)
.where(
KBImportNode.kb_import_id == import_id,
KBImportNode.node_order > node.node_order,
)
.order_by(KBImportNode.node_order)
)
for n in subsequent.scalars().all():
n.node_order += 1
new_node = KBImportNode(
kb_import_id=import_id,
node_order=node.node_order + 1,
node_type=data.content.get("type", "step"),
content=data.content,
confidence_score=1.0, # User-created nodes are fully trusted
user_edited=True,
user_approved=True,
)
db.add(new_node)
await db.flush()
await db.commit()
return KBImportNodeResponse.model_validate(new_node)
elif op == "regenerate":
# Re-run AI for just this node (simplified: update placeholder)
# Full implementation would call AI with node context + guidance
node.user_edited = False
node.user_approved = False
node.updated_at = datetime.now(timezone.utc)
await db.flush()
await db.commit()
return KBImportNodeResponse.model_validate(node)
@router.post("/{import_id}/commit", response_model=KBCommitResponse)
async def commit_import(
import_id: UUID,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
data: Optional[KBCommitRequest] = None,
):
"""Commit a reviewed KB import to the flow library as a Tree."""
kb_import = await _get_import_or_404(import_id, user, db)
if kb_import.status != "ready":
raise HTTPException(status_code=400, detail="Import must be in 'ready' status to commit.")
if not kb_import.nodes:
raise HTTPException(status_code=400, detail="No nodes to commit.")
# Extract title/description from conversion metadata
conversion_meta = (kb_import.source_metadata or {}).get("_conversion", {})
tree_name = (data.name if data and data.name else None) or conversion_meta.get("title", "Imported Flow")
tree_description = (data.description if data else None) or conversion_meta.get("description")
# Build tree_structure from nodes
if kb_import.target_type == "troubleshooting":
tree_structure = _build_troubleshooting_tree(kb_import.nodes)
else:
tree_structure = _build_procedural_tree(kb_import.nodes)
# Validate the built tree before committing
if kb_import.target_type == "troubleshooting":
validation_errors = validate_generated_tree(tree_structure)
if validation_errors:
logger.warning(
"KB commit blocked: tree failed validation with %d errors: %s",
len(validation_errors), "; ".join(validation_errors[:5]),
)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"message": "The converted flow has structural issues that need to be fixed before committing.",
"validation_errors": validation_errors,
},
)
else:
# Procedural/maintenance validation
is_valid, proc_errors = validate_procedural_structure(tree_structure)
if not is_valid:
error_messages = [e.get("message", str(e)) for e in proc_errors]
logger.warning(
"KB commit blocked: procedural flow failed validation with %d errors: %s",
len(proc_errors), "; ".join(error_messages[:5]),
)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"message": "The converted flow has structural issues that need to be fixed before committing.",
"validation_errors": error_messages,
},
)
# Build intake_form for procedural flows
intake_form = None
if kb_import.target_type == "procedural":
intake_form = (kb_import.source_metadata or {}).get("_intake_form")
# Create the Tree record
tree = Tree(
name=tree_name,
description=tree_description,
tree_type=kb_import.target_type,
tree_structure=tree_structure,
intake_form=intake_form,
author_id=user.id,
account_id=user.account_id,
status="draft",
import_metadata={
"source": "kb_accelerator",
"kb_import_id": str(kb_import.id),
"source_filename": kb_import.source_filename,
"source_format": kb_import.source_format,
"confidence_avg": kb_import.confidence_avg,
"node_count": len(kb_import.nodes),
"converted_at": datetime.now(timezone.utc).isoformat(),
},
)
if data and data.category_id:
tree.category_id = data.category_id
db.add(tree)
await db.flush()
kb_import.status = "committed"
kb_import.tree_id = tree.id
await db.commit()
return KBCommitResponse(
tree_id=tree.id,
import_id=kb_import.id,
tree_type=kb_import.target_type,
)
@router.delete("/{import_id}", status_code=204)
async def delete_import(
import_id: UUID,
user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Cancel and clean up a KB import."""
kb_import = await _get_import_or_404(import_id, user, db, load_nodes=False)
if kb_import.status == "committed":
raise HTTPException(status_code=400, detail="Cannot delete a committed import.")
await db.execute(
delete(KBImportNode).where(KBImportNode.kb_import_id == import_id)
)
await db.delete(kb_import)
await db.commit()
# ── Tree Structure Builders ──
def _build_troubleshooting_tree(nodes: list[KBImportNode]) -> dict:
"""Build a troubleshooting tree_structure from import nodes.
The tree editor expects a deeply nested structure where each decision
node's `children` array contains all reachable descendant nodes.
Action/solution nodes use `title`/`description` (not `question`).
The AI generates a DAG (shared nodes reachable from multiple paths),
but the tree editor requires unique IDs — each node can only appear
once. We embed each node the first time it's encountered; subsequent
references just use next_node_id / options[].next_node_id to point
back to the already-embedded node.
"""
if not nodes:
return {"id": "root", "type": "decision", "question": "Empty", "children": []}
# Map original IDs to import nodes
original_id_map: dict[str, KBImportNode] = {}
for node in nodes:
orig_id = node.content.get("original_id", str(node.id))
original_id_map[orig_id] = node
# Track which nodes have been placed in the tree to avoid duplicates
placed: set[str] = set()
def _build_node(import_node: KBImportNode) -> dict | None:
content = import_node.content
node_type = import_node.node_type
node_id = content.get("original_id", str(import_node.id))
# Already placed in the tree — don't create a duplicate.
# The reference (next_node_id / options) is sufficient.
if node_id in placed:
return None
placed.add(node_id)
question_text = content.get("question", "")
if node_type == "resolution":
return {
"id": node_id,
"type": "solution",
"title": question_text,
"description": content.get("description", ""),
}
if node_type in ("action", "warning"):
result: dict = {
"id": node_id,
"type": "action",
"title": question_text,
"description": content.get("description", ""),
}
next_id = content.get("next_node_id")
if next_id and next_id in original_id_map:
result["next_node_id"] = next_id
return result
# question/decision type — recursively build children
options = content.get("options", [])
# Count how many options point to buildable (not-yet-placed) targets
buildable_targets = []
for opt in options:
next_id = opt.get("next_node_id")
if next_id and next_id in original_id_map and next_id not in placed:
buildable_targets.append(next_id)
# Decision nodes MUST have at least 2 branches to pass validation.
# If fewer than 2 buildable targets, demote to action node.
if len(buildable_targets) < 2:
demoted: dict = {
"id": node_id,
"type": "action",
"title": question_text,
"description": content.get("description", ""),
}
if buildable_targets:
demoted["next_node_id"] = buildable_targets[0]
elif options:
# All targets already placed; reference first option's target
first_next = options[0].get("next_node_id")
if first_next:
demoted["next_node_id"] = first_next
return demoted
# Build children for decision node
children = []
built_options = []
for opt in options:
next_id = opt.get("next_node_id")
opt_id = opt.get("id", f"opt-{node_id}-{len(built_options)}")
if next_id and next_id in original_id_map:
child_node = _build_node(original_id_map[next_id])
if child_node is not None:
children.append(child_node)
_collect_action_chain(child_node, children)
built_options.append({
"id": opt_id,
"label": opt.get("label", ""),
"next_node_id": next_id,
})
else:
built_options.append({
"id": opt_id,
"label": opt.get("label", ""),
"next_node_id": next_id or "",
})
return {
"id": node_id,
"type": "decision",
"question": question_text,
"options": built_options,
"children": children,
}
def _collect_action_chain(node: dict, siblings: list[dict]) -> None:
"""Follow action node next_node_id chains and add targets as siblings."""
if node.get("type") != "action":
return
next_id = node.get("next_node_id")
if not next_id or next_id not in original_id_map:
return
# Don't add if already in this siblings list or already placed
if any(s["id"] == next_id for s in siblings):
return
target = _build_node(original_id_map[next_id])
if target is None:
return
siblings.append(target)
# Continue chain if the target is also an action
_collect_action_chain(target, siblings)
root_node = nodes[0]
result = _build_node(root_node)
if not result:
return {"id": "root", "type": "decision", "question": "Empty", "children": []}
# Post-build repair: fix structural issues caused by placed-set race conditions
_repair_tree(result)
# Ensure root is a valid decision node (validator requires this)
if result.get("type") == "decision":
children = result.get("children", [])
options = result.get("options", [])
# Root must have >= 2 children and >= 2 options
if len(children) < 2 or len(options) < 2:
logger.warning(
"KB tree root has %d children and %d options after repair; "
"tree may fail validation",
len(children), len(options),
)
return result
def _repair_tree(node: dict) -> None:
"""Walk the built tree and fix structural issues.
Fixes (applied bottom-up so child repairs happen before parent checks):
- Decision nodes with < 2 children → demote to action, hoist children to parent
- Decision nodes with < 2 options → rebuild options from children
- Action nodes missing next_node_id → convert to solution
"""
# Repair children first, then handle this node's children list
# We process the children list in-place, potentially expanding it
# when demoted decisions hoist their children up.
i = 0
children = node.get("children", [])
while i < len(children):
child = children[i]
if not isinstance(child, dict):
i += 1
continue
# Recurse into child first
_repair_tree(child)
# After recursion, check if this child is a decision that needs demotion
if child.get("type") == "decision":
child_children = child.get("children", [])
if len(child_children) < 2:
_demote_decision_to_action(child, children, i)
i += 1
# Now fix this node itself
node_type = node.get("type")
node_id = node.get("id", "unknown")
if node_type == "decision":
children = node.get("children", [])
options = node.get("options", [])
if len(options) < 2 and len(children) >= 2:
# Rebuild options from children
node["options"] = [
{
"id": f"opt-{node_id}-{i}",
"label": c.get("question") or c.get("title", f"Option {i+1}"),
"next_node_id": c.get("id", ""),
}
for i, c in enumerate(children)
]
elif not options:
node["options"] = []
elif node_type == "action":
if not node.get("next_node_id"):
# Action with no next_node_id → convert to solution
node["type"] = "solution"
if not node.get("title"):
node["title"] = node.get("question", "Resolution")
if not node.get("description"):
node["description"] = ""
def _demote_decision_to_action(node: dict, siblings: list[dict], index: int) -> None:
"""Demote a decision node to action and hoist its children as siblings.
Args:
node: The decision node to demote (modified in-place).
siblings: The parent's children list (may be expanded).
index: Position of this node in siblings.
"""
child_children = node.get("children", [])
question = node.get("question", "")
# Pick next_node_id from first child
next_id = None
if child_children:
next_id = child_children[0].get("id")
else:
options = node.get("options", [])
if options:
next_id = options[0].get("next_node_id")
# Convert in-place to action
node["type"] = "action"
node["title"] = question
node["description"] = ""
if next_id:
node["next_node_id"] = next_id
node.pop("question", None)
node.pop("options", None)
# Hoist children as siblings after this node
if child_children:
hoisted = node.pop("children", [])
for j, hoisted_child in enumerate(hoisted):
siblings.insert(index + 1 + j, hoisted_child)
# Delete the broken _repair_tree and replace with the working version
# by removing the first broken attempt
def _build_procedural_tree(nodes: list[KBImportNode]) -> dict:
"""Build a procedural tree_structure from import nodes.
Maps AI node types to valid procedural step types:
- step/action/warning → procedure_step
- section_header → section_header
Adds a procedure_end step at the end if missing.
Each step requires 'title' (from content text) and 'content' fields.
"""
# Type mapping from AI output to valid step types
TYPE_MAP = {
"step": "procedure_step",
"action": "procedure_step",
"warning": "procedure_step",
"question": "procedure_step",
"resolution": "procedure_step",
"section_header": "section_header",
"procedure_step": "procedure_step",
"procedure_end": "procedure_end",
}
steps = []
for node in sorted(nodes, key=lambda n: n.node_order):
content = node.content
raw_type = node.node_type
step_type = TYPE_MAP.get(raw_type, "procedure_step")
step_content = content.get("content", "")
step_title = content.get("title") or content.get("question") or step_content[:80] or "Step"
step: dict = {
"id": content.get("original_id", str(node.id)),
"type": step_type,
"title": step_title,
"description": step_content,
}
# Preserve content_type if present
content_type = content.get("content_type")
if content_type:
step["content_type"] = content_type
steps.append(step)
# Ensure a procedure_end exists at the end
has_end = any(s["type"] == "procedure_end" for s in steps)
if not has_end and steps:
steps.append({
"id": "procedure-end",
"type": "procedure_end",
"title": "Procedure Complete",
"description": "All steps have been completed.",
})
return {
"id": "root",
"type": "procedural",
"steps": steps,
}