Files
resolutionflow/backend/app/services/rag_service.py
Michael Chihlas 1aa60dada2 feat: add AI assistant with in-session copilot and standalone chat with RAG
Implements three-phase AI assistant feature:
- Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration,
  tree chunking service, and semantic search over team's flow library
- Phase 1: In-session copilot panel during flow navigation with contextual
  AI help, current step awareness, and suggested related flows
- Phase 2: Standalone AI chat page with persistent conversation history,
  pin/delete, and configurable retention policies (account-level)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 01:36:36 -05:00

171 lines
5.4 KiB
Python

"""RAG service — index trees and search embeddings for AI context.
Orchestrates tree chunking, embedding, and semantic search over the
team's flow library via pgvector cosine similarity.
"""
import logging
from typing import Optional, Any
from uuid import UUID
from sqlalchemy import text, delete
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.tree import Tree
from app.models.tree_embedding import TreeEmbedding
from app.services.embedding_service import get_embedding, get_embeddings_batch
from app.services.tree_chunker import chunk_tree
logger = logging.getLogger(__name__)
async def index_tree(tree_id: UUID, db: AsyncSession) -> int:
"""Chunk and embed a tree, storing results in tree_embeddings.
Deletes existing embeddings for this tree before re-indexing.
Returns the number of chunks indexed.
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
result = await db.execute(
select(Tree)
.options(selectinload(Tree.tags))
.where(Tree.id == tree_id)
)
tree = result.scalar_one_or_none()
if not tree:
logger.warning("index_tree: tree %s not found", tree_id)
return 0
# Delete existing embeddings
await db.execute(
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
)
# Chunk the tree
tag_names = [t.name for t in tree.tags] if tree.tags else []
chunks = chunk_tree(
tree_name=tree.name,
tree_type=tree.tree_type,
description=tree.description,
tags=tag_names,
tree_structure=tree.tree_structure,
)
if not chunks:
logger.info("index_tree: no chunks for tree %s", tree_id)
return 0
# Get embeddings for all chunks in batch
texts = [c["chunk_text"] for c in chunks]
embeddings = await get_embeddings_batch(texts, input_type="document")
if embeddings is None:
logger.warning("index_tree: embedding service unavailable for tree %s", tree_id)
return 0
# Insert embeddings
for chunk, embedding in zip(chunks, embeddings):
embedding_str = "[" + ",".join(str(v) for v in embedding) + "]"
await db.execute(
text("""
INSERT INTO tree_embeddings
(tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta)
VALUES
(:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb)
"""),
{
"tree_id": str(tree_id),
"account_id": str(tree.account_id) if tree.account_id else None,
"chunk_type": chunk["chunk_type"],
"node_type": chunk.get("node_type"),
"node_id": chunk.get("node_id"),
"chunk_text": chunk["chunk_text"],
"embedding_model": "voyage-3.5",
"embedding": embedding_str,
"meta": "{}",
},
)
logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id)
return len(chunks)
async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None:
"""Delete all embeddings for a tree."""
await db.execute(
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
)
async def search(
query: str,
account_id: UUID,
db: AsyncSession,
limit: int = 8,
exclude_tree_id: Optional[UUID] = None,
) -> list[dict[str, Any]]:
"""Semantic search over team's flow library.
Args:
query: Natural language search query.
account_id: Scope search to team's flows.
db: Database session.
limit: Max results to return.
exclude_tree_id: Exclude chunks from this tree (for copilot context).
Returns:
List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity.
"""
query_embedding = await get_embedding(query, input_type="query")
if query_embedding is None:
return []
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
exclude_clause = ""
params: dict[str, Any] = {
"embedding": embedding_str,
"account_id": str(account_id),
"limit": limit,
}
if exclude_tree_id:
exclude_clause = "AND te.tree_id != :exclude_tree_id"
params["exclude_tree_id"] = str(exclude_tree_id)
result = await db.execute(
text(f"""
SELECT
te.tree_id,
t.name as tree_name,
t.tree_type,
te.chunk_text,
te.chunk_type,
te.node_id,
1 - (te.embedding <=> :embedding::vector) as similarity
FROM tree_embeddings te
JOIN trees t ON t.id = te.tree_id
WHERE te.account_id = :account_id
AND t.deleted_at IS NULL
{exclude_clause}
ORDER BY te.embedding <=> :embedding::vector
LIMIT :limit
"""),
params,
)
rows = result.mappings().all()
return [
{
"tree_id": str(row["tree_id"]),
"tree_name": row["tree_name"],
"tree_type": row["tree_type"],
"chunk_text": row["chunk_text"],
"chunk_type": row["chunk_type"],
"node_id": row["node_id"],
"similarity": float(row["similarity"]),
}
for row in rows
]