feat: add AI assistant with in-session copilot and standalone chat with RAG

Implements three-phase AI assistant feature:
- Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration,
  tree chunking service, and semantic search over team's flow library
- Phase 1: In-session copilot panel during flow navigation with contextual
  AI help, current step awareness, and suggested related flows
- Phase 2: Standalone AI chat page with persistent conversation history,
  pin/delete, and configurable retention policies (account-level)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-03-04 01:36:36 -05:00
parent 41cb7956cb
commit 1aa60dada2
44 changed files with 3080 additions and 14 deletions

View File

@@ -0,0 +1,152 @@
"""Standalone AI assistant chat service with RAG context.
Provides persistent conversation history for general IT questions
with semantic search over the team's flow library.
"""
import logging
from typing import Optional, Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.ai_provider import get_ai_provider
from app.models.assistant_chat import AssistantChat
from app.services import rag_service
logger = logging.getLogger(__name__)
ASSISTANT_SYSTEM_PROMPT = """You are a Senior Systems and Network Engineer with 15+ years of experience working in Managed Service Provider (MSP) environments. You specialize in:
- Windows Server, Active Directory, Group Policy, and Hybrid Identity (Entra ID)
- Networking (TCP/IP, DNS, DHCP, VPN, firewall troubleshooting, Cisco/Fortinet)
- Virtualization (VMware, Hyper-V) and cloud platforms (Azure, AWS, M365)
- Endpoint management, RMM tools, and PSA platforms (ConnectWise, Datto, Kaseya)
- PowerShell scripting and automation
When answering:
- Be direct and actionable — MSP engineers need fast, practical answers
- Include specific commands, paths, and config values when relevant
- Mention potential risks or gotchas before suggesting changes
- If a relevant troubleshooting flow exists in the team's library, reference it
- Keep responses concise but thorough — prefer bullet points and code blocks
- Format code with proper markdown code blocks
"""
def _build_rag_context(rag_results: list[dict[str, Any]]) -> str:
"""Format RAG results into a system prompt section."""
if not rag_results:
return ""
parts = ["\n--- RELEVANT FLOWS FROM TEAM LIBRARY ---"]
for r in rag_results[:5]:
parts.append(f"- [{r['tree_type']}] {r['tree_name']}: {r['chunk_text'][:200]}")
return "\n".join(parts)
def _extract_suggested_flows(rag_results: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extract unique suggested flows from RAG results."""
seen: set[str] = set()
suggestions = []
for r in rag_results:
tid = r["tree_id"]
if tid in seen or r["similarity"] < 0.3:
continue
seen.add(tid)
suggestions.append({
"tree_id": tid,
"tree_name": r["tree_name"],
"tree_type": r["tree_type"],
"relevance_snippet": r["chunk_text"][:150],
})
return suggestions[:3]
def _auto_title(message: str) -> str:
"""Generate a short title from the first user message."""
title = message.strip()[:100]
if len(message) > 100:
title = title.rsplit(" ", 1)[0] + "..."
return title
async def create_chat(
user_id: UUID,
account_id: UUID,
db: AsyncSession,
) -> AssistantChat:
"""Create a new empty chat."""
chat = AssistantChat(
user_id=user_id,
account_id=account_id,
messages=[],
)
db.add(chat)
await db.flush()
return chat
async def send_message(
chat_id: UUID,
user_id: UUID,
account_id: UUID,
message: str,
db: AsyncSession,
) -> tuple[str, list[dict[str, Any]], AssistantChat]:
"""Send a user message and get AI response.
Returns (ai_content, suggested_flows, chat).
"""
result = await db.execute(
select(AssistantChat).where(
AssistantChat.id == chat_id,
AssistantChat.user_id == user_id,
)
)
chat = result.scalar_one_or_none()
if not chat:
raise ValueError("Chat not found")
# Auto-title from first message
if chat.message_count == 0:
chat.title = _auto_title(message)
# RAG search
rag_results = await rag_service.search(
query=message,
account_id=account_id,
db=db,
limit=8,
)
# Build system prompt
system_prompt = ASSISTANT_SYSTEM_PROMPT + _build_rag_context(rag_results)
# Build messages for AI
ai_messages = []
for msg in chat.messages:
if msg["role"] in ("user", "assistant"):
ai_messages.append({"role": msg["role"], "content": msg["content"]})
ai_messages.append({"role": "user", "content": message})
# Call AI
provider = get_ai_provider()
ai_content, input_tokens, output_tokens = await provider.generate_text(
system_prompt=system_prompt,
messages=ai_messages,
max_tokens=4096,
)
# Update chat
msgs = list(chat.messages)
msgs.append({"role": "user", "content": message})
msgs.append({"role": "assistant", "content": ai_content})
chat.messages = msgs
chat.message_count += 2
chat.total_input_tokens += input_tokens
chat.total_output_tokens += output_tokens
suggested_flows = _extract_suggested_flows(rag_results)
return ai_content, suggested_flows, chat

View File

@@ -0,0 +1,241 @@
"""Copilot service — in-session AI assistant with RAG context.
Builds system prompts with current flow context and RAG results,
manages conversation state, and returns AI responses with flow suggestions.
"""
import logging
from datetime import datetime, timezone, timedelta
from typing import Optional, Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.ai_provider import get_ai_provider
from app.models.tree import Tree
from app.models.copilot_conversation import CopilotConversation
from app.services import rag_service
logger = logging.getLogger(__name__)
COPILOT_SYSTEM_PROMPT = """You are a Senior Systems and Network Engineer with 15+ years of experience working in Managed Service Provider (MSP) environments. You specialize in:
- Windows Server, Active Directory, Group Policy, and Hybrid Identity (Entra ID)
- Networking (TCP/IP, DNS, DHCP, VPN, firewall troubleshooting, Cisco/Fortinet)
- Virtualization (VMware, Hyper-V) and cloud platforms (Azure, AWS, M365)
- Endpoint management, RMM tools, and PSA platforms (ConnectWise, Datto, Kaseya)
- PowerShell scripting and automation
You are acting as an in-session copilot while the user navigates a troubleshooting or procedural flow. You can see the flow context and their current position.
When answering:
- Be direct and actionable — MSP engineers need fast, practical answers
- Include specific commands, paths, and config values when relevant
- Mention potential risks or gotchas before suggesting changes
- If a relevant troubleshooting flow exists in the team's library, reference it
- Keep responses concise but thorough — prefer bullet points and code blocks
"""
def _build_flow_context(tree: Tree, current_node_id: Optional[str]) -> str:
"""Build flow context string for the system prompt."""
parts = [
f"\n--- CURRENT FLOW CONTEXT ---",
f"Flow: {tree.name}",
f"Type: {tree.tree_type}",
]
if tree.description:
parts.append(f"Description: {tree.description}")
if current_node_id and tree.tree_structure:
node = _find_node(tree.tree_structure, current_node_id)
if node:
parts.append(f"Current node type: {node.get('type', 'unknown')}")
parts.append(f"Current node: {node.get('content', node.get('label', 'Unknown'))}")
# Add options if it's a question/decision node
children = node.get("children", [])
if children and isinstance(children, list):
option_labels = [
c.get("label", c.get("content", ""))
for c in children if isinstance(c, dict)
]
if option_labels:
parts.append(f"Available options: {', '.join(option_labels)}")
return "\n".join(parts)
def _find_node(structure: dict, node_id: str) -> Optional[dict]:
"""Recursively find a node by ID in tree structure."""
if structure.get("id") == node_id:
return structure
for child in structure.get("children", []):
if isinstance(child, dict):
found = _find_node(child, node_id)
if found:
return found
# Check steps array for procedural flows
for step in structure.get("steps", []):
if isinstance(step, dict):
found = _find_node(step, node_id)
if found:
return found
return None
def _build_rag_context(rag_results: list[dict[str, Any]]) -> str:
"""Format RAG results into a system prompt section."""
if not rag_results:
return ""
parts = ["\n--- RELEVANT FLOWS FROM TEAM LIBRARY ---"]
for r in rag_results[:5]: # Cap at 5 for prompt size
parts.append(f"- [{r['tree_type']}] {r['tree_name']}: {r['chunk_text'][:200]}")
return "\n".join(parts)
def _extract_suggested_flows(
rag_results: list[dict[str, Any]],
exclude_tree_id: Optional[UUID] = None,
) -> list[dict[str, Any]]:
"""Extract unique suggested flows from RAG results."""
seen_tree_ids: set[str] = set()
suggestions = []
for r in rag_results:
tid = r["tree_id"]
if exclude_tree_id and tid == str(exclude_tree_id):
continue
if tid in seen_tree_ids:
continue
if r["similarity"] < 0.3:
continue
seen_tree_ids.add(tid)
suggestions.append({
"tree_id": tid,
"tree_name": r["tree_name"],
"tree_type": r["tree_type"],
"relevance_snippet": r["chunk_text"][:150],
})
return suggestions[:3]
async def start_conversation(
user_id: UUID,
account_id: UUID,
tree_id: UUID,
session_id: Optional[UUID],
current_node_id: Optional[str],
db: AsyncSession,
) -> tuple[CopilotConversation, str]:
"""Start a new copilot conversation.
Returns (conversation, greeting_message).
"""
# Load tree
result = await db.execute(
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == tree_id)
)
tree = result.scalar_one_or_none()
if not tree:
raise ValueError(f"Tree {tree_id} not found")
conversation = CopilotConversation(
user_id=user_id,
account_id=account_id,
tree_id=tree_id,
session_id=session_id,
current_node_id=current_node_id,
messages=[],
expires_at=datetime.now(timezone.utc) + timedelta(hours=24),
)
db.add(conversation)
await db.flush()
greeting = f"I'm your copilot for this **{tree.tree_type}** flow: **{tree.name}**. Ask me anything about the current step, alternative approaches, or related troubleshooting tips."
conversation.messages = [{"role": "assistant", "content": greeting}]
conversation.message_count = 1
return conversation, greeting
async def send_message(
conversation_id: UUID,
user_id: UUID,
message: str,
current_node_id: Optional[str],
db: AsyncSession,
) -> tuple[str, list[dict[str, Any]]]:
"""Send a user message and get AI response.
Returns (ai_content, suggested_flows).
"""
result = await db.execute(
select(CopilotConversation).where(
CopilotConversation.id == conversation_id,
CopilotConversation.user_id == user_id,
)
)
conversation = result.scalar_one_or_none()
if not conversation:
raise ValueError("Conversation not found")
if conversation.expires_at < datetime.now(timezone.utc):
raise ValueError("Conversation has expired")
# Load tree for context
tree_result = await db.execute(
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == conversation.tree_id)
)
tree = tree_result.scalar_one_or_none()
if not tree:
raise ValueError("Associated flow not found")
# Update current node
if current_node_id:
conversation.current_node_id = current_node_id
# RAG search
rag_results = await rag_service.search(
query=message,
account_id=conversation.account_id,
db=db,
limit=8,
)
# Build system prompt
system_prompt = COPILOT_SYSTEM_PROMPT
system_prompt += _build_flow_context(tree, conversation.current_node_id)
system_prompt += _build_rag_context(rag_results)
# Build messages for AI
ai_messages = []
for msg in conversation.messages:
if msg["role"] in ("user", "assistant"):
ai_messages.append({"role": msg["role"], "content": msg["content"]})
ai_messages.append({"role": "user", "content": message})
# Call AI
provider = get_ai_provider()
ai_content, input_tokens, output_tokens = await provider.generate_text(
system_prompt=system_prompt,
messages=ai_messages,
max_tokens=2048,
)
# Update conversation
msgs = list(conversation.messages)
msgs.append({"role": "user", "content": message})
msgs.append({"role": "assistant", "content": ai_content})
conversation.messages = msgs
conversation.message_count += 2
conversation.total_input_tokens += input_tokens
conversation.total_output_tokens += output_tokens
# Extract suggested flows
suggested_flows = _extract_suggested_flows(rag_results, exclude_tree_id=tree.id)
return ai_content, suggested_flows

View File

@@ -0,0 +1,78 @@
"""Embedding provider abstraction for RAG.
Uses Voyage AI (voyage-3.5, 1024 dims) as the embedding provider.
Supports document and query input types for asymmetric search.
"""
import logging
from typing import Optional
from app.core.config import settings
logger = logging.getLogger(__name__)
async def get_embedding(
text: str,
input_type: str = "document",
) -> Optional[list[float]]:
"""Get embedding vector for text using Voyage AI.
Args:
text: The text to embed.
input_type: "document" for indexing, "query" for search queries.
Returns:
List of floats (1024 dims) or None if embedding service unavailable.
"""
if not settings.VOYAGE_API_KEY:
logger.warning("VOYAGE_API_KEY not set — embedding service unavailable")
return None
try:
import voyageai
client = voyageai.AsyncClient(api_key=settings.VOYAGE_API_KEY)
result = await client.embed(
texts=[text],
model=settings.EMBEDDING_MODEL,
input_type=input_type,
)
return result.embeddings[0]
except Exception as e:
logger.error("Embedding failed: %s", e)
return None
async def get_embeddings_batch(
texts: list[str],
input_type: str = "document",
) -> Optional[list[list[float]]]:
"""Get embedding vectors for multiple texts in a single API call.
Args:
texts: List of texts to embed.
input_type: "document" for indexing, "query" for search queries.
Returns:
List of embedding vectors or None if service unavailable.
"""
if not texts:
return []
if not settings.VOYAGE_API_KEY:
logger.warning("VOYAGE_API_KEY not set — embedding service unavailable")
return None
try:
import voyageai
client = voyageai.AsyncClient(api_key=settings.VOYAGE_API_KEY)
result = await client.embed(
texts=texts,
model=settings.EMBEDDING_MODEL,
input_type=input_type,
)
return result.embeddings
except Exception as e:
logger.error("Batch embedding failed: %s", e)
return None

View File

@@ -0,0 +1,170 @@
"""RAG service — index trees and search embeddings for AI context.
Orchestrates tree chunking, embedding, and semantic search over the
team's flow library via pgvector cosine similarity.
"""
import logging
from typing import Optional, Any
from uuid import UUID
from sqlalchemy import text, delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.tree import Tree
from app.models.tree_embedding import TreeEmbedding
from app.services.embedding_service import get_embedding, get_embeddings_batch
from app.services.tree_chunker import chunk_tree
logger = logging.getLogger(__name__)
async def index_tree(tree_id: UUID, db: AsyncSession) -> int:
"""Chunk and embed a tree, storing results in tree_embeddings.
Deletes existing embeddings for this tree before re-indexing.
Returns the number of chunks indexed.
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
result = await db.execute(
select(Tree)
.options(selectinload(Tree.tags))
.where(Tree.id == tree_id)
)
tree = result.scalar_one_or_none()
if not tree:
logger.warning("index_tree: tree %s not found", tree_id)
return 0
# Delete existing embeddings
await db.execute(
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
)
# Chunk the tree
tag_names = [t.name for t in tree.tags] if tree.tags else []
chunks = chunk_tree(
tree_name=tree.name,
tree_type=tree.tree_type,
description=tree.description,
tags=tag_names,
tree_structure=tree.tree_structure,
)
if not chunks:
logger.info("index_tree: no chunks for tree %s", tree_id)
return 0
# Get embeddings for all chunks in batch
texts = [c["chunk_text"] for c in chunks]
embeddings = await get_embeddings_batch(texts, input_type="document")
if embeddings is None:
logger.warning("index_tree: embedding service unavailable for tree %s", tree_id)
return 0
# Insert embeddings
for chunk, embedding in zip(chunks, embeddings):
embedding_str = "[" + ",".join(str(v) for v in embedding) + "]"
await db.execute(
text("""
INSERT INTO tree_embeddings
(tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta)
VALUES
(:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb)
"""),
{
"tree_id": str(tree_id),
"account_id": str(tree.account_id) if tree.account_id else None,
"chunk_type": chunk["chunk_type"],
"node_type": chunk.get("node_type"),
"node_id": chunk.get("node_id"),
"chunk_text": chunk["chunk_text"],
"embedding_model": "voyage-3.5",
"embedding": embedding_str,
"meta": "{}",
},
)
logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id)
return len(chunks)
async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None:
"""Delete all embeddings for a tree."""
await db.execute(
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
)
async def search(
query: str,
account_id: UUID,
db: AsyncSession,
limit: int = 8,
exclude_tree_id: Optional[UUID] = None,
) -> list[dict[str, Any]]:
"""Semantic search over team's flow library.
Args:
query: Natural language search query.
account_id: Scope search to team's flows.
db: Database session.
limit: Max results to return.
exclude_tree_id: Exclude chunks from this tree (for copilot context).
Returns:
List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity.
"""
query_embedding = await get_embedding(query, input_type="query")
if query_embedding is None:
return []
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
exclude_clause = ""
params: dict[str, Any] = {
"embedding": embedding_str,
"account_id": str(account_id),
"limit": limit,
}
if exclude_tree_id:
exclude_clause = "AND te.tree_id != :exclude_tree_id"
params["exclude_tree_id"] = str(exclude_tree_id)
result = await db.execute(
text(f"""
SELECT
te.tree_id,
t.name as tree_name,
t.tree_type,
te.chunk_text,
te.chunk_type,
te.node_id,
1 - (te.embedding <=> :embedding::vector) as similarity
FROM tree_embeddings te
JOIN trees t ON t.id = te.tree_id
WHERE te.account_id = :account_id
AND t.deleted_at IS NULL
{exclude_clause}
ORDER BY te.embedding <=> :embedding::vector
LIMIT :limit
"""),
params,
)
rows = result.mappings().all()
return [
{
"tree_id": str(row["tree_id"]),
"tree_name": row["tree_name"],
"tree_type": row["tree_type"],
"chunk_text": row["chunk_text"],
"chunk_type": row["chunk_type"],
"node_id": row["node_id"],
"similarity": float(row["similarity"]),
}
for row in rows
]

