diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 7fdf7fb6..4768ac0d 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -105,6 +105,7 @@ class Settings(BaseSettings): "variable_inference": "fast", "kb_convert": "standard", "script_build": "standard", + "network_diagram_generate": "standard", } def get_model_for_action(self, action_type: str) -> str: diff --git a/backend/app/services/network_diagram_ai_service.py b/backend/app/services/network_diagram_ai_service.py new file mode 100644 index 00000000..61defa50 --- /dev/null +++ b/backend/app/services/network_diagram_ai_service.py @@ -0,0 +1,146 @@ +"""AI service for generating network diagrams from natural language.""" +import json +import logging + +from app.core.ai_provider import get_ai_provider +from app.core.config import settings +from app.schemas.network_diagram import ( + AIGenerateRequest, + AIGenerateResponse, + DiagramNode, + DiagramEdge, + DeviceProperties, +) + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers. +Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble. + +Return this exact structure: +{{ + "nodes": [ + {{ + "id": "unique-string", + "type": "device-type-slug", + "label": "device label", + "position": {{ "x": number, "y": number }}, + "properties": {{ + "hostname": "string or null", + "ip": "string or null", + "subnet": "string or null", + "vendor": "string or null", + "model": "string or null", + "role": "string or null", + "vlan": "string or null", + "notes": "string or null", + "status": "unknown" + }} + }} + ], + "edges": [ + {{ + "id": "unique-string", + "source": "node-id", + "target": "node-id", + "label": "connection label or null", + "connectionType": "ethernet|fiber|wifi|vpn|vlan|wan", + "speed": "string or null", + "notes": "string or null" + }} + ], + "suggestedName": "short descriptive diagram name", + "notes": "any important assumptions or missing info, or null" +}} + +Available device type slugs: {available_slugs} + +Position nodes thoughtfully in a logical network topology layout. +Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y. +Place WAN/internet at top, core network in middle, endpoints at bottom. +{merge_instructions}""" + +MERGE_INSTRUCTIONS = """ +IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices. +The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}. +Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100). +You may create edges that connect new nodes to existing nodes if the description implies a connection. +Use these existing node IDs for connections: {existing_node_ids}""" + + +async def generate_diagram( + request: AIGenerateRequest, + available_slugs: list[str], + existing_node_ids: list[str] | None = None, +) -> AIGenerateResponse: + merge_instructions = "" + if request.mode == "merge" and request.existingBounds: + b = request.existingBounds + merge_instructions = MERGE_INSTRUCTIONS.format( + minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY, + existing_node_ids=", ".join(existing_node_ids or []), + ) + + system_prompt = SYSTEM_PROMPT_TEMPLATE.format( + available_slugs=", ".join(available_slugs), + merge_instructions=merge_instructions, + ) + + model = settings.get_model_for_action("network_diagram_generate") + provider = get_ai_provider(model) + + messages = [{"role": "user", "content": request.description}] + + response_text, input_tokens, output_tokens = await provider.generate_json( + system_prompt=system_prompt, + messages=messages, + max_tokens=4096, + ) + + logger.info( + "Network diagram AI generation: input_tokens=%d, output_tokens=%d", + input_tokens, output_tokens, + ) + + try: + data = json.loads(response_text) + except json.JSONDecodeError as e: + logger.error("Failed to parse AI response as JSON: %s", e) + raise ValueError("AI generated an invalid response, please try again") + + nodes = [] + for raw_node in data.get("nodes", []): + node_type = raw_node.get("type", "server") + if node_type not in available_slugs: + logger.warning("Unknown device type '%s', falling back to 'server'", node_type) + node_type = "server" + + nodes.append(DiagramNode( + id=raw_node["id"], + type=node_type, + label=raw_node.get("label", node_type), + position=raw_node.get("position", {"x": 0, "y": 0}), + properties=DeviceProperties(**{ + k: v for k, v in raw_node.get("properties", {}).items() + if k in DeviceProperties.model_fields + }), + )) + + edges = [] + for raw_edge in data.get("edges", []): + edges.append(DiagramEdge( + id=raw_edge["id"], + source=raw_edge["source"], + target=raw_edge["target"], + label=raw_edge.get("label"), + connectionType=raw_edge.get("connectionType", "ethernet"), + speed=raw_edge.get("speed"), + notes=raw_edge.get("notes"), + )) + + return AIGenerateResponse( + nodes=nodes, + edges=edges, + suggestedName=data.get("suggestedName"), + notes=data.get("notes"), + )