diff --git a/backend/app/api/endpoints/ai_chat.py b/backend/app/api/endpoints/ai_chat.py index 326290d9..a1b8c648 100644 --- a/backend/app/api/endpoints/ai_chat.py +++ b/backend/app/api/endpoints/ai_chat.py @@ -106,7 +106,7 @@ async def create_session( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=session.total_input_tokens, @@ -118,7 +118,7 @@ async def create_session( succeeded=True, counts_toward_quota=False, error_code=None, - extra_data={"phase": "scoping"}, + extra_data={"phase": "scoping", "chat_session_id": str(session.id)}, db=db, ) @@ -175,7 +175,7 @@ async def post_message( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=0, @@ -184,7 +184,7 @@ async def post_message( succeeded=False, counts_toward_quota=False, error_code=type(e).__name__, - extra_data=None, + extra_data={"chat_session_id": str(session.id)}, db=db, ) await db.commit() @@ -198,7 +198,7 @@ async def post_message( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_message", tier=plan, input_tokens=input_delta, @@ -210,7 +210,7 @@ async def post_message( succeeded=True, counts_toward_quota=False, error_code=None, - extra_data={"phase": session.current_phase}, + extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)}, db=db, ) @@ -288,7 +288,7 @@ async def generate_tree( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=session.total_input_tokens - prev_input, @@ -297,7 +297,7 @@ async def generate_tree( succeeded=False, counts_toward_quota=False, error_code="invalid_output", - extra_data={"error": str(e)}, + extra_data={"error": str(e), "chat_session_id": str(session.id)}, db=db, ) await db.commit() @@ -312,7 +312,7 @@ async def generate_tree( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=input_delta, @@ -321,7 +321,7 @@ async def generate_tree( succeeded=False, counts_toward_quota=False, error_code=type(e).__name__, - extra_data={"error": str(e)}, + extra_data={"error": str(e), "chat_session_id": str(session.id)}, db=db, ) await db.commit() @@ -342,7 +342,7 @@ async def generate_tree( await record_ai_usage( user_id=current_user.id, account_id=current_user.account_id, - conversation_id=session.id, + conversation_id=None, generation_type="chat_generate", tier=plan, input_tokens=input_delta, @@ -354,7 +354,7 @@ async def generate_tree( succeeded=True, counts_toward_quota=True, error_code=None, - extra_data=None, + extra_data={"chat_session_id": str(session.id)}, db=db, ) diff --git a/backend/tests/test_ai_chat.py b/backend/tests/test_ai_chat.py new file mode 100644 index 00000000..0b52f650 --- /dev/null +++ b/backend/tests/test_ai_chat.py @@ -0,0 +1,187 @@ +"""Integration tests for AI Chat Builder endpoints. + +These tests mock the AI provider to avoid real API calls. +""" +import pytest +from unittest.mock import AsyncMock, patch + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_ai_provider(): + """Mock AI provider that returns realistic responses.""" + provider = AsyncMock() + provider.generate_text = AsyncMock(return_value=( + "Great question! Let's build a troubleshooting flow for DNS resolution issues. " + "To start, I need to understand the scope.\n\n" + "Who is the target audience for this flow? Are we targeting:\n" + "- Tier 1 help desk (basic checks only)\n" + "- Tier 2 desktop support (intermediate diagnostics)\n" + "- Tier 3 systems engineers (deep DNS troubleshooting)\n\n" + "[PHASE:scoping]", + 500, # input tokens + 200, # output tokens + )) + return provider + + +async def test_create_chat_session(client, auth_headers, mock_ai_provider): + """POST /ai/chat/sessions creates a session and returns AI greeting.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + + assert resp.status_code == 201 + data = resp.json() + assert "session_id" in data + assert "greeting" in data + assert data["current_phase"] == "scoping" + assert len(data["greeting"]) > 0 + + +async def test_send_message(client, auth_headers, mock_ai_provider): + """POST /ai/chat/sessions/{id}/messages returns AI response.""" + # Create session first + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + # Mock response with tree update — must pass validate_generated_tree (min 5 nodes) + import json + tree_obj = { + "id": "root", "type": "decision", + "question": "What DNS symptom is the user experiencing?", + "options": [ + {"id": "opt-1", "label": "Cannot resolve any domains", "next_node_id": "dns-check"}, + {"id": "opt-2", "label": "Intermittent failures", "next_node_id": "dns-cache-fix"}, + ], + "children": [ + { + "id": "dns-check", "type": "decision", + "question": "Is the DNS Client service running?", + "options": [ + {"id": "dc-1", "label": "Yes", "next_node_id": "dns-fwd-fix"}, + {"id": "dc-2", "label": "No", "next_node_id": "dns-svc-fix"}, + ], + "children": [ + {"id": "dns-fwd-fix", "type": "solution", "title": "Check DNS Forwarders", + "description": "DNS forwarders may be misconfigured", + "resolution_steps": ["Check forwarder config"]}, + {"id": "dns-svc-fix", "type": "solution", "title": "Restart DNS Service", + "description": "DNS Client service is stopped", + "resolution_steps": ["Start-Service Dnscache"]}, + ], + }, + {"id": "dns-cache-fix", "type": "solution", "title": "Stale DNS Cache", + "description": "DNS cache has stale entries", + "resolution_steps": ["ipconfig /flushdns"]}, + ], + } + tree_json = json.dumps(tree_obj) + mock_ai_provider.generate_text = AsyncMock(return_value=( + "Good, targeting Tier 2 support. Let's start with the first diagnostic question.\n\n" + "The root question should be: 'What DNS symptom is the user experiencing?'\n\n" + f"[TREE_UPDATE]\n{tree_json}\n[/TREE_UPDATE]\n\n" + "[PHASE:discovery]", + 800, + 400, + )) + + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + resp = await client.post( + f"/api/v1/ai/chat/sessions/{session_id}/messages", + json={"content": "This is for Tier 2 support, hybrid environment with on-prem AD."}, + headers=auth_headers, + ) + + assert resp.status_code == 200 + data = resp.json() + assert "content" in data + assert data["current_phase"] == "discovery" + assert data["working_tree"] is not None + assert data["working_tree"]["type"] == "decision" + # Markers should be stripped from content + assert "[TREE_UPDATE]" not in data["content"] + assert "[PHASE:" not in data["content"] + + +async def test_get_session(client, auth_headers, mock_ai_provider): + """GET /ai/chat/sessions/{id} returns full session state.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + resp = await client.get( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["session_id"] == session_id + assert data["status"] == "active" + assert data["flow_type"] == "troubleshooting" + # Hidden primer message should be filtered out + assert all( + msg.get("role") == "assistant" or not msg.get("hidden") + for msg in data["conversation_history"] + ) + + +async def test_abandon_session(client, auth_headers, mock_ai_provider): + """DELETE /ai/chat/sessions/{id} sets status to abandoned.""" + with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider): + create_resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + session_id = create_resp.json()["session_id"] + + resp = await client.delete( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + assert resp.status_code == 204 + + # Verify session is abandoned + get_resp = await client.get( + f"/api/v1/ai/chat/sessions/{session_id}", + headers=auth_headers, + ) + assert get_resp.json()["status"] == "abandoned" + + +async def test_session_not_found(client, auth_headers): + """Accessing nonexistent session returns 404.""" + import uuid + fake_id = str(uuid.uuid4()) + resp = await client.get( + f"/api/v1/ai/chat/sessions/{fake_id}", + headers=auth_headers, + ) + assert resp.status_code == 404 + + +async def test_ai_disabled_returns_503(client, auth_headers): + """When AI is not configured, endpoints return 503.""" + with patch("app.api.endpoints.ai_chat.settings") as mock_settings: + mock_settings.ai_enabled = False + resp = await client.post( + "/api/v1/ai/chat/sessions", + json={"flow_type": "troubleshooting"}, + headers=auth_headers, + ) + assert resp.status_code == 503