Files
resolutionflow/backend/app/services/script_builder_service.py
chihlasm 64f004a62c feat: tenant isolation Phase 4 — RLS on 31 remaining tables + script_builder fix
Enable RLS on all remaining tenant-scoped tables (31 tables):

Standard policy (tenant sees own rows):
  users, account_invites, account_limit_overrides, account_feature_overrides,
  subscriptions, ai_chat_sessions, ai_conversations, ai_session_steps,
  ai_session_embeddings, ai_suggestions, ai_usage, assistant_chats,
  attachments, copilot_conversations, feedback, file_uploads, fork_points,
  kb_imports, notifications, notification_configs, notification_logs,
  psa_activity_logs, psa_member_mappings, script_builder_sessions,
  script_categories, session_ratings, tree_embeddings, user_folders,
  user_pinned_trees

Platform-visibility policy (own rows OR PLATFORM_ACCOUNT_ID):
  platform_steps, template_trees

Intentionally skipped:
  accounts (IS the root table, no account_id column)
  plan_feature_defaults (platform config, no account_id column)

Also fixes script_builder_service.create_session() which was missing
account_id= on ScriptBuilderSession construction, causing 500s on all
script builder endpoints (pre-existing CI failure).

Adds Phase 4 RLS isolation tests covering: users, script_builder_sessions,
ai_session_steps, notifications, platform_steps, template_trees.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-12 01:25:28 +00:00

401 lines
14 KiB
Python

