diff --git a/backend/app/api/endpoints/ai_chat.py b/backend/app/api/endpoints/ai_chat.py index b95fa1bd..e306ca8a 100644 --- a/backend/app/api/endpoints/ai_chat.py +++ b/backend/app/api/endpoints/ai_chat.py @@ -95,6 +95,7 @@ async def create_session( user_id=current_user.id, account_id=current_user.account_id, db=db, + tree_id=data.tree_id, ) except Exception as e: logger.exception("AI chat session start failed: %s", e) @@ -168,7 +169,9 @@ async def post_message( try: ai_content, tree_update, new_phase, metadata = await send_message( - session, data.content, db + session, data.content, db, + action_type=data.action_type or "open_chat", + focal_node_id=data.focal_node_id, ) except Exception as e: logger.exception("AI chat message failed: %s", e) diff --git a/backend/app/core/ai_chat_service.py b/backend/app/core/ai_chat_service.py index cab333f2..09ae57ee 100644 --- a/backend/app/core/ai_chat_service.py +++ b/backend/app/core/ai_chat_service.py @@ -391,6 +391,7 @@ async def start_chat_session( user_id: uuid.UUID, account_id: uuid.UUID, db: AsyncSession, + tree_id: str | None = None, ) -> tuple[AIChatSession, str]: """Create a chat session and return the AI's opening greeting. @@ -400,6 +401,7 @@ async def start_chat_session( user_id=user_id, account_id=account_id, flow_type=flow_type, + tree_id=uuid.UUID(tree_id) if tree_id else None, expires_at=datetime.now(timezone.utc) + timedelta(hours=settings.AI_CONVERSATION_TTL_HOURS), ) db.add(session) @@ -443,6 +445,8 @@ async def send_message( session: AIChatSession, user_message: str, db: AsyncSession, + action_type: str = "open_chat", + focal_node_id: str | None = None, ) -> tuple[str, Optional[dict], Optional[str], Optional[dict]]: """Send a user message and get AI response. diff --git a/backend/app/schemas/ai_chat.py b/backend/app/schemas/ai_chat.py index 35ae66b7..761f4064 100644 --- a/backend/app/schemas/ai_chat.py +++ b/backend/app/schemas/ai_chat.py @@ -14,12 +14,35 @@ class AIChatStartRequest(BaseModel): flow_type: Literal["troubleshooting", "procedural"] = Field( ..., description="Type of flow to build" ) + tree_id: Optional[str] = Field( + default=None, + description="ID of existing tree for editor-embedded sessions", + ) + + +VALID_ACTION_TYPES = Literal[ + "generate_full", + "generate_branch", + "modify_node", + "add_steps", + "quick_action", + "open_chat", + "variable_inference", +] class AIChatMessageRequest(BaseModel): """Send a user message in a chat session.""" content: str = Field(..., min_length=1, max_length=5000) + action_type: Optional[VALID_ACTION_TYPES] = Field( + default="open_chat", + description="Type of AI action to perform", + ) + focal_node_id: Optional[str] = Field( + default=None, + description="ID of the node/step being acted on", + ) class AIChatImportRequest(BaseModel):