- Use CAST(:embedding AS vector) instead of :embedding::vector to avoid SQLAlchemy named param conflict with PostgreSQL :: cast syntax - Add db.rollback() before recording AI usage on failure to prevent InFailedSQLTransactionError cascade Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
210 lines
6.6 KiB
Python
210 lines
6.6 KiB
Python
"""RAG service — index trees and search embeddings for AI context.
|
|
|
|
Orchestrates tree chunking, embedding, and semantic search over the
|
|
team's flow library via pgvector cosine similarity.
|
|
"""
|
|
import logging
|
|
from typing import Optional, Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import text, delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.tree import Tree
|
|
from app.models.tree_embedding import TreeEmbedding
|
|
from app.services.embedding_service import get_embedding, get_embeddings_batch
|
|
from app.services.tree_chunker import chunk_tree
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def index_tree(tree_id: UUID, db: AsyncSession) -> int:
|
|
"""Chunk and embed a tree, storing results in tree_embeddings.
|
|
|
|
Deletes existing embeddings for this tree before re-indexing.
|
|
Returns the number of chunks indexed.
|
|
"""
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
result = await db.execute(
|
|
select(Tree)
|
|
.options(selectinload(Tree.tags))
|
|
.where(Tree.id == tree_id)
|
|
)
|
|
tree = result.scalar_one_or_none()
|
|
if not tree:
|
|
logger.warning("index_tree: tree %s not found", tree_id)
|
|
return 0
|
|
|
|
# Delete existing embeddings
|
|
await db.execute(
|
|
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
|
)
|
|
|
|
# Chunk the tree
|
|
tag_names = [t.name for t in tree.tags] if tree.tags else []
|
|
chunks = chunk_tree(
|
|
tree_name=tree.name,
|
|
tree_type=tree.tree_type,
|
|
description=tree.description,
|
|
tags=tag_names,
|
|
tree_structure=tree.tree_structure,
|
|
)
|
|
|
|
if not chunks:
|
|
logger.info("index_tree: no chunks for tree %s", tree_id)
|
|
return 0
|
|
|
|
# Get embeddings for all chunks in batch
|
|
texts = [c["chunk_text"] for c in chunks]
|
|
embeddings = await get_embeddings_batch(texts, input_type="document")
|
|
|
|
if embeddings is None:
|
|
logger.warning("index_tree: embedding service unavailable for tree %s", tree_id)
|
|
return 0
|
|
|
|
# Insert embeddings
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
embedding_str = "[" + ",".join(str(v) for v in embedding) + "]"
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO tree_embeddings
|
|
(tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta)
|
|
VALUES
|
|
(:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb)
|
|
"""),
|
|
{
|
|
"tree_id": str(tree_id),
|
|
"account_id": str(tree.account_id) if tree.account_id else None,
|
|
"chunk_type": chunk["chunk_type"],
|
|
"node_type": chunk.get("node_type"),
|
|
"node_id": chunk.get("node_id"),
|
|
"chunk_text": chunk["chunk_text"],
|
|
"embedding_model": "voyage-3.5",
|
|
"embedding": embedding_str,
|
|
"meta": "{}",
|
|
},
|
|
)
|
|
|
|
logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id)
|
|
return len(chunks)
|
|
|
|
|
|
async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None:
|
|
"""Delete all embeddings for a tree."""
|
|
await db.execute(
|
|
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
|
)
|
|
|
|
|
|
async def search(
|
|
query: str,
|
|
account_id: UUID,
|
|
db: AsyncSession,
|
|
limit: int = 8,
|
|
exclude_tree_id: Optional[UUID] = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Semantic search over team's flow library.
|
|
|
|
Args:
|
|
query: Natural language search query.
|
|
account_id: Scope search to team's flows.
|
|
db: Database session.
|
|
limit: Max results to return.
|
|
exclude_tree_id: Exclude chunks from this tree (for copilot context).
|
|
|
|
Returns:
|
|
List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity.
|
|
"""
|
|
query_embedding = await get_embedding(query, input_type="query")
|
|
if query_embedding is None:
|
|
return []
|
|
|
|
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
|
|
|
|
exclude_clause = ""
|
|
params: dict[str, Any] = {
|
|
"embedding": embedding_str,
|
|
"account_id": str(account_id),
|
|
"limit": limit,
|
|
}
|
|
|
|
if exclude_tree_id:
|
|
exclude_clause = "AND te.tree_id != :exclude_tree_id"
|
|
params["exclude_tree_id"] = str(exclude_tree_id)
|
|
|
|
result = await db.execute(
|
|
text(f"""
|
|
SELECT
|
|
te.tree_id,
|
|
t.name as tree_name,
|
|
t.tree_type,
|
|
te.chunk_text,
|
|
te.chunk_type,
|
|
te.node_id,
|
|
1 - (te.embedding <=> CAST(:embedding AS vector)) as similarity
|
|
FROM tree_embeddings te
|
|
JOIN trees t ON t.id = te.tree_id
|
|
WHERE te.account_id = :account_id
|
|
AND t.deleted_at IS NULL
|
|
{exclude_clause}
|
|
ORDER BY te.embedding <=> CAST(:embedding AS vector)
|
|
LIMIT :limit
|
|
"""),
|
|
params,
|
|
)
|
|
|
|
rows = result.mappings().all()
|
|
return [
|
|
{
|
|
"tree_id": str(row["tree_id"]),
|
|
"tree_name": row["tree_name"],
|
|
"tree_type": row["tree_type"],
|
|
"chunk_text": row["chunk_text"],
|
|
"chunk_type": row["chunk_type"],
|
|
"node_id": row["node_id"],
|
|
"similarity": float(row["similarity"]),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def build_rag_context(rag_results: list[dict[str, Any]]) -> str:
|
|
"""Format RAG results into a system prompt section."""
|
|
if not rag_results:
|
|
return ""
|
|
|
|
parts = ["\n--- RELEVANT FLOWS FROM TEAM LIBRARY ---"]
|
|
for r in rag_results[:5]: # Cap at 5 for prompt size
|
|
parts.append(f"- [{r['tree_type']}] {r['tree_name']}: {r['chunk_text'][:200]}")
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
def extract_suggested_flows(
|
|
rag_results: list[dict[str, Any]],
|
|
exclude_tree_id: Optional[UUID] = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Extract unique suggested flows from RAG results."""
|
|
seen_tree_ids: set[str] = set()
|
|
suggestions = []
|
|
|
|
for r in rag_results:
|
|
tid = r["tree_id"]
|
|
if exclude_tree_id and tid == str(exclude_tree_id):
|
|
continue
|
|
if tid in seen_tree_ids:
|
|
continue
|
|
if r["similarity"] < 0.3:
|
|
continue
|
|
seen_tree_ids.add(tid)
|
|
suggestions.append({
|
|
"tree_id": tid,
|
|
"tree_name": r["tree_name"],
|
|
"tree_type": r["tree_type"],
|
|
"relevance_snippet": r["chunk_text"][:150],
|
|
})
|
|
|
|
return suggestions[:3]
|