"""Flow Matching Engine v1 — find existing flows relevant to an AI session's intake. Combines keyword matching, semantic search (via RAG embeddings), and recency scoring to rank flows. Deliberately simple for v1; v2 (Phase 3) adds deeper semantic matching. Scoring weights: semantic 0.5, keyword 0.3, recency 0.2. Threshold: only return matches with composite score > 0.5. """ import logging from datetime import datetime, timezone, timedelta from typing import Any, Optional from uuid import UUID from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession from app.models.tree import Tree from app.services.rag_service import search as rag_search logger = logging.getLogger(__name__) # Scoring weights SEMANTIC_WEIGHT = 0.5 KEYWORD_WEIGHT = 0.3 RECENCY_WEIGHT = 0.2 # Only return matches above this composite score SCORE_THRESHOLD = 0.5 async def find_matches( intake_text: str, problem_domain: Optional[str], account_id: UUID, db: AsyncSession, limit: int = 5, ) -> list[dict[str, Any]]: """Find existing flows that match the intake description. Returns list of dicts sorted by composite score: {tree_id, tree_name, score, match_reason} """ candidates: dict[str, dict[str, Any]] = {} # 1. Semantic search via existing RAG embeddings try: rag_results = await rag_search( query=intake_text, account_id=account_id, db=db, limit=10, ) for r in rag_results: tree_id = str(r["tree_id"]) similarity = r.get("similarity", 0.0) if tree_id not in candidates: candidates[tree_id] = { "tree_id": tree_id, "tree_name": r["tree_name"], "semantic_score": similarity, "keyword_score": 0.0, "recency_score": 0.0, "match_reasons": [], } else: # Take the best semantic score across chunks candidates[tree_id]["semantic_score"] = max( candidates[tree_id]["semantic_score"], similarity ) if similarity > 0.5: candidates[tree_id]["match_reasons"].append( f"semantic match ({similarity:.0%})" ) except Exception as e: logger.warning("Semantic search failed during flow matching: %s", e) # 2. Keyword matching against trees.match_keywords try: keyword_matches = await _keyword_match(intake_text, account_id, db) for km in keyword_matches: tree_id = str(km["tree_id"]) if tree_id not in candidates: candidates[tree_id] = { "tree_id": tree_id, "tree_name": km["tree_name"], "semantic_score": 0.0, "keyword_score": km["score"], "recency_score": 0.0, "match_reasons": [], } else: candidates[tree_id]["keyword_score"] = km["score"] if km["score"] > 0.3: candidates[tree_id]["match_reasons"].append( f"keyword match: {', '.join(km.get('matched_keywords', []))}" ) except Exception as e: logger.warning("Keyword matching failed: %s", e) # 3. Category/domain match if problem_domain: try: domain_matches = await _domain_match(problem_domain, account_id, db) for dm in domain_matches: tree_id = str(dm["tree_id"]) if tree_id not in candidates: candidates[tree_id] = { "tree_id": tree_id, "tree_name": dm["tree_name"], "semantic_score": 0.0, "keyword_score": 0.2, # Small boost for domain match "recency_score": 0.0, "match_reasons": [], } else: candidates[tree_id]["keyword_score"] = max( candidates[tree_id]["keyword_score"], 0.2 ) candidates[tree_id]["match_reasons"].append(f"domain match: {problem_domain}") except Exception as e: logger.warning("Domain matching failed: %s", e) # 4. Apply recency boost now = datetime.now(timezone.utc) for tree_id, candidate in candidates.items(): # We'll compute recency from the tree data if available candidate["recency_score"] = 0.0 # Default, updated below # Fetch recency data for all candidates if candidates: try: recency_data = await _get_recency_scores( list(candidates.keys()), db ) for tree_id, recency_score in recency_data.items(): if tree_id in candidates: candidates[tree_id]["recency_score"] = recency_score except Exception as e: logger.warning("Recency scoring failed: %s", e) # 5. Compute composite scores and filter results = [] for tree_id, c in candidates.items(): composite = ( c["semantic_score"] * SEMANTIC_WEIGHT + c["keyword_score"] * KEYWORD_WEIGHT + c["recency_score"] * RECENCY_WEIGHT ) if composite > SCORE_THRESHOLD: results.append({ "tree_id": UUID(tree_id), "tree_name": c["tree_name"], "score": round(composite, 3), "match_reason": "; ".join(c["match_reasons"][:3]) if c["match_reasons"] else "composite match", }) # Sort by score descending, take top N results.sort(key=lambda x: x["score"], reverse=True) return results[:limit] async def _keyword_match( intake_text: str, account_id: UUID, db: AsyncSession, ) -> list[dict[str, Any]]: """Match intake text against trees.match_keywords JSONB arrays. Simple approach: tokenize intake text, check overlap with each tree's keywords. """ # Extract meaningful tokens from intake (lowercase, 3+ chars) tokens = set() for word in intake_text.lower().split(): cleaned = "".join(c for c in word if c.isalnum()) if len(cleaned) >= 3: tokens.add(cleaned) if not tokens: return [] # Find trees with match_keywords set result = await db.execute( select(Tree.id, Tree.name, Tree.match_keywords) .where( Tree.account_id == account_id, Tree.deleted_at.is_(None), Tree.status == "published", Tree.match_keywords.isnot(None), ) ) rows = result.all() matches = [] for row in rows: tree_keywords = row.match_keywords or [] if not isinstance(tree_keywords, list): continue # Lowercase keywords for comparison kw_lower = {str(kw).lower() for kw in tree_keywords} # Calculate overlap matched = tokens & kw_lower if matched: score = len(matched) / max(len(kw_lower), 1) matches.append({ "tree_id": row.id, "tree_name": row.name, "score": min(score, 1.0), "matched_keywords": list(matched)[:5], }) return matches async def _domain_match( problem_domain: str, account_id: UUID, db: AsyncSession, ) -> list[dict[str, Any]]: """Find trees whose category matches the classified problem domain.""" result = await db.execute( select(Tree.id, Tree.name) .where( Tree.account_id == account_id, Tree.deleted_at.is_(None), Tree.status == "published", Tree.category.ilike(f"%{problem_domain}%"), ) .limit(10) ) rows = result.all() return [{"tree_id": row.id, "tree_name": row.name} for row in rows] async def _get_recency_scores( tree_ids: list[str], db: AsyncSession, ) -> dict[str, float]: """Calculate recency scores based on last_matched_at. Trees matched within the last 7 days get full recency boost (0.2 → 1.0). Trees matched within 30 days get partial boost. Older or never-matched trees get 0. """ if not tree_ids: return {} result = await db.execute( select(Tree.id, Tree.last_matched_at, Tree.success_rate) .where(Tree.id.in_([UUID(tid) for tid in tree_ids])) ) rows = result.all() now = datetime.now(timezone.utc) scores = {} for row in rows: tree_id = str(row.id) if row.last_matched_at is None: scores[tree_id] = 0.0 continue days_since = (now - row.last_matched_at).days if days_since <= 7: recency = 1.0 elif days_since <= 30: recency = 1.0 - ((days_since - 7) / 23) # Linear decay 7-30 days else: recency = 0.0 # Factor in success rate if available if row.success_rate is not None: recency *= row.success_rate scores[tree_id] = max(0.0, min(1.0, recency)) return scores