diff --git a/backend/alembic/versions/e2d81e82ea5e_add_ai_chat_sessions_table.py b/backend/alembic/versions/e2d81e82ea5e_add_ai_chat_sessions_table.py new file mode 100644 index 00000000..42d2264c --- /dev/null +++ b/backend/alembic/versions/e2d81e82ea5e_add_ai_chat_sessions_table.py @@ -0,0 +1,46 @@ +"""add ai_chat_sessions table + +Revision ID: e2d81e82ea5e +Revises: 1490781700bc +Create Date: 2026-02-27 03:41:33.832260 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSONB + + +# revision identifiers, used by Alembic. +revision: str = 'e2d81e82ea5e' +down_revision: Union[str, None] = '1490781700bc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "ai_chat_sessions", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("status", sa.String(20), nullable=False, server_default="active"), + sa.Column("current_phase", sa.String(20), nullable=False, server_default="scoping"), + sa.Column("flow_type", sa.String(20), nullable=False), + sa.Column("conversation_history", JSONB, nullable=False, server_default="[]"), + sa.Column("working_tree", JSONB, nullable=True), + sa.Column("tree_metadata", JSONB, nullable=False, server_default="{}"), + sa.Column("provider_used", sa.String(20), nullable=True), + sa.Column("message_count", sa.Integer, nullable=False, server_default="0"), + sa.Column("total_input_tokens", sa.Integer, nullable=False, server_default="0"), + sa.Column("total_output_tokens", sa.Integer, nullable=False, server_default="0"), + sa.Column("generated_tree_id", UUID(as_uuid=True), sa.ForeignKey("trees.id", ondelete="SET NULL"), nullable=True), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("ai_chat_sessions") diff --git a/backend/app/api/endpoints/ai_chat.py b/backend/app/api/endpoints/ai_chat.py new file mode 100644 index 00000000..defebd5e --- /dev/null +++ b/backend/app/api/endpoints/ai_chat.py @@ -0,0 +1,426 @@ +"""AI Chat Builder endpoints. + +Conversational flow builder: + POST /ai/chat/sessions — Start session, get AI greeting + POST /ai/chat/sessions/{id}/messages — Send message, get AI response + GET /ai/chat/sessions/{id} — Get session state (for resume) + POST /ai/chat/sessions/{id}/generate — Generate final TreeStructure + POST /ai/chat/sessions/{id}/import — Create Tree from generated structure + DELETE /ai/chat/sessions/{id} — Abandon session +""" +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.rate_limit import limiter +from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin +from app.core.config import settings +from app.core.ai_chat_service import ( + start_chat_session, + send_message, + generate_final_tree, + get_chat_session, + MAX_MESSAGES_FREE, + MAX_MESSAGES_PAID, +) +from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan +from app.models.user import User +from app.models.tree import Tree +from app.schemas.ai_chat import ( + AIChatStartRequest, + AIChatStartResponse, + AIChatMessageRequest, + AIChatMessageResponse, + AIChatSessionResponse, + AIChatGenerateResponse, + AIChatImportRequest, + AIChatImportResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ai/chat", tags=["ai-chat-builder"]) + + +def _require_ai_enabled() -> None: + if not settings.ai_enabled: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.", + ) + + +@router.post("/sessions", response_model=AIChatStartResponse, status_code=201) +@limiter.limit("10/minute") +async def create_session( + request: Request, + data: AIChatStartRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Start a new AI chat builder session.""" + _require_ai_enabled() + + allowed, quota_status = await check_ai_quota( + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + billing_anchor=current_user.ai_billing_cycle_anchor_at, + is_super_admin=current_user.is_super_admin, + ) + if not allowed: + reset_key = ( + "daily_reset_at" + if quota_status.get("deny_reason") == "daily" + else "monthly_reset_at" + ) + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "message": f"AI build limit exceeded ({quota_status['deny_reason']})", + "reset_at": quota_status.get(reset_key), + "quota": quota_status, + }, + ) + + plan = await get_user_plan(current_user.account_id, db) + + try: + session, greeting = await start_chat_session( + flow_type=data.flow_type, + user_id=current_user.id, + account_id=current_user.account_id, + db=db, + ) + except Exception as e: + logger.exception("AI chat session start failed: %s", e) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) + + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_message", + tier=plan, + input_tokens=session.total_input_tokens, + output_tokens=session.total_output_tokens, + estimated_cost=( + session.total_input_tokens * 1.0 / 1_000_000 + + session.total_output_tokens * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data={"phase": "scoping", "chat_session_id": str(session.id)}, + db=db, + ) + + await db.commit() + + return AIChatStartResponse( + session_id=session.id, + greeting=greeting, + current_phase=session.current_phase, + ) + + +@router.post("/sessions/{session_id}/messages", response_model=AIChatMessageResponse) +@limiter.limit("10/minute") +async def post_message( + request: Request, + session_id: UUID, + data: AIChatMessageRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Send a user message and get AI response.""" + _require_ai_enabled() + + session = await get_chat_session(session_id, current_user.id, db) + + if session.status != "active": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Session is {session.status}, cannot send messages", + ) + + plan = await get_user_plan(current_user.account_id, db) + max_messages = MAX_MESSAGES_PAID if plan != "free" else MAX_MESSAGES_FREE + if current_user.is_super_admin: + max_messages = 999 + + if session.message_count >= max_messages: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Maximum messages per session reached ({max_messages}). Generate your tree or start a new session.", + ) + + prev_input = session.total_input_tokens + prev_output = session.total_output_tokens + + try: + ai_content, tree_update, new_phase, metadata = await send_message( + session, data.content, db + ) + except Exception as e: + logger.exception("AI chat message failed: %s", e) + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_message", + tier=plan, + input_tokens=0, + output_tokens=0, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"chat_session_id": str(session.id)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({type(e).__name__}). Please try again.", + ) + + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_message", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=( + input_delta * 1.0 / 1_000_000 + + output_delta * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=False, + error_code=None, + extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)}, + db=db, + ) + + await db.commit() + + return AIChatMessageResponse( + content=ai_content, + current_phase=session.current_phase, + working_tree=session.working_tree, + tree_metadata=session.tree_metadata if session.tree_metadata else None, + ) + + +@router.get("/sessions/{session_id}", response_model=AIChatSessionResponse) +async def get_session( + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Get full session state for resume after page reload.""" + session = await get_chat_session(session_id, current_user.id, db) + + visible_history = [ + msg for msg in session.conversation_history + if not msg.get("hidden") + ] + + return AIChatSessionResponse( + session_id=session.id, + status=session.status, + current_phase=session.current_phase, + flow_type=session.flow_type, + conversation_history=visible_history, + working_tree=session.working_tree, + tree_metadata=session.tree_metadata if session.tree_metadata else None, + message_count=session.message_count, + generated_tree=session.working_tree if session.status == "completed" else None, + ) + + +@router.post("/sessions/{session_id}/generate", response_model=AIChatGenerateResponse) +@limiter.limit("10/minute") +async def generate_tree( + request: Request, + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Generate final TreeStructure JSON from conversation.""" + _require_ai_enabled() + + session = await get_chat_session(session_id, current_user.id, db) + + if session.status == "completed" and session.working_tree: + return AIChatGenerateResponse( + tree_structure=session.working_tree, + tree_metadata=session.tree_metadata, + status="completed", + ) + + if session.status != "active": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Session is {session.status}, cannot generate", + ) + + plan = await get_user_plan(current_user.account_id, db) + prev_input = session.total_input_tokens + prev_output = session.total_output_tokens + + try: + tree_structure, metadata = await generate_final_tree(session, db) + except ValueError as e: + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_generate", + tier=plan, + input_tokens=session.total_input_tokens - prev_input, + output_tokens=session.total_output_tokens - prev_output, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code="invalid_output", + extra_data={"error": str(e), "chat_session_id": str(session.id)}, + db=db, + ) + await db.commit() + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Tree generation failed: {e}", + ) + except Exception as e: + logger.exception("AI chat generate failed: %s", e) + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_generate", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=0, + succeeded=False, + counts_toward_quota=False, + error_code=type(e).__name__, + extra_data={"error": str(e), "chat_session_id": str(session.id)}, + db=db, + ) + await db.commit() + + error_name = type(e).__name__ + if "timeout" in error_name.lower() or "Timeout" in str(e): + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Tree generation timed out. Please try again.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"AI provider error ({error_name}). Please try again.", + ) + + input_delta = session.total_input_tokens - prev_input + output_delta = session.total_output_tokens - prev_output + await record_ai_usage( + user_id=current_user.id, + account_id=current_user.account_id, + conversation_id=None, + generation_type="chat_generate", + tier=plan, + input_tokens=input_delta, + output_tokens=output_delta, + estimated_cost=( + input_delta * 1.0 / 1_000_000 + + output_delta * 5.0 / 1_000_000 + ), + succeeded=True, + counts_toward_quota=True, + error_code=None, + extra_data={"chat_session_id": str(session.id)}, + db=db, + ) + + session.status = "completed" + await db.commit() + + return AIChatGenerateResponse( + tree_structure=tree_structure, + tree_metadata=metadata, + status="completed", + ) + + +@router.post("/sessions/{session_id}/import", response_model=AIChatImportResponse) +@limiter.limit("10/minute") +async def import_tree( + request: Request, + session_id: UUID, + data: AIChatImportRequest, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], + _: None = Depends(require_engineer_or_admin), +): + """Create a Tree record from the generated tree structure.""" + session = await get_chat_session(session_id, current_user.id, db) + + if session.status != "completed" or not session.working_tree: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Session must be completed with a generated tree before importing", + ) + + # Always create a new Tree record (no duplicate check — user may + # want multiple copies or re-import after edits) + metadata = session.tree_metadata or {} + tree = Tree( + name=data.name or metadata.get("name", "AI-Generated Flow"), + description=data.description or metadata.get("description", ""), + tree_type=session.flow_type, + tree_structure=session.working_tree, + author_id=current_user.id, + account_id=current_user.account_id, + category_id=data.category_id, + is_public=False, + ) + db.add(tree) + await db.flush() + + session.generated_tree_id = tree.id + await db.commit() + + return AIChatImportResponse( + tree_id=tree.id, + tree_type=session.flow_type, + ) + + +@router.delete("/sessions/{session_id}", status_code=204) +@limiter.limit("10/minute") +async def abandon_session( + request: Request, + session_id: UUID, + current_user: Annotated[User, Depends(get_current_active_user)], + db: Annotated[AsyncSession, Depends(get_db)], +): + """Abandon a chat session.""" + session = await get_chat_session(session_id, current_user.id, db) + session.status = "abandoned" + await db.commit() diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 27963a1f..41fdb0b2 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -7,6 +7,7 @@ from app.api.endpoints import maintenance_schedules from app.api.endpoints import feedback from app.api.endpoints import ai_builder from app.api.endpoints import ai_fix +from app.api.endpoints import ai_chat api_router = APIRouter() @@ -38,3 +39,4 @@ api_router.include_router(maintenance_schedules.router) api_router.include_router(feedback.router) api_router.include_router(ai_builder.router) api_router.include_router(ai_fix.router) +api_router.include_router(ai_chat.router) diff --git a/backend/app/core/ai_chat_service.py b/backend/app/core/ai_chat_service.py new file mode 100644 index 00000000..948f4701 --- /dev/null +++ b/backend/app/core/ai_chat_service.py @@ -0,0 +1,474 @@ +"""AI Chat Builder service. + +Manages the conversational flow builder: system prompt construction, +message exchange with AI provider, and response parsing (extracting +tree updates, phase transitions, and metadata from structured markers). +""" +import json +import logging +import re +import uuid +from datetime import datetime, timezone, timedelta +from typing import Any, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.ai_provider import get_ai_provider +from app.core.ai_tree_validator import validate_generated_tree +from app.core.config import settings +from app.models.ai_chat_session import AIChatSession + +logger = logging.getLogger(__name__) + +# ── Cost estimation ── +COST_PER_INPUT_TOKEN = 1.0 / 1_000_000 +COST_PER_OUTPUT_TOKEN = 5.0 / 1_000_000 + +# ── Max messages per session ── +MAX_MESSAGES_FREE = 10 +MAX_MESSAGES_PAID = 25 + + +# ── System Prompt ── + +ROLE_PERSONA = """You are a senior IT engineer embedded in ResolutionFlow, a troubleshooting platform for MSP (Managed Service Provider) engineers. You have 15+ years of hands-on experience across Windows Server, Active Directory, Entra ID/Azure AD, Microsoft 365, networking (DNS, DHCP, routing, VPN, firewalls), virtualization (Hyper-V, VMware), security, backup/DR, and cloud infrastructure. + +Your job is to help engineers build troubleshooting decision trees by interviewing them about a problem space. You are NOT a generic assistant. You are a colleague who has seen these issues hundreds of times and knows the optimal diagnostic order. + +CRITICAL BEHAVIORS: +- Act as a senior engineer, not a chatbot. Use your domain knowledge to SUGGEST diagnostic steps, not just record what the user says. +- When the user describes a problem area, demonstrate understanding by naming specific sub-categories, common causes, and relevant tools. +- Challenge assumptions constructively: "Before we go down that path, have you considered checking X first? In my experience, that resolves 60% of these cases." +- Capture SPECIFIC commands with exact syntax. Not "check the service" but "Get-Service ADSync | Select-Object Status, StartType". +- Include expected outcomes for every action: what does success look like? +- Surface edge cases proactively: "What about multi-forest environments?" or "Does this change if they have conditional access policies?" +- Explain WHY the diagnostic order matters: "We check connectivity before auth because a network issue masquerades as an auth failure." +- Ask ONE focused question at a time. Do not overwhelm with multiple questions. +- Use plain, collegial language. Sound like a colleague, not a form.""" + +SCHEMA_CONTEXT = """ +TREESTRUCTURE SCHEMA — This is what you are building: + +The tree is a recursive JSON structure. Each node has a "type" field: + +1. decision — A diagnostic question with branching options + Required: id (string), type ("decision"), question (string), options (array), children (array) + Optional: help_text (string) + Each option: { id (string), label (string), next_node_id (string — must match a child's id) } + +2. action — A step the engineer performs + Required: id (string), type ("action"), title (string), description (string) + Optional: commands (string array — exact CLI/PowerShell syntax), expected_outcome (string), help_text (string), next_node_id (string — ID of the next node to navigate to) + +3. solution — A resolution endpoint + Required: id (string), type ("solution"), title (string), description (string) + Optional: resolution_steps (string array) + +STRUCTURAL RULES: +- Root node MUST be type "decision" +- Decision nodes contain their children in the "children" array +- Each decision option's next_node_id typically references a child node's id, BUT can also reference ANY other node in the tree for loop-back / re-verification patterns +- Action nodes use next_node_id to chain to the next step — this can point to any node in the tree, including ancestors, for loop-backs (e.g., "remediate → re-verify from earlier checkpoint") +- Solution nodes are terminal — no next_node_id or children +- All IDs must be unique strings (use descriptive slugs like "check-service-status") + +CROSS-REFERENCE / LOOP-BACK PATTERN: +When a troubleshooting path needs to loop back (e.g., after remediation, re-verify from an earlier checkpoint), set next_node_id to the target node's ID. Example: an action node "restart-ssh-service" can set next_node_id to "verify-ssh-connection" (an ancestor decision node) to create a re-verification loop. +""" + +INTERVIEW_PROTOCOL = """ +INTERVIEW PHASES — Follow this progression: + +PHASE 1 - SCOPING (current_phase: scoping): +Ask broad questions to understand the problem domain and scope: +- What type of issue is this flow for? +- Who is the target audience? (Tier 1 help desk, Tier 2, Tier 3?) +- What environment assumptions? (On-prem, hybrid, specific vendors?) +Demonstrate domain expertise immediately. If the user says "Azure AD Sync failures," show understanding: "Are you primarily seeing password hash sync issues, object attribute sync failures, or full directory sync errors?" +DO NOT emit [TREE_UPDATE] during scoping. You are still understanding the problem. + +PHASE 2 - DISCOVERY (current_phase: discovery): +Work through the troubleshooting logic branch by branch: +- Establish the first diagnostic question (the root decision node) +- For each branch, ask what the engineer would check next +- Suggest checks the user might not have considered +- Capture specific commands, tools, and procedures +EMIT [TREE_UPDATE] ONLY when you and the user have agreed on a concrete node — a decision with clear options, or an action with a specific command. If you are asking a question, you are NOT updating the tree. + +PHASE 3 - ENRICHMENT (current_phase: enrichment): +Circle back to enrich existing nodes: +- Add exact PowerShell/CLI commands with syntax +- Add help text with relevant documentation links +- Add expected outcomes for action nodes +- Suggest edge cases needing additional branches +EMIT [TREE_UPDATE] when enriching existing nodes or adding edge case branches. + +PHASE 4 - REVIEW (current_phase: review): +Present a summary: +- Total node count by type +- Text outline of the flow structure +- Flag any areas of uncertainty +- Offer chance to add/remove/modify branches +EMIT [TREE_UPDATE] only if the user requests structural changes. + +TRANSITION between phases by emitting [PHASE:phase_name] when the conversation naturally moves to the next stage. You decide when enough information has been gathered for each phase. +""" + +RESPONSE_FORMAT = """ +RESPONSE FORMAT: + +Your response is natural conversational text. When the tree structure changes, include structured markers that will be parsed by the system (the user will NOT see these markers): + +1. Tree update (only when structure changes — see phase rules above): +[TREE_UPDATE] +{...valid TreeStructure JSON...} +[/TREE_UPDATE] + +2. Phase transition (when moving to next phase): +[PHASE:discovery] + +3. Metadata capture (when you learn the flow's name, description, or tags): +[METADATA] +{"name": "...", "description": "...", "tags": ["..."]} +[/METADATA] + +IMPORTANT: +- Include [TREE_UPDATE] sparingly. Only when concrete nodes are established or modified. +- The tree update should be the COMPLETE working tree, not a diff. +- Always include conversational text OUTSIDE the markers — never respond with only markers. +""" + + +def _build_system_prompt(flow_type: str) -> str: + """Assemble the full system prompt for the chat builder.""" + flow_context = ( + "The user wants to build a TROUBLESHOOTING flow — a diagnostic decision tree " + "that guides engineers through symptom identification, diagnostic checks, and " + "resolution steps." + if flow_type == "troubleshooting" + else "The user wants to build a PROCEDURAL flow — a step-by-step process guide " + "with phases, checklists, and verification steps." + ) + + return f"{ROLE_PERSONA}\n\n{flow_context}\n\n{SCHEMA_CONTEXT}\n\n{INTERVIEW_PROTOCOL}\n\n{RESPONSE_FORMAT}" + + +def _strip_markdown_fences(text: str) -> str: + """Strip markdown code fences if the model wrapped its JSON response.""" + text = text.strip() + match = re.match(r"^```(?:json)?\s*([\s\S]*?)```$", text) + if match: + return match.group(1).strip() + return text + + +def _parse_ai_response(raw_response: str) -> dict[str, Any]: + """Parse structured markers from AI response. + + Returns dict with: + - content: str (conversational text with markers stripped) + - tree_update: dict | None (parsed TreeStructure JSON) + - phase: str | None (new phase name) + - metadata: dict | None (name, description, tags) + """ + result: dict[str, Any] = { + "content": raw_response, + "tree_update": None, + "phase": None, + "metadata": None, + } + + # Extract [TREE_UPDATE]...[/TREE_UPDATE] + tree_match = re.search( + r"\[TREE_UPDATE\]\s*([\s\S]*?)\s*\[/TREE_UPDATE\]", raw_response + ) + if tree_match: + try: + raw_json = _strip_markdown_fences(tree_match.group(1)) + result["tree_update"] = json.loads(raw_json) + except (json.JSONDecodeError, ValueError) as e: + logger.warning("Failed to parse tree update JSON: %s", e) + result["content"] = raw_response[: tree_match.start()] + raw_response[tree_match.end() :] + else: + # Handle truncated response — opening tag exists but no closing tag + # (happens when max_tokens cuts off the JSON block) + truncated_match = re.search(r"\[TREE_UPDATE\][\s\S]*$", raw_response) + if truncated_match: + logger.warning("Truncated [TREE_UPDATE] block detected (no closing tag) — stripping from display") + result["content"] = raw_response[: truncated_match.start()] + + # Extract [PHASE:name] + phase_match = re.search(r"\[PHASE:(\w+)\]", result["content"]) + if phase_match: + result["phase"] = phase_match.group(1) + result["content"] = result["content"][: phase_match.start()] + result["content"][phase_match.end() :] + + # Extract [METADATA]...[/METADATA] + meta_match = re.search( + r"\[METADATA\]\s*([\s\S]*?)\s*\[/METADATA\]", result["content"] + ) + if meta_match: + try: + raw_json = _strip_markdown_fences(meta_match.group(1)) + result["metadata"] = json.loads(raw_json) + except (json.JSONDecodeError, ValueError) as e: + logger.warning("Failed to parse metadata JSON: %s", e) + result["content"] = result["content"][: meta_match.start()] + result["content"][meta_match.end() :] + else: + truncated_meta = re.search(r"\[METADATA\][\s\S]*$", result["content"]) + if truncated_meta: + logger.warning("Truncated [METADATA] block detected — stripping from display") + result["content"] = result["content"][: truncated_meta.start()] + + # Clean up extra whitespace from marker removal + result["content"] = re.sub(r"\n{3,}", "\n\n", result["content"]).strip() + + return result + + +# ── Main Service Functions ── + + +async def start_chat_session( + flow_type: str, + user_id: uuid.UUID, + account_id: uuid.UUID, + db: AsyncSession, +) -> tuple[AIChatSession, str]: + """Create a chat session and return the AI's opening greeting. + + Returns (session, greeting_text). + """ + session = AIChatSession( + user_id=user_id, + account_id=account_id, + flow_type=flow_type, + expires_at=datetime.now(timezone.utc) + timedelta(hours=settings.AI_CONVERSATION_TTL_HOURS), + ) + db.add(session) + await db.flush() + + # Build system prompt and get opening message + system_prompt = _build_system_prompt(flow_type) + primer = f"I want to build a {flow_type} flow. Help me get started." + + provider = get_ai_provider() + provider_name = settings.AI_PROVIDER + + messages = [{"role": "user", "content": primer}] + response_text, input_tokens, output_tokens = await provider.generate_text( + system_prompt=system_prompt, + messages=messages, + max_tokens=1500, + ) + + # Parse response (greeting shouldn't have tree updates, but handle gracefully) + parsed = _parse_ai_response(response_text) + + # Store conversation history + now_iso = datetime.now(timezone.utc).isoformat() + session.conversation_history = [ + {"role": "user", "content": primer, "timestamp": now_iso, "hidden": True}, + {"role": "assistant", "content": parsed["content"], "timestamp": now_iso}, + ] + session.provider_used = provider_name + session.message_count = 1 + session.total_input_tokens = input_tokens + session.total_output_tokens = output_tokens + + if parsed["metadata"]: + session.tree_metadata = parsed["metadata"] + + return session, parsed["content"] + + +async def send_message( + session: AIChatSession, + user_message: str, + db: AsyncSession, +) -> tuple[str, Optional[dict], Optional[str], Optional[dict]]: + """Send a user message and get AI response. + + Returns (ai_content, working_tree_update, new_phase, metadata_update). + """ + system_prompt = _build_system_prompt(session.flow_type) + + # Build messages array from conversation history + now_iso = datetime.now(timezone.utc).isoformat() + history = list(session.conversation_history) + history.append({"role": "user", "content": user_message, "timestamp": now_iso}) + + # Convert to provider format (just role + content) + provider_messages = [ + {"role": msg["role"], "content": msg["content"]} + for msg in history + ] + + provider = get_ai_provider() + response_text, input_tokens, output_tokens = await provider.generate_text( + system_prompt=system_prompt, + messages=provider_messages, + max_tokens=8000, + ) + + parsed = _parse_ai_response(response_text) + + # Validate tree update if present (lightweight check for progressive builds — + # only require valid root structure, not min node counts) + tree_update = parsed["tree_update"] + if tree_update: + if not isinstance(tree_update, dict) or tree_update.get("type") != "decision": + logger.warning("AI tree update rejected: root must be a decision node") + tree_update = None + elif not tree_update.get("id"): + logger.warning("AI tree update rejected: root node missing id") + tree_update = None + + # Update session state + history.append({"role": "assistant", "content": parsed["content"], "timestamp": now_iso}) + session.conversation_history = history + session.message_count = session.message_count + 1 + session.total_input_tokens = session.total_input_tokens + input_tokens + session.total_output_tokens = session.total_output_tokens + output_tokens + + if tree_update: + session.working_tree = tree_update + + if parsed["phase"]: + valid_phases = {"scoping", "discovery", "enrichment", "review", "generation"} + if parsed["phase"] in valid_phases: + session.current_phase = parsed["phase"] + + if parsed["metadata"]: + merged = dict(session.tree_metadata) + merged.update(parsed["metadata"]) + session.tree_metadata = merged + + session.updated_at = datetime.now(timezone.utc) + + return parsed["content"], tree_update, parsed["phase"], parsed["metadata"] + + +async def generate_final_tree( + session: AIChatSession, + db: AsyncSession, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Generate the final validated TreeStructure from the conversation. + + Returns (tree_structure, metadata). + Raises ValueError if generation fails after retry. + """ + system_prompt = _build_system_prompt(session.flow_type) + + # Build generation prompt from full conversation + provider_messages = [ + {"role": msg["role"], "content": msg["content"]} + for msg in session.conversation_history + ] + + generation_instruction = """Based on our entire conversation, generate the COMPLETE and FINAL TreeStructure JSON for this flow. + +Requirements: +- Include ALL branches, steps, and solutions we discussed +- Use descriptive node IDs (slugs, not UUIDs) +- Root node must be type "decision" +- Every decision option must have a valid next_node_id pointing to a child +- Every action node should have commands with exact syntax where discussed +- Every action node should have expected_outcome where discussed +- Solution nodes should have resolution_steps +- Respond with ONLY the JSON — no conversational text, no markdown fences + +Also provide metadata as a separate JSON object after the tree: +[METADATA] +{"name": "...", "description": "...", "tags": ["..."]} +[/METADATA]""" + + provider_messages.append({"role": "user", "content": generation_instruction}) + + provider = get_ai_provider() + + for attempt in range(2): # One try + one retry + response_text, input_tokens, output_tokens = await provider.generate_text( + system_prompt=system_prompt, + messages=provider_messages, + max_tokens=8000, + ) + + session.total_input_tokens = session.total_input_tokens + input_tokens + session.total_output_tokens = session.total_output_tokens + output_tokens + + # Extract metadata first + parsed = _parse_ai_response(response_text) + metadata = parsed["metadata"] or dict(session.tree_metadata) + + # Parse tree JSON — could be in tree_update marker or raw + tree = parsed["tree_update"] + if not tree: + try: + raw = _strip_markdown_fences(parsed["content"]) + tree = json.loads(raw) + except (json.JSONDecodeError, ValueError): + pass + + if not tree: + if attempt == 0: + provider_messages.append({"role": "assistant", "content": response_text}) + provider_messages.append({ + "role": "user", + "content": "That response was not valid JSON. Please respond with ONLY the TreeStructure JSON object, starting with { and ending with }. No markdown fences, no explanatory text.", + }) + continue + raise ValueError("AI failed to produce valid JSON after retry") + + errors = validate_generated_tree(tree) + if errors: + if attempt == 0: + provider_messages.append({"role": "assistant", "content": response_text}) + correction = ( + f"The tree has validation errors: {'; '.join(errors)}. " + "Please fix these issues and respond with the corrected JSON only." + ) + provider_messages.append({"role": "user", "content": correction}) + continue + raise ValueError(f"Generated tree failed validation: {'; '.join(errors)}") + + # Success + session.working_tree = tree + session.tree_metadata = metadata + session.current_phase = "generation" + session.updated_at = datetime.now(timezone.utc) + + return tree, metadata + + raise ValueError("AI failed to generate a valid tree") + + +async def get_chat_session( + session_id: uuid.UUID, + user_id: uuid.UUID, + db: AsyncSession, +) -> AIChatSession: + """Get a chat session, validating ownership and expiry. + + Raises HTTPException on not found, forbidden, or expired. + """ + from fastapi import HTTPException, status + + result = await db.execute( + select(AIChatSession).where(AIChatSession.id == session_id) + ) + session = result.scalar_one_or_none() + + if not session: + raise HTTPException(status_code=404, detail="Chat session not found") + + if session.user_id != user_id: + raise HTTPException(status_code=403, detail="Access denied") + + if session.expires_at < datetime.now(timezone.utc): + session.status = "abandoned" + await db.flush() + raise HTTPException(status_code=410, detail="Chat session has expired") + + return session diff --git a/backend/app/core/ai_provider.py b/backend/app/core/ai_provider.py index cb3f7178..993012c6 100644 --- a/backend/app/core/ai_provider.py +++ b/backend/app/core/ai_provider.py @@ -35,6 +35,25 @@ class AIProvider(ABC): """ ... + @abstractmethod + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + """Generate a text response from the AI model (no JSON constraint). + + Args: + system_prompt: System-level instruction for the model. + messages: List of message dicts with "role" and "content" keys. + max_tokens: Maximum output tokens. + + Returns: + Tuple of (response_text, input_tokens, output_tokens). + """ + ... + class GeminiProvider(AIProvider): """Google Gemini provider using the google-genai SDK.""" @@ -95,6 +114,56 @@ class GeminiProvider(AIProvider): return text, input_tokens, output_tokens + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + from google import genai + from google.genai import types as genai_types + + client = genai.Client(api_key=self._api_key) + + contents: list[genai_types.Content] = [] + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + contents.append( + genai_types.Content( + role=role, + parts=[genai_types.Part(text=msg["content"])], + ) + ) + + config = genai_types.GenerateContentConfig( + system_instruction=system_prompt, + max_output_tokens=max_tokens, + # No response_mime_type — allow free-form text + ) + + response = await client.aio.models.generate_content( + model=self._model, + contents=contents, + config=config, + ) + + if response.candidates: + finish_reason = getattr(response.candidates[0], "finish_reason", None) + logger.info("Gemini finish_reason=%s model=%s", finish_reason, self._model) + if str(finish_reason) == "MAX_TOKENS": + logger.warning( + "Gemini output truncated (MAX_TOKENS). max_output_tokens=%d", + max_tokens, + ) + + text = response.text or "" + input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = ( + getattr(response.usage_metadata, "candidates_token_count", 0) or 0 + ) + + return text, input_tokens, output_tokens + class AnthropicProvider(AIProvider): """Anthropic Claude provider using the anthropic SDK.""" @@ -130,6 +199,15 @@ class AnthropicProvider(AIProvider): return text, input_tokens, output_tokens + async def generate_text( + self, + system_prompt: str, + messages: list[dict[str, str]], + max_tokens: int = 4096, + ) -> tuple[str, int, int]: + # Anthropic doesn't differentiate between JSON and text mode + return await self.generate_json(system_prompt, messages, max_tokens) + def get_ai_provider() -> AIProvider: """Factory that returns the configured AI provider. diff --git a/backend/app/core/ai_quota_service.py b/backend/app/core/ai_quota_service.py index 67eed9e5..49264caa 100644 --- a/backend/app/core/ai_quota_service.py +++ b/backend/app/core/ai_quota_service.py @@ -115,7 +115,7 @@ async def check_ai_quota( select(func.count(AIUsage.id)).where( AIUsage.user_id == user_id, AIUsage.succeeded == True, # noqa: E712 - AIUsage.generation_type.in_(["scaffold", "branch_detail"]), + AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate"]), AIUsage.created_at >= day_start, ) ) or 0 diff --git a/backend/app/core/ai_tree_validator.py b/backend/app/core/ai_tree_validator.py index e57767b4..351a223f 100644 --- a/backend/app/core/ai_tree_validator.py +++ b/backend/app/core/ai_tree_validator.py @@ -40,7 +40,7 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: # Collect all node IDs and validate structure all_ids: set[str] = set() - all_referenced_ids: set[str] = set() # option next_node_ids (already checked locally) + all_referenced_ids: set[str] = set() # option next_node_ids (checked globally below) action_next_ids: set[str] = set() # action next_node_ids (checked globally below) node_count = 0 solution_count = 0 @@ -111,11 +111,6 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: next_id = opt.get("next_node_id") if next_id: all_referenced_ids.add(next_id) - if child_ids and next_id not in child_ids: - errors.append( - f"Option '{opt.get('label', '?')}' in node '{node_id}' " - f"references non-existent child '{next_id}'" - ) elif node_type == "action": next_id = node.get("next_node_id") @@ -144,6 +139,13 @@ def validate_generated_tree(tree: dict[str, Any]) -> list[str]: f"Action next_node_id '{ref_id}' references a node that does not exist in the tree" ) + # Check that all option next_node_ids exist in the tree (allows cross-references) + for ref_id in all_referenced_ids - action_next_ids: + if ref_id not in all_ids: + errors.append( + f"Option next_node_id '{ref_id}' references a node that does not exist in the tree" + ) + # Global checks if node_count < 5: errors.append( diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 3731740b..3748d9c8 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -28,6 +28,7 @@ from .maintenance_schedule import MaintenanceSchedule from .feedback import Feedback from .ai_conversation import AIConversation from .ai_usage import AIUsage +from .ai_chat_session import AIChatSession __all__ = [ "User", @@ -67,4 +68,5 @@ __all__ = [ "Feedback", "AIConversation", "AIUsage", + "AIChatSession", ] diff --git a/backend/app/models/ai_chat_session.py b/backend/app/models/ai_chat_session.py new file mode 100644 index 00000000..8fbde1c6 --- /dev/null +++ b/backend/app/models/ai_chat_session.py @@ -0,0 +1,88 @@ +"""AI Chat Builder session tracking. + +Stores conversational flow builder state across the multi-phase interview. +Sessions expire after 24 hours. +""" +import uuid +from datetime import datetime, timezone +from typing import Optional, Any + +from sqlalchemy import String, DateTime, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.dialects.postgresql import UUID, JSONB + +from app.core.database import Base + + +class AIChatSession(Base): + __tablename__ = "ai_chat_sessions" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + user_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + status: Mapped[str] = mapped_column( + String(20), + nullable=False, + default="active", + comment="active | completed | abandoned", + ) + current_phase: Mapped[str] = mapped_column( + String(20), + nullable=False, + default="scoping", + comment="scoping | discovery | enrichment | review | generation", + ) + flow_type: Mapped[str] = mapped_column( + String(20), + nullable=False, + comment="troubleshooting | procedural", + ) + conversation_history: Mapped[list[dict[str, Any]]] = mapped_column( + JSONB, nullable=False, default=list + ) + working_tree: Mapped[Optional[dict[str, Any]]] = mapped_column( + JSONB, nullable=True + ) + tree_metadata: Mapped[dict[str, Any]] = mapped_column( + JSONB, nullable=False, default=dict + ) + provider_used: Mapped[Optional[str]] = mapped_column( + String(20), nullable=True + ) + message_count: Mapped[int] = mapped_column( + Integer, nullable=False, default=0 + ) + total_input_tokens: Mapped[int] = mapped_column( + Integer, nullable=False, default=0 + ) + total_output_tokens: Mapped[int] = mapped_column( + Integer, nullable=False, default=0 + ) + generated_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("trees.id", ondelete="SET NULL"), + nullable=True, + ) + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) diff --git a/backend/app/schemas/ai_chat.py b/backend/app/schemas/ai_chat.py new file mode 100644 index 00000000..35ae66b7 --- /dev/null +++ b/backend/app/schemas/ai_chat.py @@ -0,0 +1,80 @@ +"""Pydantic schemas for the AI Chat Builder.""" +from typing import Any, Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +# ── Requests ── + + +class AIChatStartRequest(BaseModel): + """Start a new chat builder session.""" + + flow_type: Literal["troubleshooting", "procedural"] = Field( + ..., description="Type of flow to build" + ) + + +class AIChatMessageRequest(BaseModel): + """Send a user message in a chat session.""" + + content: str = Field(..., min_length=1, max_length=5000) + + +class AIChatImportRequest(BaseModel): + """Import generated tree with optional metadata overrides.""" + + name: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = Field(None, max_length=2000) + category_id: Optional[UUID] = None + tags: list[str] = Field(default_factory=list) + + +# ── Responses ── + + +class AIChatStartResponse(BaseModel): + """Response after creating a chat session.""" + + session_id: UUID + greeting: str + current_phase: str + + +class AIChatMessageResponse(BaseModel): + """Response after sending a message.""" + + content: str + current_phase: str + working_tree: Optional[dict[str, Any]] = None + tree_metadata: Optional[dict[str, Any]] = None + + +class AIChatSessionResponse(BaseModel): + """Full session state for resume.""" + + session_id: UUID + status: str + current_phase: str + flow_type: str + conversation_history: list[dict[str, Any]] + working_tree: Optional[dict[str, Any]] = None + tree_metadata: Optional[dict[str, Any]] = None + message_count: int + generated_tree: Optional[dict[str, Any]] = None + + +class AIChatGenerateResponse(BaseModel): + """Response with the final generated tree.""" + + tree_structure: dict[str, Any] + tree_metadata: dict[str, Any] + status: str + + +class AIChatImportResponse(BaseModel): + """Response after importing tree to editor.""" + + tree_id: UUID + tree_type: str diff --git a/backend/tests/test_ai_chat.py b/backend/tests/test_ai_chat.py new file mode 100644 index 00000000..0b52f650 --- /dev/null +++ b/backend/tests/test_ai_chat.py @@ -0,0 +1,187 @@ +"""Integration tests for AI Chat Builder endpoints. + +These tests mock the AI provider to avoid real API calls. +""" +import pytest +from unittest.mock import AsyncMock, patch + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_ai_provider(): + """Mock AI provider that returns realistic responses.""" + provider = AsyncMock() + provider.generate_text = AsyncMock(return_value=( + "Great question! Let's build a troubleshooting flow for DNS resolution issues. " + "To start, I need to understand the scope.\n\n" + "Who is the target audience for this flow? Are we targeting:\n" + "- Tier 1 help desk (basic checks only)\n" + "- Tier 2 desktop support (intermediate diagnostics)\n" + "- Tier 3 systems engineers (deep DNS troubleshooting)\n\n" + "[PHASE:scoping]", + 500, # input tokens + 200, # output tokens + )) + return provider + + +async def test_create_chat_session(client, auth_headers, mock_ai_provider): + """POST /ai/chat/sessions creates a session and returns AI greeting.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + + assert resp.status_code == 201 + data = resp.json() + assert "session_id" in data + assert "greeting" in data + assert data["current_phase"] == "scoping" + assert len(data["greeting"]) > 0 + + +async def test_send_message(client, auth_headers, mock_ai_provider): + """POST /ai/chat/sessions/{id}/messages returns AI response.""" + # Create session first + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + # Mock response with tree update — must pass validate_generated_tree (min 5 nodes) + import json + tree_obj = { + "id": "root", "type": "decision", + "question": "What DNS symptom is the user experiencing?", + "options": [ + {"id": "opt-1", "label": "Cannot resolve any domains", "next_node_id": "dns-check"}, + {"id": "opt-2", "label": "Intermittent failures", "next_node_id": "dns-cache-fix"}, + ], + "children": [ + { + "id": "dns-check", "type": "decision", + "question": "Is the DNS Client service running?", + "options": [ + {"id": "dc-1", "label": "Yes", "next_node_id": "dns-fwd-fix"}, + {"id": "dc-2", "label": "No", "next_node_id": "dns-svc-fix"}, + ], + "children": [ + {"id": "dns-fwd-fix", "type": "solution", "title": "Check DNS Forwarders", + "description": "DNS forwarders may be misconfigured", + "resolution_steps": ["Check forwarder config"]}, + {"id": "dns-svc-fix", "type": "solution", "title": "Restart DNS Service", + "description": "DNS Client service is stopped", + "resolution_steps": ["Start-Service Dnscache"]}, + ], + }, + {"id": "dns-cache-fix", "type": "solution", "title": "Stale DNS Cache", + "description": "DNS cache has stale entries", + "resolution_steps": ["ipconfig /flushdns"]}, + ], + } + tree_json = json.dumps(tree_obj) + mock_ai_provider.generate_text = AsyncMock(return_value=( + "Good, targeting Tier 2 support. Let's start with the first diagnostic question.\n\n" + "The root question should be: 'What DNS symptom is the user experiencing?'\n\n" + f"[TREE_UPDATE]\n{tree_json}\n[/TREE_UPDATE]\n\n" + "[PHASE:discovery]", + 800, + 400, + )) + + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + resp = await client.post( + f"/api/v1/ai/chat/sessions/{session_id}/messages", + json={"content": "This is for Tier 2 support, hybrid environment with on-prem AD."}, + headers=auth_headers, + ) + + assert resp.status_code == 200 + data = resp.json() + assert "content" in data + assert data["current_phase"] == "discovery" + assert data["working_tree"] is not None + assert data["working_tree"]["type"] == "decision" + # Markers should be stripped from content + assert "[TREE_UPDATE]" not in data["content"] + assert "[PHASE:" not in data["content"] + + +async def test_get_session(client, auth_headers, mock_ai_provider): + """GET /ai/chat/sessions/{id} returns full session state.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + resp = await client.get( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["session_id"] == session_id + assert data["status"] == "active" + assert data["flow_type"] == "troubleshooting" + # Hidden primer message should be filtered out + assert all( + msg.get("role") == "assistant" or not msg.get("hidden") + for msg in data["conversation_history"] + ) + + +async def test_abandon_session(client, auth_headers, mock_ai_provider): + """DELETE /ai/chat/sessions/{id} sets status to abandoned.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + resp = await client.delete( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + assert resp.status_code == 204 + + # Verify session is abandoned + get_resp = await client.get( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + assert get_resp.json()["status"] == "abandoned" + + +async def test_session_not_found(client, auth_headers): + """Accessing nonexistent session returns 404.""" + import uuid + fake_id = str(uuid.uuid4()) + resp = await client.get( + f"/api/v1/ai/chat/sessions/{fake_id}", + headers=auth_headers, + ) + assert resp.status_code == 404 + + +async def test_ai_disabled_returns_503(client, auth_headers): + """When AI is not configured, endpoints return 503.""" + with patch("app.api.endpoints.ai_chat.settings") as mock_settings: + mock_settings.ai_enabled = False + resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + assert resp.status_code == 503 diff --git a/backend/tests/test_ai_tree_validator.py b/backend/tests/test_ai_tree_validator.py index f8f3f4d7..cfa58ff0 100644 --- a/backend/tests/test_ai_tree_validator.py +++ b/backend/tests/test_ai_tree_validator.py @@ -122,7 +122,7 @@ class TestReferenceIntegrity: tree = _make_valid_tree() tree["options"][0]["next_node_id"] = "nonexistent" errors = validate_generated_tree(tree) - assert any("non-existent child" in e for e in errors) + assert any("does not exist" in e for e in errors) def test_action_next_node_id_references_nonexistent_node(self): """Action next_node_id pointing to a node that doesn't exist anywhere in the tree.""" @@ -188,6 +188,31 @@ class TestDeadEndDetection: assert any("dead end" in e for e in errors) +class TestCrossReferenceSupport: + def test_option_referencing_non_child_node_in_tree_is_valid(self): + """A decision option can reference any node in the tree, not just direct children.""" + tree = _make_valid_tree() + # Make root option point to a grandchild (not a direct child) — cross-reference + tree["options"][0]["next_node_id"] = "fix-errors" # grandchild of root + errors = validate_generated_tree(tree) + assert not any("non-existent child" in e for e in errors) + assert not any("does not exist" in e for e in errors) + + def test_option_referencing_nonexistent_node_still_fails(self): + """Cross-references must still point to nodes that exist in the tree.""" + tree = _make_valid_tree() + tree["options"][0]["next_node_id"] = "totally-fake-id" + errors = validate_generated_tree(tree) + assert any("does not exist" in e for e in errors) + + def test_action_next_node_id_to_ancestor_is_valid(self): + """Action node can loop back to an ancestor node.""" + tree = _make_valid_tree() + tree["children"][1]["next_node_id"] = "root" + errors = validate_generated_tree(tree) + assert not any("does not exist" in e for e in errors) + + class TestCountTreeStats: def test_stats_correct(self): tree = _make_valid_tree() diff --git a/docs/plans/2026-02-28-cross-reference-loopback-design.md b/docs/plans/2026-02-28-cross-reference-loopback-design.md new file mode 100644 index 00000000..98e6be32 --- /dev/null +++ b/docs/plans/2026-02-28-cross-reference-loopback-design.md @@ -0,0 +1,80 @@ +# Cross-Reference / Loop-Back Support — Design + +**Goal:** Allow tree nodes to reference any other node in the tree (not just direct children), enabling loop-back patterns like "remediate → re-verify from earlier checkpoint." + +**Architecture:** Ghost references on existing tree structure. No schema change, no migration. A cross-reference is any `next_node_id` that points outside the current node's `children` array. The canvas renders these as dashed SVG overlay arrows. Navigation already supports this. + +**Approach chosen:** Approach 1 — "Ghost references" (keep tree structure, add visual cross-ref edges) + +--- + +## 1. Data Model — No Changes + +The `TreeStructure` type and database stay as-is. The distinction is semantic: + +- **Local link:** `next_node_id` → direct child → normal tree edge +- **Cross-reference:** `next_node_id` → node elsewhere in tree → dashed overlay arrow + +No new fields, no new node types, no migration. + +## 2. Validation Changes + +### Backend (`ai_tree_validator.py`) + +- Relax decision option validation: `option.next_node_id` can reference any node in the tree (not just children). Check existence only, same as action nodes. + +### Frontend circular reference detector (`treeEditorStore.ts`) + +- Change loop detection from **error** to **warning**. Loops are now intentional. Warning text: "This path loops back to [node title]." + +### Frontend orphan detection + +- Keep as-is. Orphaned nodes still flagged as warnings. + +## 3. Canvas Rendering — Cross-Reference Edges + +- **SVG overlay** layer on top of the canvas (absolute positioned) +- **Dashed line** with **arrowhead** pointing at target node +- **Purple/primary color** to distinguish from normal gray tree connectors +- Small label on the arrow (option label or "loops back") +- After dagre layout, scan all nodes for `next_node_id` values not matching a direct child +- Look up source/target positions from layout, draw curved SVG bezier path +- Target node gets a subtle badge/indicator for inbound cross-references +- Hovering the badge highlights source nodes + +## 4. Editor UX — Creating Cross-References + +### A. Node picker dropdown (in node form) + +- Action nodes and decision option rows get "Link to existing node" dropdown +- Lists all nodes by title/question, grouped by type +- Selecting sets `next_node_id`; orphaned answer stubs cleaned up +- "Clear link" option to remove + +### B. Canvas drag-to-link + +- Small output port (dot) at bottom of each node +- Drag from port starts a dashed line following cursor +- Drop on any node creates cross-reference +- Drop on empty space cancels +- Existing answer stubs cleaned up if replaced + +### Visual feedback + +- Node form: "Linked to: [node title]" with navigate + remove actions +- Canvas: dashed arrow (Section 3) + +## 5. AI Flow Assist — Prompt Changes + +- Update system prompt STRUCTURAL RULES: "Action nodes can set `next_node_id` to any node in the tree, including ancestors, for loop-backs." +- Add SSH loop example to schema context +- No changes to generation or progressive validation + +## 6. Navigation — No Changes + +`findNode` already searches the full tree. `handleSelectOption` and `handleContinue` follow `next_node_id` without hierarchy checks. Session `pathTaken` will contain repeated IDs for loops — this is correct behavior. + +## 7. Testing + +- Backend: extend validator tests for cross-references +- Frontend: `npm run build` after each piece, manual testing of editor + navigation loops diff --git a/docs/plans/2026-02-28-cross-reference-loopback.md b/docs/plans/2026-02-28-cross-reference-loopback.md new file mode 100644 index 00000000..1b516289 --- /dev/null +++ b/docs/plans/2026-02-28-cross-reference-loopback.md @@ -0,0 +1,656 @@ +# Cross-Reference / Loop-Back Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Enable tree nodes to reference any other node in the tree (not just direct children), supporting loop-back patterns like "remediate → re-verify from earlier checkpoint." + +**Architecture:** Ghost references on existing tree structure. No schema change, no migration. A cross-reference is any `next_node_id` that points outside the current node's `children` array. The canvas renders these as dashed overlay arrows. Navigation already supports this pattern. + +**Tech Stack:** Python FastAPI (backend validation), React + @xyflow/react (canvas rendering), Zustand (store validation), TypeScript + +**Design Doc:** `docs/plans/2026-02-28-cross-reference-loopback-design.md` + +--- + +## Task 1: Backend — Relax Decision Option Validation + +Allow decision option `next_node_id` to reference ANY node in the tree, not just direct children. + +**Files:** +- Modify: `backend/app/core/ai_tree_validator.py:111-118` +- Test: `backend/tests/test_ai_tree_validator.py` + +**Step 1: Write the failing test — cross-reference option passes validation** + +Add a new test class `TestCrossReferenceSupport` at the bottom of `test_ai_tree_validator.py`: + +```python +class TestCrossReferenceSupport: + def test_option_referencing_non_child_node_in_tree_is_valid(self): + """A decision option can reference any node in the tree, not just direct children.""" + tree = _make_valid_tree() + # Make root option point to a grandchild (not a direct child) — cross-reference + tree["options"][0]["next_node_id"] = "fix-errors" # grandchild of root + errors = validate_generated_tree(tree) + # Should NOT have the "non-existent child" error for this reference + assert not any("non-existent child" in e for e in errors) + + def test_option_referencing_nonexistent_node_still_fails(self): + """Cross-references must still point to nodes that exist in the tree.""" + tree = _make_valid_tree() + tree["options"][0]["next_node_id"] = "totally-fake-id" + errors = validate_generated_tree(tree) + assert any("does not exist" in e for e in errors) + + def test_action_next_node_id_to_ancestor_is_valid(self): + """Action node can loop back to an ancestor node (the whole point of cross-refs).""" + tree = _make_valid_tree() + # Make the action node loop back to root + tree["children"][1]["next_node_id"] = "root" + errors = validate_generated_tree(tree) + assert not any("does not exist" in e for e in errors) +``` + +**Step 2: Run the tests to verify they fail** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py::TestCrossReferenceSupport -v` +Expected: `test_option_referencing_non_child_node_in_tree_is_valid` FAILS (currently raises "non-existent child" error). The other two should already pass. + +**Step 3: Update validator — check global existence, not just children** + +In `backend/app/core/ai_tree_validator.py`, replace lines 111-118 (the decision option next_node_id check): + +Old code (lines 111-118): +```python + next_id = opt.get("next_node_id") + if next_id: + all_referenced_ids.add(next_id) + if child_ids and next_id not in child_ids: + errors.append( + f"Option '{opt.get('label', '?')}' in node '{node_id}' " + f"references non-existent child '{next_id}'" + ) +``` + +New code: +```python + next_id = opt.get("next_node_id") + if next_id: + all_referenced_ids.add(next_id) +``` + +Then add a new global check after line 145 (after the action next_node_id existence check). This checks ALL option references exist anywhere in the tree: + +After the existing `for ref_id in action_next_ids:` block, add: +```python + # Check that all option next_node_ids exist in the tree (allows cross-references) + for ref_id in all_referenced_ids: + if ref_id not in all_ids: + errors.append( + f"Option next_node_id '{ref_id}' references a node that does not exist in the tree" + ) +``` + +**Step 4: Run all validator tests to verify they pass** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: ALL tests pass. The old `test_option_references_nonexistent_child` test in `TestReferenceIntegrity` will now fail because the error message changed from "non-existent child" to "does not exist in the tree". Update that test: + +In `TestReferenceIntegrity.test_option_references_nonexistent_child`, change: +```python + def test_option_references_nonexistent_child(self): + tree = _make_valid_tree() + tree["options"][0]["next_node_id"] = "nonexistent" + errors = validate_generated_tree(tree) + assert any("does not exist" in e for e in errors) +``` + +Run again: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: ALL PASS + +**Step 5: Commit** + +```bash +git add backend/app/core/ai_tree_validator.py backend/tests/test_ai_tree_validator.py +git commit -m "feat: relax decision option validation — allow cross-references to any node in tree" +``` + +--- + +## Task 2: Frontend — Change Circular Reference Detection From Error to Warning + +Loop-backs are now intentional. The circular reference detector should warn instead of error. + +**Files:** +- Modify: `frontend/src/store/treeEditorStore.ts:791-824` + +**Step 1: Update severity from 'error' to 'warning' and improve messages** + +In `frontend/src/store/treeEditorStore.ts`, find the `detectCircularRefs` function (lines 792-824). Change both `severity: 'error'` to `severity: 'warning'` and update messages: + +Replace line 803-807: +```typescript + errors.push({ + nodeId: node.id, + message: `Circular reference detected: "${opt.label}" creates a loop`, + severity: 'error' + }) +``` +With: +```typescript + errors.push({ + nodeId: node.id, + message: `This path loops back to an earlier node via "${opt.label}"`, + severity: 'warning' + }) +``` + +Replace lines 815-819: +```typescript + errors.push({ + nodeId: node.id, + message: `Circular reference detected in node "${node.title || node.id}"`, + severity: 'error' + }) +``` +With: +```typescript + errors.push({ + nodeId: node.id, + message: `This node loops back to an earlier node ("${node.title || node.id}")`, + severity: 'warning' + }) +``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build succeeds with no errors. + +**Step 3: Commit** + +```bash +git add frontend/src/store/treeEditorStore.ts +git commit -m "feat: change circular reference detection from error to warning — loops are intentional" +``` + +--- + +## Task 3: Canvas — Render Cross-Reference Edges as Dashed Arrows + +Add dashed purple overlay edges for any `next_node_id` pointing outside the current node's children. + +**Files:** +- Modify: `frontend/src/components/tree-editor/useTreeLayout.ts` + +**Step 1: Add cross-reference edge collection to the `walk` function** + +In `useTreeLayout.ts`, inside the `useMemo` callback (line 57), after the `walk(treeStructure, null)` call on line 129, add a second pass to collect cross-reference edges. + +Add this helper function before the `useTreeLayout` export (around line 40): + +```typescript +/** Collect all node IDs in the tree. */ +function collectAllIds(root: TreeStructure): Set { + const ids = new Set() + function walk(node: TreeStructure) { + ids.add(node.id) + node.children?.forEach(walk) + } + walk(root) + return ids +} + +/** Find all cross-reference edges (next_node_id pointing outside children). */ +function collectCrossRefEdges(root: TreeStructure): Array<{ source: string; target: string; label?: string }> { + const refs: Array<{ source: string; target: string; label?: string }> = [] + const allIds = collectAllIds(root) + + function walk(node: TreeStructure) { + const childIds = new Set(node.children?.map(c => c.id) ?? []) + + // Decision options pointing outside children + if (node.type === 'decision' && node.options) { + for (const opt of node.options) { + if (opt.next_node_id && !childIds.has(opt.next_node_id) && allIds.has(opt.next_node_id)) { + refs.push({ source: node.id, target: opt.next_node_id, label: opt.label }) + } + } + } + + // Action next_node_id pointing to non-child (always a cross-ref since actions use next_node_id not children) + if (node.type === 'action' && node.next_node_id && allIds.has(node.next_node_id) && !childIds.has(node.next_node_id)) { + refs.push({ source: node.id, target: node.next_node_id, label: 'loops back' }) + } + + node.children?.forEach(walk) + } + + walk(root) + return refs +} +``` + +**Step 2: Add cross-reference edges to the edges array** + +In the `useMemo` callback, after `walk(treeStructure, null)` (line 129) and before the return (line 131), add: + +```typescript + // Add cross-reference edges (dashed, purple) + if (treeStructure) { + const crossRefs = collectCrossRefEdges(treeStructure) + for (const ref of crossRefs) { + // Only add if both source and target nodes are visible (not collapsed away) + const sourceVisible = nodes.some(n => n.id === ref.source) + const targetVisible = nodes.some(n => n.id === ref.target) + if (sourceVisible && targetVisible) { + edges.push({ + id: `xref-${ref.source}->${ref.target}`, + source: ref.source, + target: ref.target, + type: 'smoothstep', + animated: true, + label: ref.label ? truncateLabel(ref.label) : undefined, + labelStyle: { fill: 'hsl(var(--primary))', fontSize: 10, fontWeight: 500 }, + labelBgStyle: { fill: 'hsl(var(--card))', fillOpacity: 0.95 }, + labelBgPadding: [4, 2] as [number, number], + style: { + stroke: 'hsl(var(--primary))', + strokeWidth: 2, + strokeDasharray: '6 3', + }, + markerEnd: { + type: 'arrowclosed' as const, + color: 'hsl(var(--primary))', + width: 16, + height: 16, + }, + }) + } + } + } +``` + +**Step 3: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build succeeds. Cross-reference edges render as animated, dashed purple arrows with arrowheads. + +**Step 4: Commit** + +```bash +git add frontend/src/components/tree-editor/useTreeLayout.ts +git commit -m "feat: render cross-reference edges as dashed purple arrows on canvas" +``` + +--- + +## Task 4: Editor UX — Node Picker Dropdown for Action Nodes + +Add a "Link to existing node" dropdown to `NodeFormAction.tsx`. + +**Files:** +- Modify: `frontend/src/components/tree-editor/NodeFormAction.tsx` +- Modify: `frontend/src/store/treeEditorStore.ts` (add helper to collect all nodes) + +**Step 1: Add `collectAllNodes` helper to the tree editor store** + +In `frontend/src/store/treeEditorStore.ts`, add a standalone exported helper function (near the top of the file, after imports, or as a utility): + +Find the `findNodeInTree` helper function. Near it, add: + +```typescript +/** Collect all nodes in the tree as a flat list with depth info. */ +export function collectAllNodesFlat( + root: TreeStructure | null +): Array<{ id: string; label: string; type: string; depth: number }> { + if (!root) return [] + const result: Array<{ id: string; label: string; type: string; depth: number }> = [] + + function walk(node: TreeStructure, depth: number) { + const label = node.type === 'decision' + ? (node.question || 'Untitled Decision') + : (node.title || `Untitled ${node.type}`) + result.push({ id: node.id, label, type: node.type, depth }) + node.children?.forEach(child => walk(child, depth + 1)) + } + + walk(root, 0) + return result +} +``` + +**Step 2: Update NodeFormAction to include the node picker** + +Replace the "Next step hint" section at the bottom of `NodeFormAction.tsx` (lines 161-170) with a full node picker: + +```tsx +import { Link2, X } from 'lucide-react' +import { collectAllNodesFlat } from '@/store/treeEditorStore' +``` + +(Add these to existing imports at top of file) + +Replace lines 161-170 (the `{hasNextNode ? ... : ...}` block): + +```tsx + {/* Link to existing node */} +
+ + {hasNextNode ? ( +
+ + Linked to: {(() => { + const treeStructure = useTreeEditorStore.getState().treeStructure + const allNodes = collectAllNodesFlat(treeStructure) + const target = allNodes.find(n => n.id === node.next_node_id) + return target ? target.label : node.next_node_id + })()} + + +
+ ) : ( + + )} +

+ {hasNextNode + ? 'This action will navigate to the linked node.' + : 'Select a node to navigate to after this action, or save to create a new placeholder.'} +

+
+``` + +**Step 3: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build succeeds. + +**Step 4: Commit** + +```bash +git add frontend/src/components/tree-editor/NodeFormAction.tsx frontend/src/store/treeEditorStore.ts +git commit -m "feat: add node picker dropdown to action node form for cross-references" +``` + +--- + +## Task 5: Editor UX — Node Picker for Decision Option Rows + +Add "Link to existing node" capability to each decision option row. + +**Files:** +- Modify: `frontend/src/components/tree-editor/NodeFormDecision.tsx` + +**Step 1: Add link icon and dropdown per option row** + +Add imports at top of `NodeFormDecision.tsx`: +```tsx +import { Link2 } from 'lucide-react' +import { collectAllNodesFlat } from '@/store/treeEditorStore' +``` + +In the option render callback (inside `renderItem` around line 161), after the label input and its error message, add a cross-reference link indicator per option. After the closing `` of the `flex-1` wrapper (around line 197), add: + +```tsx + {/* Cross-reference link indicator */} + {option.next_node_id && (() => { + const treeStructure = useTreeEditorStore.getState().treeStructure + const childIds = new Set(node.children?.map(c => c.id) ?? []) + // Only show if it's a cross-reference (points outside children) + if (childIds.has(option.next_node_id)) return null + const allNodes = collectAllNodesFlat(treeStructure) + const target = allNodes.find(n => n.id === option.next_node_id) + if (!target) return null + return ( +
+ +
+ ) + })()} +``` + +**Step 2: Add "Link to node" option below the options list** + +After the `DynamicArrayField` closing tag (line 201, before the root tip), add: + +```tsx + {/* Quick-link: assign an option to an existing node */} +
+ + + Link an option to an existing node (cross-reference) + +
+

+ Select an option, then pick a target node. This creates a loop-back or cross-reference. +

+
+ + +
+
+
+``` + +**Step 2: Build to verify** + +Run: `cd frontend && npm run build` +Expected: Build succeeds. + +**Step 3: Commit** + +```bash +git add frontend/src/components/tree-editor/NodeFormDecision.tsx +git commit -m "feat: add cross-reference node picker to decision option rows" +``` + +--- + +## Task 6: AI System Prompt — Update Structural Rules for Cross-References + +Update the AI chat system prompt to allow and encourage loop-back patterns. + +**Files:** +- Modify: `backend/app/core/ai_chat_service.py:68-75` + +**Step 1: Update STRUCTURAL RULES in SCHEMA_CONTEXT** + +Replace the STRUCTURAL RULES section (lines 68-75 of `ai_chat_service.py`): + +Old: +```python +STRUCTURAL RULES: +- Root node MUST be type "decision" +- Decision nodes contain their children in the "children" array +- Each decision option's next_node_id must reference a child node's id +- Action nodes use next_node_id to chain to the next step (NOT children) +- Solution nodes are terminal — no next_node_id or children +- All IDs must be unique strings (use descriptive slugs like "check-service-status") +``` + +New: +```python +STRUCTURAL RULES: +- Root node MUST be type "decision" +- Decision nodes contain their children in the "children" array +- Each decision option's next_node_id typically references a child node's id, BUT can also reference ANY other node in the tree for loop-back / re-verification patterns +- Action nodes use next_node_id to chain to the next step — this can point to any node in the tree, including ancestors, for loop-backs (e.g., "remediate → re-verify from earlier checkpoint") +- Solution nodes are terminal — no next_node_id or children +- All IDs must be unique strings (use descriptive slugs like "check-service-status") + +CROSS-REFERENCE / LOOP-BACK PATTERN: +When a troubleshooting path needs to loop back (e.g., after remediation, re-verify from an earlier checkpoint), set next_node_id to the target node's ID. Example: an action node "restart-ssh-service" can set next_node_id to "verify-ssh-connection" (an ancestor decision node) to create a re-verification loop. +``` + +**Step 2: Build backend to verify syntax** + +Run: `cd backend && python -c "from app.core.ai_chat_service import SCHEMA_CONTEXT; print('OK')"` +Expected: Prints "OK" with no import errors. + +**Step 3: Commit** + +```bash +git add backend/app/core/ai_chat_service.py +git commit -m "feat: update AI system prompt to allow cross-reference loop-back patterns" +``` + +--- + +## Task 7: Backend — Update Option Validation Error for `all_referenced_ids` + +The `all_referenced_ids` set currently holds only option `next_node_id` values. After Task 1's change, the global existence check also needs to handle the case where `action_next_ids` and `all_referenced_ids` may overlap. + +**Files:** +- Modify: `backend/app/core/ai_tree_validator.py` +- Test: `backend/tests/test_ai_tree_validator.py` + +**Step 1: Verify no double-counting between action and option refs** + +Check: `action_next_ids` are added to `all_referenced_ids` on line 128. After Task 1, we added a global check for `all_referenced_ids`. This means action refs get checked twice — once in the action-specific loop (lines 141-145) and once in the new option loop. We should only check option refs in the new loop. + +Update the global check added in Task 1 to exclude action refs: + +```python + # Check that all option next_node_ids exist in the tree (allows cross-references) + for ref_id in all_referenced_ids - action_next_ids: + if ref_id not in all_ids: + errors.append( + f"Option next_node_id '{ref_id}' references a node that does not exist in the tree" + ) +``` + +**Step 2: Run all validator tests** + +Run: `cd backend && python -m pytest tests/test_ai_tree_validator.py -v` +Expected: ALL PASS + +**Step 3: Commit** + +```bash +git add backend/app/core/ai_tree_validator.py +git commit -m "fix: prevent double-counting action refs in global option cross-reference check" +``` + +--- + +## Task 8: Full Integration Test + +Run the full backend test suite and frontend build to verify nothing is broken. + +**Files:** (none — testing only) + +**Step 1: Run backend tests** + +Run: `cd backend && python -m pytest --override-ini="addopts=" -v` +Expected: ALL PASS + +**Step 2: Run frontend build** + +Run: `cd frontend && npm run build` +Expected: Build succeeds with no errors. + +**Step 3: Manual smoke test** + +1. Start backend: `cd backend && uvicorn app.main:app --reload` +2. Start frontend: `cd frontend && npm run dev` +3. Open tree editor with an existing tree +4. Edit an action node → verify "Next Step" dropdown appears with all nodes listed +5. Select a node from a different branch → verify dashed purple arrow appears on canvas +6. Edit a decision node → expand "Link an option to an existing node" → create a cross-reference +7. Verify circular reference warning (not error) appears in validation panel +8. Navigate the tree → verify loop-back works (session follows `next_node_id`) + +--- + +## Summary + +| Task | What | Files | +|------|------|-------| +| 1 | Backend: relax option validation | `ai_tree_validator.py`, `test_ai_tree_validator.py` | +| 2 | Frontend: circular ref → warning | `treeEditorStore.ts` | +| 3 | Canvas: dashed purple cross-ref edges | `useTreeLayout.ts` | +| 4 | Editor: action node picker | `NodeFormAction.tsx`, `treeEditorStore.ts` | +| 5 | Editor: decision option picker | `NodeFormDecision.tsx` | +| 6 | AI prompt: loop-back awareness | `ai_chat_service.py` | +| 7 | Backend: fix ref overlap check | `ai_tree_validator.py` | +| 8 | Integration test | (testing only) | diff --git a/frontend/src/api/aiChat.ts b/frontend/src/api/aiChat.ts new file mode 100644 index 00000000..1722b272 --- /dev/null +++ b/frontend/src/api/aiChat.ts @@ -0,0 +1,44 @@ +import { apiClient } from './client' +import type { + AIChatStartResponse, + AIChatMessageResponse, + AIChatSessionResponse, + AIChatGenerateResponse, + AIChatImportResponse, +} from '@/types' + +export const aiChatApi = { + startSession: async (flowType: 'troubleshooting' | 'procedural'): Promise => { + const { data } = await apiClient.post('/ai/chat/sessions', { flow_type: flowType }) + return data + }, + + sendMessage: async (sessionId: string, content: string): Promise => { + const { data } = await apiClient.post(`/ai/chat/sessions/${sessionId}/messages`, { content }) + return data + }, + + getSession: async (sessionId: string): Promise => { + const { data } = await apiClient.get(`/ai/chat/sessions/${sessionId}`) + return data + }, + + generateTree: async (sessionId: string): Promise => { + const { data } = await apiClient.post(`/ai/chat/sessions/${sessionId}/generate`) + return data + }, + + importTree: async ( + sessionId: string, + params?: { name?: string; description?: string; category_id?: string; tags?: string[] } + ): Promise => { + const { data } = await apiClient.post(`/ai/chat/sessions/${sessionId}/import`, params || {}) + return data + }, + + abandonSession: async (sessionId: string): Promise => { + await apiClient.delete(`/ai/chat/sessions/${sessionId}`) + }, +} + +export default aiChatApi diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index fba65d24..016efcbe 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -17,3 +17,4 @@ export { targetListsApi } from './targetLists' export { maintenanceSchedulesApi, batchLaunchApi } from './maintenanceSchedules' export { default as feedbackApi } from './feedback' export { default as aiBuilderApi } from './aiBuilder' +export { default as aiChatApi } from './aiChat' diff --git a/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx b/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx index 49013055..77d84f7b 100644 --- a/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx +++ b/frontend/src/components/ai-builder/AIFlowBuilderModal.tsx @@ -98,7 +98,7 @@ export function AIFlowBuilderModal({ isOpen, onClose }: AIFlowBuilderModalProps) const getTitle = () => { switch (phase) { case 'foundation': - return 'Build with AI' + return 'Flow Assist' case 'scaffolding': case 'generating': return 'AI Scaffold' @@ -107,9 +107,9 @@ export function AIFlowBuilderModal({ isOpen, onClose }: AIFlowBuilderModalProps) case 'reviewing': return 'Review & Assemble' case 'error': - return 'AI Flow Builder' + return 'Flow Assist' default: - return 'Build with AI' + return 'Flow Assist' } } diff --git a/frontend/src/components/ai-chat/ChatInput.tsx b/frontend/src/components/ai-chat/ChatInput.tsx new file mode 100644 index 00000000..c10f4970 --- /dev/null +++ b/frontend/src/components/ai-chat/ChatInput.tsx @@ -0,0 +1,72 @@ +import { useState, useRef, useCallback } from 'react' +import { Send } from 'lucide-react' +import { cn } from '@/lib/utils' + +interface ChatInputProps { + onSend: (content: string) => void + disabled?: boolean + placeholder?: string +} + +export function ChatInput({ onSend, disabled, placeholder = 'Type a message...' }: ChatInputProps) { + const [value, setValue] = useState('') + const textareaRef = useRef(null) + + const handleSend = useCallback(() => { + const trimmed = value.trim() + if (!trimmed || disabled) return + onSend(trimmed) + setValue('') + if (textareaRef.current) { + textareaRef.current.style.height = 'auto' + } + }, [value, disabled, onSend]) + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault() + handleSend() + } + } + + const handleInput = () => { + if (textareaRef.current) { + textareaRef.current.style.height = 'auto' + textareaRef.current.style.height = Math.min(textareaRef.current.scrollHeight, 160) + 'px' + } + } + + return ( +
+