View File

@@ -0,0 +1,84 @@
"""Chat retention cleanup job.
Runs daily via APScheduler to enforce account-level retention settings:
- Delete non-pinned chats older than chat_retention_days
- Delete oldest non-pinned chats when count exceeds chat_retention_max_count
"""
import logging
from datetime import datetime, timezone, timedelta
from sqlalchemy import select, delete, func
from app.core.database import async_session_maker
from app.models.account import Account
from app.models.assistant_chat import AssistantChat
logger = logging.getLogger(__name__)
async def cleanup_expired_chats() -> None:
"""Enforce chat retention policies for all accounts."""
async with async_session_maker() as db:
try:
result = await db.execute(select(Account))
accounts = result.scalars().all()
total_deleted = 0
for account in accounts:
deleted = await _cleanup_account_chats(account, db)
total_deleted += deleted
await db.commit()
if total_deleted > 0:
logger.info("[retention] Cleaned up %d expired chats", total_deleted)
except Exception as e:
logger.error("[retention] Chat cleanup failed: %s", e)
await db.rollback()
async def _cleanup_account_chats(account: Account, db) -> int:
"""Enforce retention for a single account. Returns count deleted."""
deleted = 0
# Age-based retention
if account.chat_retention_days:
cutoff = datetime.now(timezone.utc) - timedelta(days=account.chat_retention_days)
result = await db.execute(
delete(AssistantChat)
.where(
AssistantChat.account_id == account.id,
AssistantChat.pinned == False, # noqa: E712
AssistantChat.updated_at < cutoff,
)
.returning(AssistantChat.id)
)
deleted += len(result.all())
# Count-based retention
if account.chat_retention_max_count:
total = await db.scalar(
select(func.count(AssistantChat.id)).where(
AssistantChat.account_id == account.id,
)
) or 0
if total > account.chat_retention_max_count:
excess = total - account.chat_retention_max_count
# Get oldest non-pinned chat IDs
oldest = await db.execute(
select(AssistantChat.id)
.where(
AssistantChat.account_id == account.id,
AssistantChat.pinned == False, # noqa: E712
)
.order_by(AssistantChat.updated_at.asc())
.limit(excess)
)
ids_to_delete = [row[0] for row in oldest.all()]
if ids_to_delete:
await db.execute(
delete(AssistantChat).where(AssistantChat.id.in_(ids_to_delete))
)
deleted += len(ids_to_delete)
return deleted

