feat: add backend tests for AI chat builder + fix conversation_id FK issue
Tests cover session create, send message with tree update, get session, abandon, 404 on missing session, and 503 when AI disabled. Fixed: ai_usage.conversation_id has FK to ai_conversations, not ai_chat_sessions. Chat builder now passes conversation_id=None and tracks session reference in extra_data.chat_session_id. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -106,7 +106,7 @@ async def create_session(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_message",
|
||||
tier=plan,
|
||||
input_tokens=session.total_input_tokens,
|
||||
@@ -118,7 +118,7 @@ async def create_session(
|
||||
succeeded=True,
|
||||
counts_toward_quota=False,
|
||||
error_code=None,
|
||||
extra_data={"phase": "scoping"},
|
||||
extra_data={"phase": "scoping", "chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ async def post_message(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_message",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
@@ -184,7 +184,7 @@ async def post_message(
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data=None,
|
||||
extra_data={"chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
@@ -198,7 +198,7 @@ async def post_message(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_message",
|
||||
tier=plan,
|
||||
input_tokens=input_delta,
|
||||
@@ -210,7 +210,7 @@ async def post_message(
|
||||
succeeded=True,
|
||||
counts_toward_quota=False,
|
||||
error_code=None,
|
||||
extra_data={"phase": session.current_phase},
|
||||
extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -288,7 +288,7 @@ async def generate_tree(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_generate",
|
||||
tier=plan,
|
||||
input_tokens=session.total_input_tokens - prev_input,
|
||||
@@ -297,7 +297,7 @@ async def generate_tree(
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code="invalid_output",
|
||||
extra_data={"error": str(e)},
|
||||
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
@@ -312,7 +312,7 @@ async def generate_tree(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_generate",
|
||||
tier=plan,
|
||||
input_tokens=input_delta,
|
||||
@@ -321,7 +321,7 @@ async def generate_tree(
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"error": str(e)},
|
||||
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
@@ -342,7 +342,7 @@ async def generate_tree(
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=session.id,
|
||||
conversation_id=None,
|
||||
generation_type="chat_generate",
|
||||
tier=plan,
|
||||
input_tokens=input_delta,
|
||||
@@ -354,7 +354,7 @@ async def generate_tree(
|
||||
succeeded=True,
|
||||
counts_toward_quota=True,
|
||||
error_code=None,
|
||||
extra_data=None,
|
||||
extra_data={"chat_session_id": str(session.id)},
|
||||
db=db,
|
||||
)
|
||||
|
||||
|
||||
187
backend/tests/test_ai_chat.py
Normal file
187
backend/tests/test_ai_chat.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Integration tests for AI Chat Builder endpoints.
|
||||
|
||||
These tests mock the AI provider to avoid real API calls.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_provider():
|
||||
"""Mock AI provider that returns realistic responses."""
|
||||
provider = AsyncMock()
|
||||
provider.generate_text = AsyncMock(return_value=(
|
||||
"Great question! Let's build a troubleshooting flow for DNS resolution issues. "
|
||||
"To start, I need to understand the scope.\n\n"
|
||||
"Who is the target audience for this flow? Are we targeting:\n"
|
||||
"- Tier 1 help desk (basic checks only)\n"
|
||||
"- Tier 2 desktop support (intermediate diagnostics)\n"
|
||||
"- Tier 3 systems engineers (deep DNS troubleshooting)\n\n"
|
||||
"[PHASE:scoping]",
|
||||
500, # input tokens
|
||||
200, # output tokens
|
||||
))
|
||||
return provider
|
||||
|
||||
|
||||
async def test_create_chat_session(client, auth_headers, mock_ai_provider):
|
||||
"""POST /ai/chat/sessions creates a session and returns AI greeting."""
|
||||
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||
resp = await client.post(
|
||||
"/api/v1/ai/chat/sessions",
|
||||
json={"flow_type": "troubleshooting"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert "session_id" in data
|
||||
assert "greeting" in data
|
||||
assert data["current_phase"] == "scoping"
|
||||
assert len(data["greeting"]) > 0
|
||||
|
||||
|
||||
async def test_send_message(client, auth_headers, mock_ai_provider):
|
||||
"""POST /ai/chat/sessions/{id}/messages returns AI response."""
|
||||
# Create session first
|
||||
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||
create_resp = await client.post(
|
||||
"/api/v1/ai/chat/sessions",
|
||||
json={"flow_type": "troubleshooting"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
# Mock response with tree update — must pass validate_generated_tree (min 5 nodes)
|
||||
import json
|
||||
tree_obj = {
|
||||
"id": "root", "type": "decision",
|
||||
"question": "What DNS symptom is the user experiencing?",
|
||||
"options": [
|
||||
{"id": "opt-1", "label": "Cannot resolve any domains", "next_node_id": "dns-check"},
|
||||
{"id": "opt-2", "label": "Intermittent failures", "next_node_id": "dns-cache-fix"},
|
||||
],
|
||||
"children": [
|
||||
{
|
||||
"id": "dns-check", "type": "decision",
|
||||
"question": "Is the DNS Client service running?",
|
||||
"options": [
|
||||
{"id": "dc-1", "label": "Yes", "next_node_id": "dns-fwd-fix"},
|
||||
{"id": "dc-2", "label": "No", "next_node_id": "dns-svc-fix"},
|
||||
],
|
||||
"children": [
|
||||
{"id": "dns-fwd-fix", "type": "solution", "title": "Check DNS Forwarders",
|
||||
"description": "DNS forwarders may be misconfigured",
|
||||
"resolution_steps": ["Check forwarder config"]},
|
||||
{"id": "dns-svc-fix", "type": "solution", "title": "Restart DNS Service",
|
||||
"description": "DNS Client service is stopped",
|
||||
"resolution_steps": ["Start-Service Dnscache"]},
|
||||
],
|
||||
},
|
||||
{"id": "dns-cache-fix", "type": "solution", "title": "Stale DNS Cache",
|
||||
"description": "DNS cache has stale entries",
|
||||
"resolution_steps": ["ipconfig /flushdns"]},
|
||||
],
|
||||
}
|
||||
tree_json = json.dumps(tree_obj)
|
||||
mock_ai_provider.generate_text = AsyncMock(return_value=(
|
||||
"Good, targeting Tier 2 support. Let's start with the first diagnostic question.\n\n"
|
||||
"The root question should be: 'What DNS symptom is the user experiencing?'\n\n"
|
||||
f"[TREE_UPDATE]\n{tree_json}\n[/TREE_UPDATE]\n\n"
|
||||
"[PHASE:discovery]",
|
||||
800,
|
||||
400,
|
||||
))
|
||||
|
||||
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||
resp = await client.post(
|
||||
f"/api/v1/ai/chat/sessions/{session_id}/messages",
|
||||
json={"content": "This is for Tier 2 support, hybrid environment with on-prem AD."},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "content" in data
|
||||
assert data["current_phase"] == "discovery"
|
||||
assert data["working_tree"] is not None
|
||||
assert data["working_tree"]["type"] == "decision"
|
||||
# Markers should be stripped from content
|
||||
assert "[TREE_UPDATE]" not in data["content"]
|
||||
assert "[PHASE:" not in data["content"]
|
||||
|
||||
|
||||
async def test_get_session(client, auth_headers, mock_ai_provider):
|
||||
"""GET /ai/chat/sessions/{id} returns full session state."""
|
||||
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||
create_resp = await client.post(
|
||||
"/api/v1/ai/chat/sessions",
|
||||
json={"flow_type": "troubleshooting"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["session_id"] == session_id
|
||||
assert data["status"] == "active"
|
||||
assert data["flow_type"] == "troubleshooting"
|
||||
# Hidden primer message should be filtered out
|
||||
assert all(
|
||||
msg.get("role") == "assistant" or not msg.get("hidden")
|
||||
for msg in data["conversation_history"]
|
||||
)
|
||||
|
||||
|
||||
async def test_abandon_session(client, auth_headers, mock_ai_provider):
|
||||
"""DELETE /ai/chat/sessions/{id} sets status to abandoned."""
|
||||
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||
create_resp = await client.post(
|
||||
"/api/v1/ai/chat/sessions",
|
||||
json={"flow_type": "troubleshooting"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
resp = await client.delete(
|
||||
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify session is abandoned
|
||||
get_resp = await client.get(
|
||||
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert get_resp.json()["status"] == "abandoned"
|
||||
|
||||
|
||||
async def test_session_not_found(client, auth_headers):
|
||||
"""Accessing nonexistent session returns 404."""
|
||||
import uuid
|
||||
fake_id = str(uuid.uuid4())
|
||||
resp = await client.get(
|
||||
f"/api/v1/ai/chat/sessions/{fake_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_ai_disabled_returns_503(client, auth_headers):
|
||||
"""When AI is not configured, endpoints return 503."""
|
||||
with patch("app.api.endpoints.ai_chat.settings") as mock_settings:
|
||||
mock_settings.ai_enabled = False
|
||||
resp = await client.post(
|
||||
"/api/v1/ai/chat/sessions",
|
||||
json={"flow_type": "troubleshooting"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 503
|
||||
Reference in New Issue
Block a user