feat: add backend tests for AI chat builder + fix conversation_id FK issue
Tests cover session create, send message with tree update, get session, abandon, 404 on missing session, and 503 when AI disabled. Fixed: ai_usage.conversation_id has FK to ai_conversations, not ai_chat_sessions. Chat builder now passes conversation_id=None and tracks session reference in extra_data.chat_session_id. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -106,7 +106,7 @@ async def create_session(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_message",
|
generation_type="chat_message",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=session.total_input_tokens,
|
input_tokens=session.total_input_tokens,
|
||||||
@@ -118,7 +118,7 @@ async def create_session(
|
|||||||
succeeded=True,
|
succeeded=True,
|
||||||
counts_toward_quota=False,
|
counts_toward_quota=False,
|
||||||
error_code=None,
|
error_code=None,
|
||||||
extra_data={"phase": "scoping"},
|
extra_data={"phase": "scoping", "chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ async def post_message(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_message",
|
generation_type="chat_message",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=0,
|
input_tokens=0,
|
||||||
@@ -184,7 +184,7 @@ async def post_message(
|
|||||||
succeeded=False,
|
succeeded=False,
|
||||||
counts_toward_quota=False,
|
counts_toward_quota=False,
|
||||||
error_code=type(e).__name__,
|
error_code=type(e).__name__,
|
||||||
extra_data=None,
|
extra_data={"chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -198,7 +198,7 @@ async def post_message(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_message",
|
generation_type="chat_message",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=input_delta,
|
input_tokens=input_delta,
|
||||||
@@ -210,7 +210,7 @@ async def post_message(
|
|||||||
succeeded=True,
|
succeeded=True,
|
||||||
counts_toward_quota=False,
|
counts_toward_quota=False,
|
||||||
error_code=None,
|
error_code=None,
|
||||||
extra_data={"phase": session.current_phase},
|
extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -288,7 +288,7 @@ async def generate_tree(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_generate",
|
generation_type="chat_generate",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=session.total_input_tokens - prev_input,
|
input_tokens=session.total_input_tokens - prev_input,
|
||||||
@@ -297,7 +297,7 @@ async def generate_tree(
|
|||||||
succeeded=False,
|
succeeded=False,
|
||||||
counts_toward_quota=False,
|
counts_toward_quota=False,
|
||||||
error_code="invalid_output",
|
error_code="invalid_output",
|
||||||
extra_data={"error": str(e)},
|
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -312,7 +312,7 @@ async def generate_tree(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_generate",
|
generation_type="chat_generate",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=input_delta,
|
input_tokens=input_delta,
|
||||||
@@ -321,7 +321,7 @@ async def generate_tree(
|
|||||||
succeeded=False,
|
succeeded=False,
|
||||||
counts_toward_quota=False,
|
counts_toward_quota=False,
|
||||||
error_code=type(e).__name__,
|
error_code=type(e).__name__,
|
||||||
extra_data={"error": str(e)},
|
extra_data={"error": str(e), "chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -342,7 +342,7 @@ async def generate_tree(
|
|||||||
await record_ai_usage(
|
await record_ai_usage(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
account_id=current_user.account_id,
|
account_id=current_user.account_id,
|
||||||
conversation_id=session.id,
|
conversation_id=None,
|
||||||
generation_type="chat_generate",
|
generation_type="chat_generate",
|
||||||
tier=plan,
|
tier=plan,
|
||||||
input_tokens=input_delta,
|
input_tokens=input_delta,
|
||||||
@@ -354,7 +354,7 @@ async def generate_tree(
|
|||||||
succeeded=True,
|
succeeded=True,
|
||||||
counts_toward_quota=True,
|
counts_toward_quota=True,
|
||||||
error_code=None,
|
error_code=None,
|
||||||
extra_data=None,
|
extra_data={"chat_session_id": str(session.id)},
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
187
backend/tests/test_ai_chat.py
Normal file
187
backend/tests/test_ai_chat.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""Integration tests for AI Chat Builder endpoints.
|
||||||
|
|
||||||
|
These tests mock the AI provider to avoid real API calls.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ai_provider():
|
||||||
|
"""Mock AI provider that returns realistic responses."""
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.generate_text = AsyncMock(return_value=(
|
||||||
|
"Great question! Let's build a troubleshooting flow for DNS resolution issues. "
|
||||||
|
"To start, I need to understand the scope.\n\n"
|
||||||
|
"Who is the target audience for this flow? Are we targeting:\n"
|
||||||
|
"- Tier 1 help desk (basic checks only)\n"
|
||||||
|
"- Tier 2 desktop support (intermediate diagnostics)\n"
|
||||||
|
"- Tier 3 systems engineers (deep DNS troubleshooting)\n\n"
|
||||||
|
"[PHASE:scoping]",
|
||||||
|
500, # input tokens
|
||||||
|
200, # output tokens
|
||||||
|
))
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_chat_session(client, auth_headers, mock_ai_provider):
|
||||||
|
"""POST /ai/chat/sessions creates a session and returns AI greeting."""
|
||||||
|
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/ai/chat/sessions",
|
||||||
|
json={"flow_type": "troubleshooting"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "session_id" in data
|
||||||
|
assert "greeting" in data
|
||||||
|
assert data["current_phase"] == "scoping"
|
||||||
|
assert len(data["greeting"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_message(client, auth_headers, mock_ai_provider):
|
||||||
|
"""POST /ai/chat/sessions/{id}/messages returns AI response."""
|
||||||
|
# Create session first
|
||||||
|
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/ai/chat/sessions",
|
||||||
|
json={"flow_type": "troubleshooting"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# Mock response with tree update — must pass validate_generated_tree (min 5 nodes)
|
||||||
|
import json
|
||||||
|
tree_obj = {
|
||||||
|
"id": "root", "type": "decision",
|
||||||
|
"question": "What DNS symptom is the user experiencing?",
|
||||||
|
"options": [
|
||||||
|
{"id": "opt-1", "label": "Cannot resolve any domains", "next_node_id": "dns-check"},
|
||||||
|
{"id": "opt-2", "label": "Intermittent failures", "next_node_id": "dns-cache-fix"},
|
||||||
|
],
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"id": "dns-check", "type": "decision",
|
||||||
|
"question": "Is the DNS Client service running?",
|
||||||
|
"options": [
|
||||||
|
{"id": "dc-1", "label": "Yes", "next_node_id": "dns-fwd-fix"},
|
||||||
|
{"id": "dc-2", "label": "No", "next_node_id": "dns-svc-fix"},
|
||||||
|
],
|
||||||
|
"children": [
|
||||||
|
{"id": "dns-fwd-fix", "type": "solution", "title": "Check DNS Forwarders",
|
||||||
|
"description": "DNS forwarders may be misconfigured",
|
||||||
|
"resolution_steps": ["Check forwarder config"]},
|
||||||
|
{"id": "dns-svc-fix", "type": "solution", "title": "Restart DNS Service",
|
||||||
|
"description": "DNS Client service is stopped",
|
||||||
|
"resolution_steps": ["Start-Service Dnscache"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"id": "dns-cache-fix", "type": "solution", "title": "Stale DNS Cache",
|
||||||
|
"description": "DNS cache has stale entries",
|
||||||
|
"resolution_steps": ["ipconfig /flushdns"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
tree_json = json.dumps(tree_obj)
|
||||||
|
mock_ai_provider.generate_text = AsyncMock(return_value=(
|
||||||
|
"Good, targeting Tier 2 support. Let's start with the first diagnostic question.\n\n"
|
||||||
|
"The root question should be: 'What DNS symptom is the user experiencing?'\n\n"
|
||||||
|
f"[TREE_UPDATE]\n{tree_json}\n[/TREE_UPDATE]\n\n"
|
||||||
|
"[PHASE:discovery]",
|
||||||
|
800,
|
||||||
|
400,
|
||||||
|
))
|
||||||
|
|
||||||
|
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||||
|
resp = await client.post(
|
||||||
|
f"/api/v1/ai/chat/sessions/{session_id}/messages",
|
||||||
|
json={"content": "This is for Tier 2 support, hybrid environment with on-prem AD."},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "content" in data
|
||||||
|
assert data["current_phase"] == "discovery"
|
||||||
|
assert data["working_tree"] is not None
|
||||||
|
assert data["working_tree"]["type"] == "decision"
|
||||||
|
# Markers should be stripped from content
|
||||||
|
assert "[TREE_UPDATE]" not in data["content"]
|
||||||
|
assert "[PHASE:" not in data["content"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_session(client, auth_headers, mock_ai_provider):
|
||||||
|
"""GET /ai/chat/sessions/{id} returns full session state."""
|
||||||
|
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/ai/chat/sessions",
|
||||||
|
json={"flow_type": "troubleshooting"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["session_id"] == session_id
|
||||||
|
assert data["status"] == "active"
|
||||||
|
assert data["flow_type"] == "troubleshooting"
|
||||||
|
# Hidden primer message should be filtered out
|
||||||
|
assert all(
|
||||||
|
msg.get("role") == "assistant" or not msg.get("hidden")
|
||||||
|
for msg in data["conversation_history"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abandon_session(client, auth_headers, mock_ai_provider):
|
||||||
|
"""DELETE /ai/chat/sessions/{id} sets status to abandoned."""
|
||||||
|
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
|
||||||
|
create_resp = await client.post(
|
||||||
|
"/api/v1/ai/chat/sessions",
|
||||||
|
json={"flow_type": "troubleshooting"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
resp = await client.delete(
|
||||||
|
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
# Verify session is abandoned
|
||||||
|
get_resp = await client.get(
|
||||||
|
f"/api/v1/ai/chat/sessions/{session_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert get_resp.json()["status"] == "abandoned"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_session_not_found(client, auth_headers):
|
||||||
|
"""Accessing nonexistent session returns 404."""
|
||||||
|
import uuid
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/v1/ai/chat/sessions/{fake_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ai_disabled_returns_503(client, auth_headers):
|
||||||
|
"""When AI is not configured, endpoints return 503."""
|
||||||
|
with patch("app.api.endpoints.ai_chat.settings") as mock_settings:
|
||||||
|
mock_settings.ai_enabled = False
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/ai/chat/sessions",
|
||||||
|
json={"flow_type": "troubleshooting"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 503
|
||||||
Reference in New Issue
Block a user