View File

@@ -0,0 +1,165 @@
"""Tree chunker — converts tree_structure JSON into embeddable text chunks.
Produces three chunk types:
- tree_summary: Name + description + tags + type overview
- node: Individual node content with breadcrumb path context
- solution: Full solution/action text with path context
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
def _get_breadcrumb(node: dict, parent_path: str = "") -> str:
"""Build a breadcrumb path string for a node."""
content = node.get("content", node.get("label", ""))[:80]
if parent_path:
return f"{parent_path} > {content}"
return content
def _chunk_node(
node: dict,
tree_name: str,
tree_type: str,
tags: list[str],
parent_path: str = "",
) -> list[dict[str, Any]]:
"""Recursively chunk a node and its children."""
chunks = []
node_type = node.get("type", "unknown")
node_id = node.get("id", "")
content = node.get("content", node.get("label", ""))
breadcrumb = _get_breadcrumb(node, parent_path)
# Build chunk text based on node type
if node_type in ("question", "decision"):
options = node.get("children", [])
option_labels = [
child.get("label", child.get("content", ""))[:100]
for child in options
if isinstance(child, dict)
]
text_parts = [
f"[{node_type}] {content}",
]
if option_labels:
text_parts.append(f"Options: {', '.join(option_labels)}")
text_parts.append(f"Path: {breadcrumb}")
text_parts.append(f"Flow: {tree_name} | Type: {tree_type}")
if tags:
text_parts.append(f"Tags: {', '.join(tags)}")
chunks.append({
"chunk_type": "node",
"node_type": node_type,
"node_id": node_id,
"chunk_text": "\n".join(text_parts),
})
elif node_type in ("action", "solution", "info", "warning"):
text_parts = [
f"[{node_type}] {content}",
f"Path: {breadcrumb}",
f"Flow: {tree_name} | Type: {tree_type}",
]
if tags:
text_parts.append(f"Tags: {', '.join(tags)}")
chunk_type = "solution" if node_type == "solution" else "node"
chunks.append({
"chunk_type": chunk_type,
"node_type": node_type,
"node_id": node_id,
"chunk_text": "\n".join(text_parts),
})
elif node_type in ("step", "section_header"):
text_parts = [
f"[{node_type}] {content}",
f"Path: {breadcrumb}",
f"Flow: {tree_name} | Type: {tree_type}",
]
if node.get("description"):
text_parts.insert(1, node["description"])
if tags:
text_parts.append(f"Tags: {', '.join(tags)}")
chunks.append({
"chunk_type": "node",
"node_type": node_type,
"node_id": node_id,
"chunk_text": "\n".join(text_parts),
})
# Recurse into children
children = node.get("children", [])
if isinstance(children, list):
for child in children:
if isinstance(child, dict):
chunks.extend(
_chunk_node(child, tree_name, tree_type, tags, breadcrumb)
)
# Follow next_node_id linked nodes (action nodes)
# These are handled at the tree level, not recursively
return chunks
def chunk_tree(
tree_name: str,
tree_type: str,
description: str | None,
tags: list[str],
tree_structure: dict[str, Any],
) -> list[dict[str, Any]]:
"""Convert a tree into embeddable text chunks.
Args:
tree_name: Name of the flow.
tree_type: troubleshooting | procedural | maintenance.
description: Flow description.
tags: List of tag names.
tree_structure: The tree_structure JSONB content.
Returns:
List of chunk dicts with keys: chunk_type, node_type, node_id, chunk_text.
"""
chunks = []
# Tree summary chunk
summary_parts = [
f"Flow: {tree_name}",
f"Type: {tree_type}",
]
if description:
summary_parts.append(f"Description: {description}")
if tags:
summary_parts.append(f"Tags: {', '.join(tags)}")
chunks.append({
"chunk_type": "tree_summary",
"node_type": None,
"node_id": None,
"chunk_text": "\n".join(summary_parts),
})
# Chunk the tree structure nodes
root = tree_structure
if isinstance(root, dict):
# Handle both flat structure and nested
if "children" in root or "type" in root:
chunks.extend(
_chunk_node(root, tree_name, tree_type, tags)
)
# Handle steps array (procedural flows)
if "steps" in root and isinstance(root["steps"], list):
for step in root["steps"]:
if isinstance(step, dict):
chunks.extend(
_chunk_node(step, tree_name, tree_type, tags)
)
return chunks