"""AI Script Builder service — generates scripts from natural language descriptions."""
import logging
import re
from datetime import datetime, timezone
from typing import Optional
from uuid import UUID
from sqlalchemy import select, func, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.ai_provider import get_ai_provider
from app.core.config import settings
from app.models.script_builder_session import ScriptBuilderSession, ScriptBuilderMessage
from app.schemas.script_builder import (
ScriptBuilderMessageResponse,
ScriptBuilderSessionDetail,
ScriptBuilderSessionSummary,
)
logger = logging.getLogger(__name__)
MAX_MESSAGES_PER_SESSION = 30
LANGUAGE_PROMPTS = {
"powershell": """\
You are an expert PowerShell script writer for MSP (Managed Service Provider) environments.
## Script Standards
- Use Advanced Functions with CmdletBinding and param() blocks
- Include comment-based help (.SYNOPSIS, .DESCRIPTION, .PARAMETER, .EXAMPLE)
- Use try/catch/finally for error handling
- Use Write-Verbose for diagnostic output, Write-Error for failures
- Support pipeline input where appropriate
- Use approved PowerShell verbs (Get-, Set-, New-, Remove-, etc.)
- Import required modules at the top (e.g., Import-Module ActiveDirectory)
- Use [Parameter(Mandatory=$true)] for required params
- Default to UTF-8 output for exports
## Security
- Never hardcode credentials — use Get-Credential or SecureString params
- Use -WhatIf and -Confirm support via SupportsShouldProcess
- Validate input with ValidateSet, ValidatePattern, ValidateRange
""",
"bash": """\
You are an expert Bash script writer for Linux/macOS system administration.
## Script Standards
- Start with #!/bin/bash
- Use set -euo pipefail for safety
- Parse arguments with getopts or positional parameters
- Include a usage() function for --help
- Use functions for logical grouping
- Quote all variable expansions ("$var" not $var)
- Use [[ ]] for conditionals (not [ ])
- Add comments explaining non-obvious logic
- Use lowercase_with_underscores for variable names
- Exit with meaningful exit codes (0=success, 1=general error, 2=usage error)
## Security
- Never store passwords in scripts — use environment variables or prompts
- Validate all user inputs
- Use mktemp for temporary files
""",
"python": """\
You are an expert Python script writer for IT automation and system administration.
## Script Standards
- Use Python 3.10+ syntax
- Add type hints to all function signatures
- Use argparse for CLI argument parsing
- Include if __name__ == "__main__": guard
- Use logging module (not print) for diagnostic output
- Use docstrings for functions and modules
- Use pathlib.Path instead of os.path
- Handle exceptions with specific exception types
- Use f-strings for string formatting
- Follow PEP 8 naming conventions
## Security
- Never hardcode secrets — use environment variables or config files
- Validate and sanitize all user inputs
- Use subprocess.run() with shell=False (never shell=True with user input)
""",
}
SYSTEM_PROMPT_TEMPLATE = """\
{language_prompt}
## Response Format
Respond conversationally. When generating a script:
1. Briefly explain what the script does and any assumptions
2. Include the complete script in a single fenced code block with the language tag
3. Suggest a filename (e.g., `Get-LinkedGPOs.ps1`)
When the user asks for modifications, generate the COMPLETE updated script (not a diff).
## Context
The user is an MSP engineer using ResolutionFlow. They need scripts for managing client infrastructure.
Keep scripts practical, production-ready, and well-documented.\
"""
def _extract_script_from_response(content: str, language: str) -> tuple[str | None, str | None]:
"""Extract code block and filename from AI response.
Returns (script, filename) tuple.
"""
# Map language to code fence tags
lang_tags = {
"powershell": ["powershell", "ps1", "pwsh"],
"bash": ["bash", "sh", "shell"],
"python": ["python", "py", "python3"],
}
tags = lang_tags.get(language, [language])
# Try each language-specific tag, then fall back to generic fence
script = None
for tag in tags:
pattern = rf"```{tag}\s*\n(.*?)```"
match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
if match:
script = match.group(1).strip()
break
else:
# Try generic code fence
pattern = r"```\s*\n(.*?)```"
match = re.search(pattern, content, re.DOTALL)
script = match.group(1).strip() if match else None
# Extract filename suggestion
filename = None
ext_map = {"powershell": ".ps1", "bash": ".sh", "python": ".py"}
ext = ext_map.get(language, ".txt")
filename_pattern = rf"`([A-Za-z0-9_\-]+{re.escape(ext)})`"
fname_match = re.search(filename_pattern, content)
if fname_match:
filename = fname_match.group(1)
return script, filename
async def create_session(
db: AsyncSession,
user_id: UUID,
account_id: UUID,
team_id: UUID | None,
language: str,
initial_prompt: str | None = None,
) -> ScriptBuilderSession:
"""Create a new Script Builder session."""
session = ScriptBuilderSession(
user_id=user_id,
account_id=account_id,
team_id=team_id,
language=language,
)
db.add(session)
await db.flush()
# If initial prompt provided (e.g., from FlowPilot), send first message
if initial_prompt:
await send_message(db, session, initial_prompt)
return session
async def send_message(
db: AsyncSession,
session: ScriptBuilderSession,
user_content: str,
) -> ScriptBuilderMessageResponse:
"""Send a user message and get AI response with generated script."""
# Acquire per-session advisory lock to prevent concurrent message count races.
# Two simultaneous sends to the same session would otherwise both read the same
# count, both pass the limit check, and both insert — exceeding the cap.
session_lock_key = hash(str(session.id)) % (2**62)
await db.execute(text("SELECT pg_advisory_xact_lock(:key)"), {"key": session_lock_key})
# Count existing messages for the session
msg_count_result = await db.execute(
select(func.count(ScriptBuilderMessage.id)).where(
ScriptBuilderMessage.session_id == session.id,
ScriptBuilderMessage.role == "user",
)
)
user_msg_count = msg_count_result.scalar_one()
if user_msg_count >= MAX_MESSAGES_PER_SESSION:
raise ValueError(f"Session has reached the maximum of {MAX_MESSAGES_PER_SESSION} messages.")
now = datetime.now(timezone.utc)
# Create user message record
user_msg = ScriptBuilderMessage(
session_id=session.id,
role="user",
content=user_content,
created_at=now,
)
db.add(user_msg)
await db.flush()
# Build system prompt
language_prompt = LANGUAGE_PROMPTS.get(session.language, LANGUAGE_PROMPTS["powershell"])
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(language_prompt=language_prompt)
# Build conversation for AI — get last 20 messages for context window
recent_result = await db.execute(
select(ScriptBuilderMessage)
.where(ScriptBuilderMessage.session_id == session.id)
.order_by(ScriptBuilderMessage.created_at.desc())
.limit(20)
)
recent_msgs = list(reversed(recent_result.scalars().all()))
ai_messages = [{"role": m.role, "content": m.content} for m in recent_msgs]
# Call AI
model = settings.get_model_for_action("script_build")
provider = get_ai_provider(model=model)
ai_text, input_tokens, output_tokens = await provider.generate_text(
system_prompt=system_prompt,
messages=ai_messages,
max_tokens=8192,
)
# Extract script from response
script, filename = _extract_script_from_response(ai_text, session.language)
line_count = len(script.splitlines()) if script else None
# Create assistant message record
assistant_msg = ScriptBuilderMessage(
session_id=session.id,
role="assistant",
content=ai_text,
script=script,
script_filename=filename,
line_count=line_count,
input_tokens=input_tokens,
output_tokens=output_tokens,
created_at=datetime.now(timezone.utc),
)
db.add(assistant_msg)
# Update session denormalized fields
if script:
session.latest_script = script
session.latest_script_filename = filename
if not session.title:
# Auto-generate title from first user message (truncate)
first_user = user_content[:100]
session.title = first_user if len(user_content) <= 100 else first_user + "..."
session.updated_at = datetime.now(timezone.utc)
await db.flush()
return ScriptBuilderMessageResponse(
role="assistant",
content=ai_text,
script=script,
script_filename=filename,
line_count=line_count,
timestamp=datetime.now(timezone.utc),
)
async def get_session(
db: AsyncSession,
session_id: UUID,
user_id: UUID,
) -> ScriptBuilderSession | None:
"""Get a session by ID, ensuring the user owns it."""
result = await db.execute(
select(ScriptBuilderSession)
.options(selectinload(ScriptBuilderSession.message_records))
.where(
ScriptBuilderSession.id == session_id,
ScriptBuilderSession.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def list_sessions(
db: AsyncSession,
user_id: UUID,
limit: int = 20,
offset: int = 0,
) -> list[ScriptBuilderSession]:
"""List user's builder sessions ordered by updated_at desc."""
result = await db.execute(
select(ScriptBuilderSession)
.where(ScriptBuilderSession.user_id == user_id)
.order_by(ScriptBuilderSession.updated_at.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
async def delete_session(
db: AsyncSession,
session_id: UUID,
user_id: UUID,
) -> bool:
"""Delete a builder session. Returns True if deleted."""
session = await get_session(db, session_id, user_id)
if not session:
return False
await db.delete(session)
await db.flush()
return True
async def count_user_sessions(db: AsyncSession, user_id: UUID) -> int:
"""Count active builder sessions for a user."""
result = await db.execute(
select(func.count(ScriptBuilderSession.id)).where(
ScriptBuilderSession.user_id == user_id
)
)
return result.scalar_one()
async def save_to_library(
db: AsyncSession,
session: ScriptBuilderSession,
name: str,
description: str | None,
category_id: UUID | None,
share_with_team: bool,
user_id: UUID,
team_id: UUID | None,
script_body: str | None = None,
parameters_schema: dict | None = None,
) -> "ScriptTemplate":
"""Save the latest generated script to the Script Library as a ScriptTemplate."""
import uuid as uuid_mod
from app.models.script_template import ScriptTemplate, ScriptCategory
if not session.latest_script:
raise ValueError("No script has been generated in this session yet")
# Resolve category: use provided, or find "AI Generated" default
resolved_category_id = category_id
if not resolved_category_id:
result = await db.execute(
select(ScriptCategory.id).where(ScriptCategory.slug == "ai-generated")
)
default_cat = result.scalar_one_or_none()
if not default_cat:
raise ValueError("Default 'AI Generated' category not found. Run migrations.")
resolved_category_id = default_cat
# Generate slug. Use a UUID suffix on first attempt to prevent concurrent
# saves with the same name from hitting the unique constraint on slug.
base_slug = name.lower().replace(" ", "-").replace("_", "-")[:80]
base_slug = re.sub(r"[^a-z0-9\-]", "", base_slug)
# Check if the base slug is already taken; if not, use it clean (no suffix).
# If taken, or if the insert races with a concurrent request, retry with a
# fresh UUID suffix. The unique constraint on script_templates.slug is the
# authoritative guard — the application check just avoids unnecessary retries.
existing = await db.execute(
select(ScriptTemplate.id).where(ScriptTemplate.slug == base_slug)
)
slug = base_slug if not existing.scalar_one_or_none() else f"{base_slug}-{uuid_mod.uuid4().hex[:6]}"
for attempt in range(3):
template = ScriptTemplate(
id=uuid_mod.uuid4(),
category_id=resolved_category_id,
created_by=user_id,
team_id=team_id if share_with_team else None,
name=name,
slug=slug,
description=description,
script_body=script_body or session.latest_script,
parameters_schema=parameters_schema or {"parameters": []},
default_values={},
validation_rules={},
tags=[session.language, "ai-generated"],
complexity="intermediate",
is_verified=False,
is_active=True,
version=1,
usage_count=0,
)
db.add(template)
try:
await db.flush()
return template
except IntegrityError as exc:
if "uq_script_templates_slug" not in str(exc.orig) or attempt == 2:
raise
await db.rollback()
slug = f"{base_slug}-{uuid_mod.uuid4().hex[:8]}"
raise RuntimeError("Failed to generate a unique slug after 3 attempts")