Implements three-phase AI assistant feature: - Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration, tree chunking service, and semantic search over team's flow library - Phase 1: In-session copilot panel during flow navigation with contextual AI help, current step awareness, and suggested related flows - Phase 2: Standalone AI chat page with persistent conversation history, pin/delete, and configurable retention policies (account-level) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
"""RAG service — index trees and search embeddings for AI context.
|
|
|
|
Orchestrates tree chunking, embedding, and semantic search over the
|
|
team's flow library via pgvector cosine similarity.
|
|
"""
|
|
import logging
|
|
from typing import Optional, Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import text, delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.tree import Tree
|
|
from app.models.tree_embedding import TreeEmbedding
|
|
from app.services.embedding_service import get_embedding, get_embeddings_batch
|
|
from app.services.tree_chunker import chunk_tree
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def index_tree(tree_id: UUID, db: AsyncSession) -> int:
|
|
"""Chunk and embed a tree, storing results in tree_embeddings.
|
|
|
|
Deletes existing embeddings for this tree before re-indexing.
|
|
Returns the number of chunks indexed.
|
|
"""
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
result = await db.execute(
|
|
select(Tree)
|
|
.options(selectinload(Tree.tags))
|
|
.where(Tree.id == tree_id)
|
|
)
|
|
tree = result.scalar_one_or_none()
|
|
if not tree:
|
|
logger.warning("index_tree: tree %s not found", tree_id)
|
|
return 0
|
|
|
|
# Delete existing embeddings
|
|
await db.execute(
|
|
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
|
)
|
|
|
|
# Chunk the tree
|
|
tag_names = [t.name for t in tree.tags] if tree.tags else []
|
|
chunks = chunk_tree(
|
|
tree_name=tree.name,
|
|
tree_type=tree.tree_type,
|
|
description=tree.description,
|
|
tags=tag_names,
|
|
tree_structure=tree.tree_structure,
|
|
)
|
|
|
|
if not chunks:
|
|
logger.info("index_tree: no chunks for tree %s", tree_id)
|
|
return 0
|
|
|
|
# Get embeddings for all chunks in batch
|
|
texts = [c["chunk_text"] for c in chunks]
|
|
embeddings = await get_embeddings_batch(texts, input_type="document")
|
|
|
|
if embeddings is None:
|
|
logger.warning("index_tree: embedding service unavailable for tree %s", tree_id)
|
|
return 0
|
|
|
|
# Insert embeddings
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
embedding_str = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO tree_embeddings
|
|
(tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta)
|
|
VALUES
|
|
(:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb)
|
|
"""),
|
|
{
|
|
"tree_id": str(tree_id),
|
|
"account_id": str(tree.account_id) if tree.account_id else None,
|
|
"chunk_type": chunk["chunk_type"],
|
|
"node_type": chunk.get("node_type"),
|
|
"node_id": chunk.get("node_id"),
|
|
"chunk_text": chunk["chunk_text"],
|
|
"embedding_model": "voyage-3.5",
|
|
"embedding": embedding_str,
|
|
"meta": "{}",
|
|
},
|
|
)
|
|
|
|
logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id)
|
|
return len(chunks)
|
|
|
|
|
|
async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None:
|
|
"""Delete all embeddings for a tree."""
|
|
await db.execute(
|
|
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
|
)
|
|
|
|
|
|
async def search(
|
|
query: str,
|
|
account_id: UUID,
|
|
db: AsyncSession,
|
|
limit: int = 8,
|
|
exclude_tree_id: Optional[UUID] = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Semantic search over team's flow library.
|
|
|
|
Args:
|
|
query: Natural language search query.
|
|
account_id: Scope search to team's flows.
|
|
db: Database session.
|
|
limit: Max results to return.
|
|
exclude_tree_id: Exclude chunks from this tree (for copilot context).
|
|
|
|
Returns:
|
|
List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity.
|
|
"""
|
|
query_embedding = await get_embedding(query, input_type="query")
|
|
if query_embedding is None:
|
|
return []
|
|
|
|
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
|
|
|
|
exclude_clause = ""
|
|
params: dict[str, Any] = {
|
|
"embedding": embedding_str,
|
|
"account_id": str(account_id),
|
|
"limit": limit,
|
|
}
|
|
|
|
if exclude_tree_id:
|
|
exclude_clause = "AND te.tree_id != :exclude_tree_id"
|
|
params["exclude_tree_id"] = str(exclude_tree_id)
|
|
|
|
result = await db.execute(
|
|
text(f"""
|
|
SELECT
|
|
te.tree_id,
|
|
t.name as tree_name,
|
|
t.tree_type,
|
|
te.chunk_text,
|
|
te.chunk_type,
|
|
te.node_id,
|
|
1 - (te.embedding <=> :embedding::vector) as similarity
|
|
FROM tree_embeddings te
|
|
JOIN trees t ON t.id = te.tree_id
|
|
WHERE te.account_id = :account_id
|
|
AND t.deleted_at IS NULL
|
|
{exclude_clause}
|
|
ORDER BY te.embedding <=> :embedding::vector
|
|
LIMIT :limit
|
|
"""),
|
|
params,
|
|
)
|
|
|
|
rows = result.mappings().all()
|
|
return [
|
|
{
|
|
"tree_id": str(row["tree_id"]),
|
|
"tree_name": row["tree_name"],
|
|
"tree_type": row["tree_type"],
|
|
"chunk_text": row["chunk_text"],
|
|
"chunk_type": row["chunk_type"],
|
|
"node_id": row["node_id"],
|
|
"similarity": float(row["similarity"]),
|
|
}
|
|
for row in rows
|
|
]
|