diff --git a/backend/app/api/endpoints/scripts.py b/backend/app/api/endpoints/scripts.py new file mode 100644 index 00000000..e7ab9158 --- /dev/null +++ b/backend/app/api/endpoints/scripts.py @@ -0,0 +1,356 @@ +"""Script Generator API endpoints.""" +from typing import Annotated, Optional +from uuid import UUID +import re + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func, or_ + +from app.core.database import get_db +from app.api.deps import get_current_active_user +from app.models.user import User +from app.models.script_template import ScriptCategory, ScriptTemplate, ScriptGeneration +from app.schemas.script_template import ( + ScriptCategoryResponse, + ScriptTemplateCreate, + ScriptTemplateUpdate, + ScriptTemplateListItem, + ScriptTemplateDetail, + ScriptGenerateRequest, + ScriptGenerateResponse, + ScriptGenerationRecord, +) +from app.services.script_template_engine import ScriptTemplateEngine, ScriptRenderError + +router = APIRouter(prefix="/scripts", tags=["scripts"]) +_engine = ScriptTemplateEngine() + + +def _require_team_admin(user: User) -> None: + """Raise 403 if user is not a team admin or super admin.""" + if not (user.is_team_admin or user.is_super_admin): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Team admin access required", + ) + + +# ── Categories ──────────────────────────────────────────────────────────── + + +@router.get("/categories", response_model=list[ScriptCategoryResponse]) +async def list_categories( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> list[ScriptCategoryResponse]: + result = await db.execute( + select(ScriptCategory) + .where(ScriptCategory.is_active == True) # noqa: E712 + .order_by(ScriptCategory.sort_order) + ) + categories = result.scalars().all() + + count_result = await db.execute( + select(ScriptTemplate.category_id, func.count(ScriptTemplate.id)) + .where(ScriptTemplate.is_active == True) # noqa: E712 + .group_by(ScriptTemplate.category_id) + ) + counts = dict(count_result.all()) + + return [ + ScriptCategoryResponse( + id=cat.id, + name=cat.name, + slug=cat.slug, + description=cat.description, + icon=cat.icon, + sort_order=cat.sort_order, + template_count=counts.get(cat.id, 0), + ) + for cat in categories + ] + + +# ── Templates ───────────────────────────────────────────────────────────── + + +@router.get("/templates", response_model=list[ScriptTemplateListItem]) +async def list_templates( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + category_slug: Optional[str] = Query(None), + search: Optional[str] = Query(None), + tags: Optional[str] = Query(None, description="Comma-separated tags"), +) -> list[ScriptTemplateListItem]: + query = ( + select(ScriptTemplate) + .join(ScriptCategory, ScriptTemplate.category_id == ScriptCategory.id) + .where(ScriptTemplate.is_active == True) # noqa: E712 + .where( + or_( + ScriptTemplate.team_id == None, # noqa: E711 + ScriptTemplate.team_id == current_user.team_id, + ) + ) + ) + + if category_slug: + query = query.where(ScriptCategory.slug == category_slug) + + if search: + term = f"%{search.lower()}%" + query = query.where( + or_( + func.lower(ScriptTemplate.name).like(term), + func.lower(ScriptTemplate.description).like(term), + func.lower(ScriptTemplate.slug).like(term), + ) + ) + + result = await db.execute(query.order_by(ScriptTemplate.name)) + templates = result.scalars().all() + + if tags: + tag_list = [t.strip().lower() for t in tags.split(",")] + templates = [ + t + for t in templates + if any(tag in [tg.lower() for tg in (t.tags or [])] for tag in tag_list) + ] + + return [ScriptTemplateListItem.model_validate(t) for t in templates] + + +@router.get("/templates/{template_id}", response_model=ScriptTemplateDetail) +async def get_template( + template_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> ScriptTemplateDetail: + result = await db.execute( + select(ScriptTemplate).where( + ScriptTemplate.id == template_id, + ScriptTemplate.is_active == True, # noqa: E712 + or_( + ScriptTemplate.team_id == None, # noqa: E711 + ScriptTemplate.team_id == current_user.team_id, + ), + ) + ) + template = result.scalar_one_or_none() + if not template: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Template not found" + ) + return ScriptTemplateDetail.model_validate(template) + + +@router.post( + "/templates", + response_model=ScriptTemplateDetail, + status_code=status.HTTP_201_CREATED, +) +async def create_template( + data: ScriptTemplateCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> ScriptTemplateDetail: + _require_team_admin(current_user) + + cat_result = await db.execute( + select(ScriptCategory).where( + ScriptCategory.id == data.category_id, + ScriptCategory.is_active == True, # noqa: E712 + ) + ) + if not cat_result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Category not found" + ) + + slug = re.sub(r"[^a-z0-9]+", "-", data.name.lower()).strip("-") + + template = ScriptTemplate( + category_id=data.category_id, + team_id=current_user.team_id, + created_by=current_user.id, + name=data.name, + slug=slug, + description=data.description, + use_case=data.use_case, + script_body=data.script_body, + parameters_schema=data.parameters_schema, + default_values=data.default_values, + validation_rules=data.validation_rules, + tags=data.tags, + complexity=data.complexity, + estimated_runtime=data.estimated_runtime, + requires_elevation=data.requires_elevation, + requires_modules=data.requires_modules, + ) + db.add(template) + await db.commit() + await db.refresh(template) + return ScriptTemplateDetail.model_validate(template) + + +@router.put("/templates/{template_id}", response_model=ScriptTemplateDetail) +async def update_template( + template_id: UUID, + data: ScriptTemplateUpdate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> ScriptTemplateDetail: + _require_team_admin(current_user) + + result = await db.execute( + select(ScriptTemplate).where( + ScriptTemplate.id == template_id, + ScriptTemplate.team_id == current_user.team_id, + ) + ) + template = result.scalar_one_or_none() + if not template: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Template not found or not editable", + ) + + update_data = data.model_dump(exclude_unset=True) + if "script_body" in update_data or "parameters_schema" in update_data: + template.version += 1 + for field, value in update_data.items(): + setattr(template, field, value) + + await db.commit() + await db.refresh(template) + return ScriptTemplateDetail.model_validate(template) + + +@router.delete("/templates/{template_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_template( + template_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> None: + _require_team_admin(current_user) + + result = await db.execute( + select(ScriptTemplate).where( + ScriptTemplate.id == template_id, + ScriptTemplate.team_id == current_user.team_id, + ) + ) + template = result.scalar_one_or_none() + if not template: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Template not found or not deletable", + ) + + template.is_active = False + await db.commit() + + +# ── Generate ────────────────────────────────────────────────────────────── + + +@router.post("/generate", response_model=ScriptGenerateResponse) +async def generate_script( + data: ScriptGenerateRequest, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> ScriptGenerateResponse: + result = await db.execute( + select(ScriptTemplate).where( + ScriptTemplate.id == data.template_id, + ScriptTemplate.is_active == True, # noqa: E712 + or_( + ScriptTemplate.team_id == None, # noqa: E711 + ScriptTemplate.team_id == current_user.team_id, + ), + ) + ) + template = result.scalar_one_or_none() + if not template: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Template not found" + ) + + try: + rendered_script = _engine.render(template.script_body, data.parameters) + except ScriptRenderError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e) + ) + + params_schema = template.parameters_schema or {} + sensitive_keys = { + p["key"] + for p in params_schema.get("parameters", []) + if p.get("sensitive", False) + } + redacted_params = _engine.redact_sensitive(data.parameters, sensitive_keys) + + generation = ScriptGeneration( + template_id=template.id, + user_id=current_user.id, + team_id=current_user.team_id, + session_id=data.session_id, + parameters_used=redacted_params, + generated_script=rendered_script, + ) + db.add(generation) + template.usage_count += 1 + await db.commit() + await db.refresh(generation) + + warnings: list[str] = [] + if template.requires_elevation: + warnings.append("This script requires 'Run as Administrator'") + + return ScriptGenerateResponse( + id=generation.id, + script=rendered_script, + warnings=warnings, + metadata={ + "template_name": template.name, + "template_version": template.version, + "requires_elevation": template.requires_elevation, + "requires_modules": template.requires_modules, + "generated_at": generation.created_at.isoformat(), + "estimated_runtime": template.estimated_runtime, + }, + ) + + +# ── Generations history ─────────────────────────────────────────────────── + + +@router.get("/generations", response_model=list[ScriptGenerationRecord]) +async def list_generations( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + limit: int = Query(20, ge=1, le=100), + offset: int = Query(0, ge=0), +) -> list[ScriptGenerationRecord]: + result = await db.execute( + select(ScriptGeneration, ScriptTemplate.name) + .join(ScriptTemplate, ScriptGeneration.template_id == ScriptTemplate.id) + .where(ScriptGeneration.user_id == current_user.id) + .order_by(ScriptGeneration.created_at.desc()) + .limit(limit) + .offset(offset) + ) + rows = result.all() + return [ + ScriptGenerationRecord( + id=gen.id, + template_id=gen.template_id, + template_name=name, + parameters_used=gen.parameters_used, + created_at=gen.created_at, + ) + for gen, name in rows + ] diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 5e789ff9..13293cca 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -16,6 +16,7 @@ from app.api.endpoints import tree_transfer from app.api.endpoints import ai_suggestions from app.api.endpoints import kb_accelerator from app.api.endpoints import beta_signup +from app.api.endpoints import scripts api_router = APIRouter() @@ -56,3 +57,4 @@ api_router.include_router(tree_transfer.router) api_router.include_router(ai_suggestions.router) api_router.include_router(kb_accelerator.router) api_router.include_router(beta_signup.router) +api_router.include_router(scripts.router) diff --git a/backend/tests/test_scripts.py b/backend/tests/test_scripts.py new file mode 100644 index 00000000..544ad74e --- /dev/null +++ b/backend/tests/test_scripts.py @@ -0,0 +1,336 @@ +"""Integration tests for Script Generator API endpoints.""" +import json +import uuid +from datetime import datetime, timezone + +import pytest +import sqlalchemy as sa + + +# ── Fixtures ────────────────────────────────────────────────────────────── + +@pytest.fixture +async def seed_script_data(test_db): + """Seed script categories and templates into the test database.""" + now = datetime.now(timezone.utc) + cat_id = uuid.UUID("00000000-0000-0000-0000-000000000001") + + # Insert category + await test_db.execute( + sa.text(""" + INSERT INTO script_categories (id, name, slug, description, icon, sort_order, is_active, created_at, updated_at) + VALUES (:id, :name, :slug, :description, :icon, :sort_order, true, :now, :now) + """), + { + "id": cat_id, + "name": "Active Directory", + "slug": "active-directory", + "description": "User account and group management scripts", + "icon": "shield-check", + "sort_order": 1, + "now": now, + }, + ) + + # Minimal template data for testing + templates = [ + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000001"), + "slug": "create-ad-user", + "name": "Create AD User Account", + "description": "Creates a new Active Directory user account.", + "script_body": "$SamAccountName = '{{ sam_account_name }}'", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "sam_account_name", "label": "SAM Account Name", "type": "text", "required": True, "order": 1}, + ] + }), + "complexity": "intermediate", + "estimated_runtime": "< 5 seconds", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "user-management"]), + }, + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000002"), + "slug": "disable-ad-user", + "name": "Disable AD User Account", + "description": "Disables an Active Directory user account.", + "script_body": "$SamAccountName = '{{ sam_account_name }}'", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "sam_account_name", "label": "SAM Account Name", "type": "text", "required": True, "order": 1}, + ] + }), + "complexity": "beginner", + "estimated_runtime": "< 5 seconds", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "offboarding"]), + }, + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000003"), + "slug": "reset-ad-password", + "name": "Reset AD Password", + "description": "Resets an Active Directory user password.", + "script_body": "$SamAccountName = '{{ sam_account_name }}'\n$NewPassword = {{ new_password | as_secure_string }}\n$ForceChange = {{ force_change_at_logon | as_bool }}\n$UnlockAccount = {{ unlock_account | as_bool }}", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "sam_account_name", "label": "SAM Account Name", "type": "text", "required": True, "order": 1}, + {"key": "new_password", "label": "New Password", "type": "password", "required": True, "order": 2, "sensitive": True}, + {"key": "force_change_at_logon", "label": "Force Change at Next Logon", "type": "boolean", "required": True, "order": 3}, + {"key": "unlock_account", "label": "Unlock Account if Locked", "type": "boolean", "required": True, "order": 4}, + ] + }), + "complexity": "beginner", + "estimated_runtime": "< 5 seconds", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "password"]), + }, + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000004"), + "slug": "unlock-ad-account", + "name": "Unlock AD Account", + "description": "Unlocks a locked-out Active Directory user account.", + "script_body": "$SamAccountName = '{{ sam_account_name }}'\n$ShowLockoutInfo = {{ show_lockout_info | as_bool }}", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "sam_account_name", "label": "SAM Account Name", "type": "text", "required": True, "order": 1}, + {"key": "show_lockout_info", "label": "Show Lockout Source Info", "type": "boolean", "required": False, "order": 2}, + ] + }), + "complexity": "beginner", + "estimated_runtime": "< 5 seconds", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "lockout"]), + }, + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000005"), + "slug": "delete-ad-user", + "name": "Delete AD User Account", + "description": "Permanently deletes an Active Directory user account.", + "script_body": "$SamAccountName = '{{ sam_account_name }}'\n$ConfirmDeletion = {{ confirm_deletion | as_bool }}", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "sam_account_name", "label": "SAM Account Name", "type": "text", "required": True, "order": 1}, + {"key": "confirm_deletion", "label": "Confirm Deletion", "type": "boolean", "required": True, "order": 2}, + ] + }), + "complexity": "advanced", + "estimated_runtime": "< 10 seconds", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "destructive"]), + }, + { + "id": uuid.UUID("00000000-0000-0000-0001-000000000006"), + "slug": "bulk-user-import", + "name": "Bulk User Import from CSV", + "description": "Imports multiple Active Directory user accounts from a CSV file.", + "script_body": "$CSVPath = '{{ csv_path }}'\n$OUPath = '{{ ou_path }}'", + "parameters_schema": json.dumps({ + "parameters": [ + {"key": "csv_path", "label": "CSV File Path", "type": "text", "required": True, "order": 1}, + {"key": "ou_path", "label": "Target OU", "type": "text", "required": True, "order": 2}, + ] + }), + "complexity": "advanced", + "estimated_runtime": "1-2 minutes", + "requires_elevation": True, + "tags": json.dumps(["active-directory", "bulk"]), + }, + ] + + for tmpl in templates: + await test_db.execute( + sa.text(""" + INSERT INTO script_templates ( + id, category_id, name, slug, description, + script_body, parameters_schema, default_values, validation_rules, + tags, complexity, estimated_runtime, requires_elevation, + requires_modules, version, is_verified, is_active, usage_count, + created_at, updated_at + ) VALUES ( + :id, :category_id, :name, :slug, :description, + :script_body, CAST(:parameters_schema AS jsonb), '{}'::jsonb, '{}'::jsonb, + CAST(:tags AS jsonb), :complexity, :estimated_runtime, :requires_elevation, + '[]'::jsonb, 1, true, true, 0, + :now, :now + ) + """), + {**tmpl, "category_id": cat_id, "now": now}, + ) + + await test_db.commit() + return cat_id + + +# ── Categories ──────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_list_categories_requires_auth(client): + response = await client.get("/api/v1/scripts/categories") + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_list_categories_returns_seeded_data(client, auth_headers, seed_script_data): + response = await client.get("/api/v1/scripts/categories", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert any(c["slug"] == "active-directory" for c in data) + + +# ── Templates ───────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_list_templates_requires_auth(client): + response = await client.get("/api/v1/scripts/templates") + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_list_templates_returns_seeded_data(client, auth_headers, seed_script_data): + response = await client.get("/api/v1/scripts/templates", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert len(data) == 6 + slugs = [t["slug"] for t in data] + assert "create-ad-user" in slugs + assert "reset-ad-password" in slugs + + +@pytest.mark.asyncio +async def test_list_templates_filter_by_category(client, auth_headers, seed_script_data): + response = await client.get( + "/api/v1/scripts/templates?category_slug=active-directory", + headers=auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 6 + + +@pytest.mark.asyncio +async def test_list_templates_search(client, auth_headers, seed_script_data): + response = await client.get( + "/api/v1/scripts/templates?search=password", + headers=auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert any("password" in t["name"].lower() or "password" in t["slug"] for t in data) + + +@pytest.mark.asyncio +async def test_get_template_detail(client, auth_headers, seed_script_data): + list_resp = await client.get("/api/v1/scripts/templates", headers=auth_headers) + templates = list_resp.json() + template_id = templates[0]["id"] + + response = await client.get(f"/api/v1/scripts/templates/{template_id}", headers=auth_headers) + assert response.status_code == 200 + data = response.json() + assert "script_body" in data + assert "parameters_schema" in data + + +@pytest.mark.asyncio +async def test_get_template_detail_not_found(client, auth_headers): + response = await client.get( + "/api/v1/scripts/templates/00000000-0000-0000-0000-000000000099", + headers=auth_headers, + ) + assert response.status_code == 404 + + +# ── Generate ────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_generate_script_success(client, auth_headers, seed_script_data): + list_resp = await client.get( + "/api/v1/scripts/templates?search=unlock", + headers=auth_headers, + ) + unlock_template = list_resp.json()[0] + + response = await client.post( + "/api/v1/scripts/generate", + json={ + "template_id": unlock_template["id"], + "parameters": {"sam_account_name": "jsmith", "show_lockout_info": False}, + }, + headers=auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "script" in data + assert "jsmith" in data["script"] + assert "id" in data + + +@pytest.mark.asyncio +async def test_generate_script_missing_required_param(client, auth_headers, seed_script_data): + list_resp = await client.get( + "/api/v1/scripts/templates?search=unlock", + headers=auth_headers, + ) + unlock_template = list_resp.json()[0] + + response = await client.post( + "/api/v1/scripts/generate", + json={ + "template_id": unlock_template["id"], + "parameters": {}, + }, + headers=auth_headers, + ) + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_generate_script_password_redacted_in_record(client, auth_headers, seed_script_data): + list_resp = await client.get( + "/api/v1/scripts/templates?search=reset-ad-password", + headers=auth_headers, + ) + reset_template = list_resp.json()[0] + + await client.post( + "/api/v1/scripts/generate", + json={ + "template_id": reset_template["id"], + "parameters": { + "sam_account_name": "jsmith", + "new_password": "SuperSecret123!", + "force_change_at_logon": True, + "unlock_account": True, + }, + }, + headers=auth_headers, + ) + + history_resp = await client.get("/api/v1/scripts/generations", headers=auth_headers) + assert history_resp.status_code == 200 + generations = history_resp.json() + assert len(generations) > 0 + latest = generations[0] + assert latest["parameters_used"].get("new_password") == "[REDACTED]" + + +# ── Team template CRUD ──────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_create_team_template_requires_team_admin(client, auth_headers, seed_script_data): + list_resp = await client.get("/api/v1/scripts/categories", headers=auth_headers) + cat_id = list_resp.json()[0]["id"] + + response = await client.post( + "/api/v1/scripts/templates", + json={ + "category_id": cat_id, + "name": "My Custom Script", + "script_body": "Write-Host 'hello'", + "parameters_schema": {}, + }, + headers=auth_headers, # regular engineer + ) + assert response.status_code == 403