"""AI service for generating network diagrams from natural language.""" import json import logging from app.core.ai_provider import get_ai_provider from app.core.config import settings from app.schemas.network_diagram import ( AIGenerateRequest, AIGenerateResponse, DiagramNode, DiagramEdge, DeviceProperties, Position, ) logger = logging.getLogger(__name__) SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers. Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble. Return this exact structure: {{ "nodes": [ {{ "id": "unique-string", "type": "device-type-slug", "label": "device label", "position": {{ "x": number, "y": number }}, "properties": {{ "hostname": "string or null", "ip": "string or null", "subnet": "string or null", "vendor": "string or null", "model": "string or null", "role": "string or null", "vlan": "string or null", "notes": "string or null", "status": "unknown" }} }} ], "edges": [ {{ "id": "unique-string", "source": "node-id", "target": "node-id", "label": "connection label or null", "connectionType": "ethernet|fiber|wifi|vpn|vlan|wan", "speed": "string or null", "notes": "string or null" }} ], "suggestedName": "short descriptive diagram name", "notes": "any important assumptions or missing info, or null" }} Available device type slugs: {available_slugs} Position nodes thoughtfully in a logical network topology layout. Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y. Place WAN/internet at top, core network in middle, endpoints at bottom. {merge_instructions}""" MERGE_INSTRUCTIONS = """ IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices. The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}. Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100). You may create edges that connect new nodes to existing nodes if the description implies a connection. Use these existing node IDs for connections: {existing_node_ids}""" async def generate_diagram( request: AIGenerateRequest, available_slugs: list[str], existing_node_ids: list[str] | None = None, ) -> AIGenerateResponse: merge_instructions = "" if request.mode == "merge" and request.existingBounds: b = request.existingBounds merge_instructions = MERGE_INSTRUCTIONS.format( minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY, existing_node_ids=", ".join(existing_node_ids or []), ) system_prompt = SYSTEM_PROMPT_TEMPLATE.format( available_slugs=", ".join(available_slugs), merge_instructions=merge_instructions, ) model = settings.get_model_for_action("network_diagram_generate") provider = get_ai_provider(model) messages = [{"role": "user", "content": request.description}] response_text, input_tokens, output_tokens = await provider.generate_json( system_prompt=system_prompt, messages=messages, max_tokens=4096, ) logger.info( "Network diagram AI generation: input_tokens=%d, output_tokens=%d", input_tokens, output_tokens, ) try: data = json.loads(response_text) except json.JSONDecodeError as e: logger.error("Failed to parse AI response as JSON: %s", e) raise ValueError("AI generated an invalid response, please try again") try: nodes = [] for raw_node in data.get("nodes", []): node_type = raw_node.get("type", "server") if node_type not in available_slugs: logger.warning("Unknown device type '%s', falling back to 'server'", node_type) node_type = "server" nodes.append(DiagramNode( id=raw_node["id"], type=node_type, label=raw_node.get("label", node_type), position=Position(**raw_node.get("position", {"x": 0, "y": 0})), properties=DeviceProperties(**{ k: v for k, v in raw_node.get("properties", {}).items() if k in DeviceProperties.model_fields }), )) edges = [] for raw_edge in data.get("edges", []): edges.append(DiagramEdge( id=raw_edge["id"], source=raw_edge["source"], target=raw_edge["target"], label=raw_edge.get("label"), connectionType=raw_edge.get("connectionType", "ethernet"), speed=raw_edge.get("speed"), notes=raw_edge.get("notes"), )) except KeyError as e: logger.warning("AI response missing required field: %s", e) raise ValueError(f"AI generated incomplete data (missing {e}), please try again") return AIGenerateResponse( nodes=nodes, edges=edges, suggestedName=data.get("suggestedName"), notes=data.get("notes"), )