"""Script Generator API endpoints.""" from typing import Annotated, Optional from uuid import UUID import re from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func, or_ from app.core.database import get_db from app.api.deps import get_current_active_user from app.models.user import User from app.models.script_template import ScriptCategory, ScriptTemplate, ScriptGeneration from app.schemas.script_template import ( ScriptCategoryResponse, ScriptTemplateCreate, ScriptTemplateUpdate, ScriptTemplateListItem, ScriptTemplateDetail, ScriptGenerateRequest, ScriptGenerateResponse, ScriptGenerationRecord, ) from app.services.script_template_engine import ScriptTemplateEngine, ScriptRenderError router = APIRouter(prefix="/scripts", tags=["scripts"]) _engine = ScriptTemplateEngine() def _require_team_admin(user: User) -> None: """Raise 403 if user is not a team admin or super admin.""" if not (user.is_team_admin or user.is_super_admin): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Team admin access required", ) # ── Categories ──────────────────────────────────────────────────────────── @router.get("/categories", response_model=list[ScriptCategoryResponse]) async def list_categories( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> list[ScriptCategoryResponse]: result = await db.execute( select(ScriptCategory) .where(ScriptCategory.is_active == True) # noqa: E712 .order_by(ScriptCategory.sort_order) ) categories = result.scalars().all() count_result = await db.execute( select(ScriptTemplate.category_id, func.count(ScriptTemplate.id)) .where(ScriptTemplate.is_active == True) # noqa: E712 .group_by(ScriptTemplate.category_id) ) counts = dict(count_result.all()) return [ ScriptCategoryResponse( id=cat.id, name=cat.name, slug=cat.slug, description=cat.description, icon=cat.icon, sort_order=cat.sort_order, template_count=counts.get(cat.id, 0), ) for cat in categories ] # ── Templates ───────────────────────────────────────────────────────────── @router.get("/templates", response_model=list[ScriptTemplateListItem]) async def list_templates( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], category_slug: Optional[str] = Query(None), search: Optional[str] = Query(None), tags: Optional[str] = Query(None, description="Comma-separated tags"), ) -> list[ScriptTemplateListItem]: query = ( select(ScriptTemplate) .join(ScriptCategory, ScriptTemplate.category_id == ScriptCategory.id) .where(ScriptTemplate.is_active == True) # noqa: E712 .where( or_( ScriptTemplate.team_id == None, # noqa: E711 ScriptTemplate.team_id == current_user.team_id, ) ) ) if category_slug: query = query.where(ScriptCategory.slug == category_slug) if search: term = f"%{search.lower()}%" query = query.where( or_( func.lower(ScriptTemplate.name).like(term), func.lower(ScriptTemplate.description).like(term), func.lower(ScriptTemplate.slug).like(term), ) ) result = await db.execute(query.order_by(ScriptTemplate.name)) templates = result.scalars().all() if tags: tag_list = [t.strip().lower() for t in tags.split(",")] templates = [ t for t in templates if any(tag in [tg.lower() for tg in (t.tags or [])] for tag in tag_list) ] return [ScriptTemplateListItem.model_validate(t) for t in templates] @router.get("/templates/{template_id}", response_model=ScriptTemplateDetail) async def get_template( template_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> ScriptTemplateDetail: result = await db.execute( select(ScriptTemplate).where( ScriptTemplate.id == template_id, ScriptTemplate.is_active == True, # noqa: E712 or_( ScriptTemplate.team_id == None, # noqa: E711 ScriptTemplate.team_id == current_user.team_id, ), ) ) template = result.scalar_one_or_none() if not template: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Template not found" ) return ScriptTemplateDetail.model_validate(template) @router.post( "/templates", response_model=ScriptTemplateDetail, status_code=status.HTTP_201_CREATED, ) async def create_template( data: ScriptTemplateCreate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> ScriptTemplateDetail: _require_team_admin(current_user) cat_result = await db.execute( select(ScriptCategory).where( ScriptCategory.id == data.category_id, ScriptCategory.is_active == True, # noqa: E712 ) ) if not cat_result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Category not found" ) slug = re.sub(r"[^a-z0-9]+", "-", data.name.lower()).strip("-") template = ScriptTemplate( category_id=data.category_id, team_id=current_user.team_id, created_by=current_user.id, name=data.name, slug=slug, description=data.description, use_case=data.use_case, script_body=data.script_body, parameters_schema=data.parameters_schema, default_values=data.default_values, validation_rules=data.validation_rules, tags=data.tags, complexity=data.complexity, estimated_runtime=data.estimated_runtime, requires_elevation=data.requires_elevation, requires_modules=data.requires_modules, ) db.add(template) await db.commit() await db.refresh(template) return ScriptTemplateDetail.model_validate(template) @router.put("/templates/{template_id}", response_model=ScriptTemplateDetail) async def update_template( template_id: UUID, data: ScriptTemplateUpdate, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> ScriptTemplateDetail: _require_team_admin(current_user) result = await db.execute( select(ScriptTemplate).where( ScriptTemplate.id == template_id, ScriptTemplate.team_id == current_user.team_id, ) ) template = result.scalar_one_or_none() if not template: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Template not found or not editable", ) update_data = data.model_dump(exclude_unset=True) if "script_body" in update_data or "parameters_schema" in update_data: template.version += 1 for field, value in update_data.items(): setattr(template, field, value) await db.commit() await db.refresh(template) return ScriptTemplateDetail.model_validate(template) @router.delete("/templates/{template_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_template( template_id: UUID, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> None: _require_team_admin(current_user) result = await db.execute( select(ScriptTemplate).where( ScriptTemplate.id == template_id, ScriptTemplate.team_id == current_user.team_id, ) ) template = result.scalar_one_or_none() if not template: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Template not found or not deletable", ) template.is_active = False await db.commit() # ── Generate ────────────────────────────────────────────────────────────── @router.post("/generate", response_model=ScriptGenerateResponse) async def generate_script( data: ScriptGenerateRequest, db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> ScriptGenerateResponse: result = await db.execute( select(ScriptTemplate).where( ScriptTemplate.id == data.template_id, ScriptTemplate.is_active == True, # noqa: E712 or_( ScriptTemplate.team_id == None, # noqa: E711 ScriptTemplate.team_id == current_user.team_id, ), ) ) template = result.scalar_one_or_none() if not template: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Template not found" ) try: rendered_script = _engine.render(template.script_body, data.parameters) except ScriptRenderError as e: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e) ) params_schema = template.parameters_schema or {} sensitive_keys = { p["key"] for p in params_schema.get("parameters", []) if p.get("sensitive", False) } redacted_params = _engine.redact_sensitive(data.parameters, sensitive_keys) generation = ScriptGeneration( template_id=template.id, user_id=current_user.id, team_id=current_user.team_id, session_id=data.session_id, parameters_used=redacted_params, generated_script=rendered_script, ) db.add(generation) template.usage_count += 1 await db.commit() await db.refresh(generation) warnings: list[str] = [] if template.requires_elevation: warnings.append("This script requires 'Run as Administrator'") return ScriptGenerateResponse( id=generation.id, script=rendered_script, warnings=warnings, metadata={ "template_name": template.name, "template_version": template.version, "requires_elevation": template.requires_elevation, "requires_modules": template.requires_modules, "generated_at": generation.created_at.isoformat(), "estimated_runtime": template.estimated_runtime, }, ) # ── Generations history ─────────────────────────────────────────────────── @router.get("/generations", response_model=list[ScriptGenerationRecord]) async def list_generations( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), ) -> list[ScriptGenerationRecord]: result = await db.execute( select(ScriptGeneration, ScriptTemplate.name) .join(ScriptTemplate, ScriptGeneration.template_id == ScriptTemplate.id) .where(ScriptGeneration.user_id == current_user.id) .order_by(ScriptGeneration.created_at.desc()) .limit(limit) .offset(offset) ) rows = result.all() return [ ScriptGenerationRecord( id=gen.id, template_id=gen.template_id, template_name=name, parameters_used=gen.parameters_used, created_at=gen.created_at, ) for gen, name in rows ]