"""Unified chat service — chat sessions on ai_sessions table. Replaces assistant_chat_service for new chat sessions. Messages are stored in ai_sessions.conversation_messages JSONB. Reuses the same AI calling infrastructure and system prompt from assistant_chat_service. """ import json import logging import re from typing import Any from uuid import UUID from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.ai_session import AISession from app.services.assistant_chat_service import ( ASSISTANT_SYSTEM_PROMPT, _call_ai, _auto_title, ) from app.services.rag_service import search as rag_search, build_rag_context, extract_suggested_flows logger = logging.getLogger(__name__) def _parse_fork_marker(ai_content: str) -> tuple[str, dict[str, Any] | None]: """Extract [FORK]...[/FORK] JSON from AI response. Returns (cleaned_content, fork_data_or_None). The fork marker is stripped from the display text. """ match = re.search(r'\[FORK\]\s*([\s\S]*?)\s*\[/FORK\]', ai_content) if not match: return ai_content, None try: raw = match.group(1).strip() # Strip markdown fences if AI wrapped it if raw.startswith("```"): raw = re.sub(r'^```(?:json)?\s*', '', raw) raw = re.sub(r'\s*```$', '', raw) fork_data = json.loads(raw) except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to parse [FORK] marker: %s", e) return ai_content, None # Validate structure if not isinstance(fork_data, dict) or "options" not in fork_data: logger.warning("Invalid [FORK] data — missing 'options'") return ai_content, None options = fork_data["options"] if not isinstance(options, list) or len(options) < 2: logger.warning("Invalid [FORK] data — need at least 2 options") return ai_content, None # Strip the marker from display text cleaned = ai_content[:match.start()] + ai_content[match.end():] cleaned = cleaned.strip() return cleaned, fork_data def _parse_actions_marker(ai_content: str) -> tuple[str, list[dict[str, Any]] | None]: """Extract [ACTIONS]...[/ACTIONS] JSON from AI response. Returns (cleaned_content, actions_list_or_None). The actions marker is stripped from the display text. """ match = re.search(r'\[ACTIONS\]\s*([\s\S]*?)\s*\[/ACTIONS\]', ai_content) if not match: return ai_content, None try: raw = match.group(1).strip() if raw.startswith("```"): raw = re.sub(r'^```(?:json)?\s*', '', raw) raw = re.sub(r'\s*```$', '', raw) actions = json.loads(raw) except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to parse [ACTIONS] marker: %s", e) return ai_content, None if not isinstance(actions, list) or len(actions) == 0: logger.warning("Invalid [ACTIONS] data — need at least 1 action") return ai_content, None # Validate each action has at minimum a label valid_actions = [] for a in actions: if isinstance(a, dict) and a.get("label"): valid_actions.append({ "label": a["label"], "command": a.get("command"), "description": a.get("description", ""), }) if not valid_actions: return ai_content, None cleaned = ai_content[:match.start()] + ai_content[match.end():] cleaned = cleaned.strip() return cleaned, valid_actions def _parse_questions_marker(ai_content: str) -> tuple[str, list[dict[str, Any]] | None]: """Extract [QUESTIONS]...[/QUESTIONS] JSON from AI response. Returns (cleaned_content, questions_list_or_None). The questions marker is stripped from the display text. """ match = re.search(r'\[QUESTIONS\]\s*([\s\S]*?)\s*\[/QUESTIONS\]', ai_content) if not match: return ai_content, None try: raw = match.group(1).strip() if raw.startswith("```"): raw = re.sub(r'^```(?:json)?\s*', '', raw) raw = re.sub(r'\s*```$', '', raw) questions = json.loads(raw) except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to parse [QUESTIONS] marker: %s", e) return ai_content, None if not isinstance(questions, list) or len(questions) == 0: logger.warning("Invalid [QUESTIONS] data — need at least 1 question") return ai_content, None # Validate each question has at minimum a text field valid_questions = [] for q in questions: if isinstance(q, dict) and q.get("text"): valid_questions.append({ "text": q["text"], "context": q.get("context", ""), }) if not valid_questions: return ai_content, None cleaned = ai_content[:match.start()] + ai_content[match.end():] cleaned = cleaned.strip() return cleaned, valid_questions async def create_chat_session( user_id: UUID, account_id: UUID, team_id: UUID | None, intake_content: dict[str, Any], db: AsyncSession, ) -> AISession: """Create a new chat session on ai_sessions.""" first_message = intake_content.get("text", "") title = _auto_title(first_message) if first_message else "New Chat" session = AISession( user_id=user_id, account_id=account_id, team_id=team_id, session_type="chat", title=title, intake_type="free_text", intake_content=intake_content, status="active", confidence_tier="discovery", confidence_score=0.0, conversation_messages=[], ) db.add(session) await db.flush() return session async def send_chat_message( session_id: UUID, user_id: UUID, account_id: UUID, message: str, db: AsyncSession, images: list[dict[str, Any]] | None = None, ) -> tuple[str, list[dict[str, Any]], AISession, dict[str, Any] | None, list[dict[str, Any]] | None, list[dict[str, Any]] | None]: """Send a message in a chat session and get AI response. Args: images: Optional list of {"media_type": str, "data": str (base64)} for vision content attached to this message. Returns (ai_content, suggested_flows, session, fork_metadata, actions_data, questions_data). """ result = await db.execute( select(AISession).where( AISession.id == session_id, AISession.user_id == user_id, AISession.session_type == "chat", ) ) session = result.scalar_one_or_none() if not session: raise ValueError("Chat session not found") if session.status not in ("active", "paused"): raise ValueError(f"Cannot send messages to a {session.status} session") # If branching is active, route to branch message handler if session.is_branching and session.active_branch_id: from app.services.branch_manager import BranchManager from app.services.branch_aware_prompt_builder import BranchAwarePromptBuilder from app.models.session_branch import SessionBranch branch_result = await db.execute( select(SessionBranch).where(SessionBranch.id == session.active_branch_id) ) branch = branch_result.scalar_one_or_none() if branch: manager = BranchManager(db) sibling_ctx = await manager.build_cross_branch_context(branch.id) builder = BranchAwarePromptBuilder() session_context = f"Problem: {session.problem_summary or 'Unknown'}. Domain: {session.problem_domain or 'Unknown'}." prompt_args = builder.build( branch_messages=branch.conversation_messages, sibling_summaries=sibling_ctx, session_context=session_context, attachments=[], new_message=message, revival_context=branch.evidence_description if branch.status == "revived" else None, ) # Override images from prompt_args with actual images if provided if images: prompt_args["images"] = images ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args) # Update branch conversation # Strip _(not yet completed)_ markers before storage (same reason as main path) stored_message = message.replace("_(not yet completed)_", "(pending)").replace("_(skipped)_", "(skipped)") msgs = list(branch.conversation_messages or []) msgs.append({"role": "user", "content": stored_message}) msgs.append({"role": "assistant", "content": ai_content}) branch.conversation_messages = msgs session.total_input_tokens += input_tokens session.total_output_tokens += output_tokens session.step_count += 2 if session.status == "paused": session.status = "active" # Check for fork, actions, and questions markers in branch response too branch_display, branch_fork_data = _parse_fork_marker(ai_content) branch_display, branch_actions_data = _parse_actions_marker(branch_display) branch_display, branch_questions_data = _parse_questions_marker(branch_display) if branch_display != ai_content: # Store stripped content in branch history msgs[-1] = {"role": "assistant", "content": branch_display} branch.conversation_messages = msgs branch_fork_metadata = None if branch_fork_data: try: fork_point, new_branches = await manager.create_fork( session_id=session.id, parent_branch_id=branch.id, trigger_step_id=None, fork_reason=branch_fork_data.get("fork_reason", ""), options=[ {"label": o["label"], "description": o.get("description", "")} for o in branch_fork_data["options"] ], ) first_branch = new_branches[0] await manager.switch_branch(session.id, first_branch.id) branch_fork_metadata = { "fork_point_id": str(fork_point.id), "fork_reason": branch_fork_data.get("fork_reason", ""), "branches": [ {"branch_id": str(b.id), "label": b.label} for b in new_branches ], "active_branch_id": str(first_branch.id), } await db.flush() except Exception: logger.exception("Failed to create fork within branch for session %s", session.id) # Persist task lane state on session if branch_questions_data or branch_actions_data: session.pending_task_lane = { "questions": branch_questions_data or [], "actions": branch_actions_data or [], } else: session.pending_task_lane = None suggested_flows = extract_suggested_flows( await rag_search(query=message, account_id=account_id, db=db, limit=8) ) return branch_display, suggested_flows, session, branch_fork_metadata, branch_actions_data, branch_questions_data # Auto-title from first message if still default if session.step_count == 0 and message.strip(): session.title = _auto_title(message) # Auto-detect problem domain from first message if not session.problem_summary and message.strip(): session.problem_summary = _auto_title(message) # RAG search for relevant flows rag_results = await rag_search( query=message, account_id=account_id, db=db, limit=8, ) rag_context = build_rag_context(rag_results) # Build message history for AI ai_messages: list[dict[str, Any]] = [] for msg in (session.conversation_messages or []): if msg.get("role") in ("user", "assistant"): ai_messages.append({"role": msg["role"], "content": msg["content"]}) # Call AI ai_content, input_tokens, output_tokens = await _call_ai( system_base=ASSISTANT_SYSTEM_PROMPT, rag_context=rag_context, history=ai_messages, new_message=message, images=images, ) # Check for fork marker in AI response display_content, fork_data = _parse_fork_marker(ai_content) # Check for actions marker in AI response display_content, actions_data = _parse_actions_marker(display_content) # Check for questions marker in AI response display_content, questions_data = _parse_questions_marker(display_content) logger.info( "Marker parsing results — actions: %s, questions: %s, fork: %s, raw_length: %d, display_length: %d", bool(actions_data), bool(questions_data), bool(fork_data), len(ai_content), len(display_content), ) # Store DISPLAY content (markers stripped) in conversation_messages. # The format reminder in the user message + system prompt final reminder # are sufficient to keep the AI emitting markers on subsequent turns. # # Strip _(not yet completed)_ task markers from the stored user message. # The AI processes them correctly on the current turn, but persisting them # into history causes the AI to re-inject stale task lane items from prior # turns — even across unrelated topics in a long session. stored_message = message.replace("_(not yet completed)_", "(pending)").replace("_(skipped)_", "(skipped)") msgs = list(session.conversation_messages or []) msgs.append({"role": "user", "content": stored_message}) msgs.append({"role": "assistant", "content": display_content}) session.conversation_messages = msgs session.step_count += 2 # message count for display session.total_input_tokens += input_tokens session.total_output_tokens += output_tokens # Resume if paused if session.status == "paused": session.status = "active" # If fork was detected, create branches fork_metadata = None if fork_data: try: from app.services.branch_manager import BranchManager mgr = BranchManager(db) # Create root branch if this is the first fork if not session.is_branching: await mgr.create_root_branch(session.id) fork_point, new_branches = await mgr.create_fork( session_id=session.id, parent_branch_id=session.active_branch_id, trigger_step_id=None, fork_reason=fork_data.get("fork_reason", ""), options=[ {"label": o["label"], "description": o.get("description", "")} for o in fork_data["options"] ], ) # Don't auto-switch — conversation continues on current branch. # Branches appear in sidebar. User switches when ready. fork_metadata = { "fork_point_id": str(fork_point.id), "fork_reason": fork_data.get("fork_reason", ""), "branches": [ {"branch_id": str(b.id), "label": b.label} for b in new_branches ], "active_branch_id": str(session.active_branch_id) if session.active_branch_id else None, } await db.flush() logger.info("Created fork with %d branches for session %s", len(new_branches), session_id) except Exception: logger.exception("Failed to create fork for session %s", session_id) # Fork failed but chat message still sent — don't break the response # Persist task lane state on session if questions_data or actions_data: session.pending_task_lane = { "questions": questions_data or [], "actions": actions_data or [], } else: session.pending_task_lane = None suggested_flows = extract_suggested_flows(rag_results) return display_content, suggested_flows, session, fork_metadata, actions_data, questions_data