feat: add backend tests for AI chat builder + fix conversation_id FK issue

Tests cover session create, send message with tree update, get session,
abandon, 404 on missing session, and 503 when AI disabled.

Fixed: ai_usage.conversation_id has FK to ai_conversations, not
ai_chat_sessions. Chat builder now passes conversation_id=None and
tracks session reference in extra_data.chat_session_id.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
chihlasm
2026-02-27 04:06:47 -05:00
parent ef96b1a12f
commit 0da67586da
2 changed files with 199 additions and 12 deletions

View File

@@ -106,7 +106,7 @@ async def create_session(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=session.total_input_tokens,
@@ -118,7 +118,7 @@ async def create_session(
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data={"phase": "scoping"},
extra_data={"phase": "scoping", "chat_session_id": str(session.id)},
db=db,
)
@@ -175,7 +175,7 @@ async def post_message(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=0,
@@ -184,7 +184,7 @@ async def post_message(
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data=None,
extra_data={"chat_session_id": str(session.id)},
db=db,
)
await db.commit()
@@ -198,7 +198,7 @@ async def post_message(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_message",
tier=plan,
input_tokens=input_delta,
@@ -210,7 +210,7 @@ async def post_message(
succeeded=True,
counts_toward_quota=False,
error_code=None,
extra_data={"phase": session.current_phase},
extra_data={"phase": session.current_phase, "chat_session_id": str(session.id)},
db=db,
)
@@ -288,7 +288,7 @@ async def generate_tree(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=session.total_input_tokens - prev_input,
@@ -297,7 +297,7 @@ async def generate_tree(
succeeded=False,
counts_toward_quota=False,
error_code="invalid_output",
extra_data={"error": str(e)},
extra_data={"error": str(e), "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
@@ -312,7 +312,7 @@ async def generate_tree(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=input_delta,
@@ -321,7 +321,7 @@ async def generate_tree(
succeeded=False,
counts_toward_quota=False,
error_code=type(e).__name__,
extra_data={"error": str(e)},
extra_data={"error": str(e), "chat_session_id": str(session.id)},
db=db,
)
await db.commit()
@@ -342,7 +342,7 @@ async def generate_tree(
await record_ai_usage(
user_id=current_user.id,
account_id=current_user.account_id,
conversation_id=session.id,
conversation_id=None,
generation_type="chat_generate",
tier=plan,
input_tokens=input_delta,
@@ -354,7 +354,7 @@ async def generate_tree(
succeeded=True,
counts_toward_quota=True,
error_code=None,
extra_data=None,
extra_data={"chat_session_id": str(session.id)},
db=db,
)

View File

@@ -0,0 +1,187 @@
"""Integration tests for AI Chat Builder endpoints.
These tests mock the AI provider to avoid real API calls.
"""
import pytest
from unittest.mock import AsyncMock, patch
pytestmark = pytest.mark.asyncio
@pytest.fixture
def mock_ai_provider():
"""Mock AI provider that returns realistic responses."""
provider = AsyncMock()
provider.generate_text = AsyncMock(return_value=(
"Great question! Let's build a troubleshooting flow for DNS resolution issues. "
"To start, I need to understand the scope.\n\n"
"Who is the target audience for this flow? Are we targeting:\n"
"- Tier 1 help desk (basic checks only)\n"
"- Tier 2 desktop support (intermediate diagnostics)\n"
"- Tier 3 systems engineers (deep DNS troubleshooting)\n\n"
"[PHASE:scoping]",
500, # input tokens
200, # output tokens
))
return provider
async def test_create_chat_session(client, auth_headers, mock_ai_provider):
"""POST /ai/chat/sessions creates a session and returns AI greeting."""
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
resp = await client.post(
"/api/v1/ai/chat/sessions",
json={"flow_type": "troubleshooting"},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
assert "session_id" in data
assert "greeting" in data
assert data["current_phase"] == "scoping"
assert len(data["greeting"]) > 0
async def test_send_message(client, auth_headers, mock_ai_provider):
"""POST /ai/chat/sessions/{id}/messages returns AI response."""
# Create session first
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
create_resp = await client.post(
"/api/v1/ai/chat/sessions",
json={"flow_type": "troubleshooting"},
headers=auth_headers,
)
session_id = create_resp.json()["session_id"]
# Mock response with tree update — must pass validate_generated_tree (min 5 nodes)
import json
tree_obj = {
"id": "root", "type": "decision",
"question": "What DNS symptom is the user experiencing?",
"options": [
{"id": "opt-1", "label": "Cannot resolve any domains", "next_node_id": "dns-check"},
{"id": "opt-2", "label": "Intermittent failures", "next_node_id": "dns-cache-fix"},
],
"children": [
{
"id": "dns-check", "type": "decision",
"question": "Is the DNS Client service running?",
"options": [
{"id": "dc-1", "label": "Yes", "next_node_id": "dns-fwd-fix"},
{"id": "dc-2", "label": "No", "next_node_id": "dns-svc-fix"},
],
"children": [
{"id": "dns-fwd-fix", "type": "solution", "title": "Check DNS Forwarders",
"description": "DNS forwarders may be misconfigured",
"resolution_steps": ["Check forwarder config"]},
{"id": "dns-svc-fix", "type": "solution", "title": "Restart DNS Service",
"description": "DNS Client service is stopped",
"resolution_steps": ["Start-Service Dnscache"]},
],
},
{"id": "dns-cache-fix", "type": "solution", "title": "Stale DNS Cache",
"description": "DNS cache has stale entries",
"resolution_steps": ["ipconfig /flushdns"]},
],
}
tree_json = json.dumps(tree_obj)
mock_ai_provider.generate_text = AsyncMock(return_value=(
"Good, targeting Tier 2 support. Let's start with the first diagnostic question.\n\n"
"The root question should be: 'What DNS symptom is the user experiencing?'\n\n"
f"[TREE_UPDATE]\n{tree_json}\n[/TREE_UPDATE]\n\n"
"[PHASE:discovery]",
800,
400,
))
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
resp = await client.post(
f"/api/v1/ai/chat/sessions/{session_id}/messages",
json={"content": "This is for Tier 2 support, hybrid environment with on-prem AD."},
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert "content" in data
assert data["current_phase"] == "discovery"
assert data["working_tree"] is not None
assert data["working_tree"]["type"] == "decision"
# Markers should be stripped from content
assert "[TREE_UPDATE]" not in data["content"]
assert "[PHASE:" not in data["content"]
async def test_get_session(client, auth_headers, mock_ai_provider):
"""GET /ai/chat/sessions/{id} returns full session state."""
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
create_resp = await client.post(
"/api/v1/ai/chat/sessions",
json={"flow_type": "troubleshooting"},
headers=auth_headers,
)
session_id = create_resp.json()["session_id"]
resp = await client.get(
f"/api/v1/ai/chat/sessions/{session_id}",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["session_id"] == session_id
assert data["status"] == "active"
assert data["flow_type"] == "troubleshooting"
# Hidden primer message should be filtered out
assert all(
msg.get("role") == "assistant" or not msg.get("hidden")
for msg in data["conversation_history"]
)
async def test_abandon_session(client, auth_headers, mock_ai_provider):
"""DELETE /ai/chat/sessions/{id} sets status to abandoned."""
with patch("app.core.ai_chat_service.get_ai_provider", return_value=mock_ai_provider):
create_resp = await client.post(
"/api/v1/ai/chat/sessions",
json={"flow_type": "troubleshooting"},
headers=auth_headers,
)
session_id = create_resp.json()["session_id"]
resp = await client.delete(
f"/api/v1/ai/chat/sessions/{session_id}",
headers=auth_headers,
)
assert resp.status_code == 204
# Verify session is abandoned
get_resp = await client.get(
f"/api/v1/ai/chat/sessions/{session_id}",
headers=auth_headers,
)
assert get_resp.json()["status"] == "abandoned"
async def test_session_not_found(client, auth_headers):
"""Accessing nonexistent session returns 404."""
import uuid
fake_id = str(uuid.uuid4())
resp = await client.get(
f"/api/v1/ai/chat/sessions/{fake_id}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_ai_disabled_returns_503(client, auth_headers):
"""When AI is not configured, endpoints return 503."""
with patch("app.api.endpoints.ai_chat.settings") as mock_settings:
mock_settings.ai_enabled = False
resp = await client.post(
"/api/v1/ai/chat/sessions",
json={"flow_type": "troubleshooting"},
headers=auth_headers,
)
assert resp.status_code == 503