"""RAG service — index trees and search embeddings for AI context. Orchestrates tree chunking, embedding, and semantic search over the team's flow library via pgvector cosine similarity. """ import logging from typing import Optional, Any from uuid import UUID from sqlalchemy import text, delete from sqlalchemy.ext.asyncio import AsyncSession from app.models.tree import Tree from app.models.tree_embedding import TreeEmbedding from app.services.embedding_service import get_embedding, get_embeddings_batch from app.services.tree_chunker import chunk_tree logger = logging.getLogger(__name__) async def index_tree(tree_id: UUID, db: AsyncSession) -> int: """Chunk and embed a tree, storing results in tree_embeddings. Deletes existing embeddings for this tree before re-indexing. Returns the number of chunks indexed. """ from sqlalchemy import select from sqlalchemy.orm import selectinload result = await db.execute( select(Tree) .options(selectinload(Tree.tags)) .where(Tree.id == tree_id) ) tree = result.scalar_one_or_none() if not tree: logger.warning("index_tree: tree %s not found", tree_id) return 0 # Delete existing embeddings await db.execute( delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id) ) # Chunk the tree tag_names = [t.name for t in tree.tags] if tree.tags else [] chunks = chunk_tree( tree_name=tree.name, tree_type=tree.tree_type, description=tree.description, tags=tag_names, tree_structure=tree.tree_structure, ) if not chunks: logger.info("index_tree: no chunks for tree %s", tree_id) return 0 # Get embeddings for all chunks in batch texts = [c["chunk_text"] for c in chunks] embeddings = await get_embeddings_batch(texts, input_type="document") if embeddings is None: logger.warning("index_tree: embedding service unavailable for tree %s", tree_id) return 0 # Insert embeddings for chunk, embedding in zip(chunks, embeddings): embedding_str = "[" + ",".join(str(v) for v in embedding) + "]" await db.execute( text(""" INSERT INTO tree_embeddings (tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta) VALUES (:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb) """), { "tree_id": str(tree_id), "account_id": str(tree.account_id) if tree.account_id else None, "chunk_type": chunk["chunk_type"], "node_type": chunk.get("node_type"), "node_id": chunk.get("node_id"), "chunk_text": chunk["chunk_text"], "embedding_model": "voyage-3.5", "embedding": embedding_str, "meta": "{}", }, ) logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id) return len(chunks) async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None: """Delete all embeddings for a tree.""" await db.execute( delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id) ) async def search( query: str, account_id: UUID, db: AsyncSession, limit: int = 8, exclude_tree_id: Optional[UUID] = None, ) -> list[dict[str, Any]]: """Semantic search over team's flow library. Args: query: Natural language search query. account_id: Scope search to team's flows. db: Database session. limit: Max results to return. exclude_tree_id: Exclude chunks from this tree (for copilot context). Returns: List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity. """ query_embedding = await get_embedding(query, input_type="query") if query_embedding is None: return [] embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]" exclude_clause = "" params: dict[str, Any] = { "embedding": embedding_str, "account_id": str(account_id), "limit": limit, } if exclude_tree_id: exclude_clause = "AND te.tree_id != :exclude_tree_id" params["exclude_tree_id"] = str(exclude_tree_id) result = await db.execute( text(f""" SELECT te.tree_id, t.name as tree_name, t.tree_type, te.chunk_text, te.chunk_type, te.node_id, 1 - (te.embedding <=> CAST(:embedding AS vector)) as similarity FROM tree_embeddings te JOIN trees t ON t.id = te.tree_id WHERE te.account_id = :account_id AND t.deleted_at IS NULL {exclude_clause} ORDER BY te.embedding <=> CAST(:embedding AS vector) LIMIT :limit """), params, ) rows = result.mappings().all() return [ { "tree_id": str(row["tree_id"]), "tree_name": row["tree_name"], "tree_type": row["tree_type"], "chunk_text": row["chunk_text"], "chunk_type": row["chunk_type"], "node_id": row["node_id"], "similarity": float(row["similarity"]), } for row in rows ] def build_rag_context(rag_results: list[dict[str, Any]]) -> str: """Format RAG results into a system prompt section.""" if not rag_results: return "" parts = ["\n--- RELEVANT FLOWS FROM TEAM LIBRARY ---"] for r in rag_results[:5]: # Cap at 5 for prompt size parts.append(f"- [{r['tree_type']}] {r['tree_name']}: {r['chunk_text'][:200]}") return "\n".join(parts) def extract_suggested_flows( rag_results: list[dict[str, Any]], exclude_tree_id: Optional[UUID] = None, ) -> list[dict[str, Any]]: """Extract unique suggested flows from RAG results.""" seen_tree_ids: set[str] = set() suggestions = [] for r in rag_results: tid = r["tree_id"] if exclude_tree_id and tid == str(exclude_tree_id): continue if tid in seen_tree_ids: continue if r["similarity"] < 0.3: continue seen_tree_ids.add(tid) suggestions.append({ "tree_id": tid, "tree_name": r["tree_name"], "tree_type": r["tree_type"], "relevance_snippet": r["chunk_text"][:150], }) return suggestions[:3]