274 lines
10 KiB
Python
274 lines
10 KiB
Python
"""Script Builder API endpoints."""
|
|
from typing import Annotated
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy import select, text
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.database import get_db
|
|
from app.core.rate_limit import limiter
|
|
from app.api.deps import get_current_active_user
|
|
from app.models.ai_session import AISession
|
|
from app.models.user import User
|
|
from app.models.script_builder_session import ScriptBuilderSession
|
|
from app.schemas.script_builder import (
|
|
ScriptBuilderCreateRequest,
|
|
ScriptBuilderMessageRequest,
|
|
ScriptBuilderMessageResponse,
|
|
ScriptBuilderMessageSchema,
|
|
ScriptBuilderSessionDetail,
|
|
ScriptBuilderSessionSummary,
|
|
SaveToLibraryRequest,
|
|
)
|
|
from app.schemas.script_template import ScriptTemplateDetail
|
|
from app.services import script_builder_service
|
|
|
|
router = APIRouter(prefix="/scripts/builder", tags=["script-builder"])
|
|
|
|
MAX_SESSIONS_PER_USER = 5
|
|
|
|
|
|
def _session_to_detail(session: ScriptBuilderSession) -> ScriptBuilderSessionDetail:
|
|
"""Convert a session ORM object (with message_records loaded) to detail schema."""
|
|
messages = [
|
|
ScriptBuilderMessageSchema.model_validate(m)
|
|
for m in session.message_records
|
|
]
|
|
return ScriptBuilderSessionDetail(
|
|
id=session.id,
|
|
language=session.language,
|
|
title=session.title,
|
|
messages=messages,
|
|
latest_script=session.latest_script,
|
|
latest_script_filename=session.latest_script_filename,
|
|
message_count=len([m for m in messages if m.role == "user"]),
|
|
ai_session_id=session.ai_session_id,
|
|
created_at=session.created_at,
|
|
updated_at=session.updated_at,
|
|
)
|
|
|
|
|
|
def _session_to_summary(session: ScriptBuilderSession) -> ScriptBuilderSessionSummary:
|
|
"""Convert a session ORM object to summary schema (no messages needed)."""
|
|
return ScriptBuilderSessionSummary(
|
|
id=session.id,
|
|
language=session.language,
|
|
title=session.title,
|
|
message_count=0, # Summary doesn't eagerly load messages
|
|
latest_script_filename=session.latest_script_filename,
|
|
created_at=session.created_at,
|
|
updated_at=session.updated_at,
|
|
)
|
|
|
|
|
|
@router.post("/sessions", response_model=ScriptBuilderSessionDetail, status_code=201)
|
|
async def create_session(
|
|
data: ScriptBuilderCreateRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> ScriptBuilderSessionDetail:
|
|
"""Start a new Script Builder session.
|
|
|
|
When origin='pilot_inline', behaves as get-or-create: the same row is
|
|
returned on repeated calls with the same (user, ai_session_id) pair.
|
|
Inline sessions are excluded from the session cap and the list endpoint.
|
|
"""
|
|
# Phase 9: inline origin validation + authorization
|
|
if data.origin == "pilot_inline":
|
|
if data.ai_session_id is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="ai_session_id is required when origin='pilot_inline'",
|
|
)
|
|
# Ownership check: the pilot session must belong to the current user.
|
|
ai_session = await db.scalar(
|
|
select(AISession).where(
|
|
AISession.id == data.ai_session_id,
|
|
AISession.user_id == current_user.id,
|
|
)
|
|
)
|
|
if ai_session is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Session not found",
|
|
)
|
|
|
|
# Idempotent get-or-create: if a pilot_inline row already exists for
|
|
# this (user, ai_session_id) pair, return it without creating a duplicate.
|
|
existing = await db.scalar(
|
|
select(ScriptBuilderSession).where(
|
|
ScriptBuilderSession.user_id == current_user.id,
|
|
ScriptBuilderSession.ai_session_id == data.ai_session_id,
|
|
ScriptBuilderSession.origin == "pilot_inline",
|
|
)
|
|
)
|
|
if existing is not None:
|
|
# Re-fetch with message_records loaded
|
|
session = await script_builder_service.get_session(db, existing.id, current_user.id)
|
|
return _session_to_detail(session)
|
|
|
|
# Create the inline session — wrap in IntegrityError catch for races.
|
|
try:
|
|
session = await script_builder_service.create_session(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
team_id=current_user.team_id,
|
|
language=data.language,
|
|
origin=data.origin,
|
|
ai_session_id=data.ai_session_id,
|
|
)
|
|
await db.commit()
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
# Race: another request won the unique index — re-read the winner row.
|
|
existing = await db.scalar(
|
|
select(ScriptBuilderSession).where(
|
|
ScriptBuilderSession.user_id == current_user.id,
|
|
ScriptBuilderSession.ai_session_id == data.ai_session_id,
|
|
ScriptBuilderSession.origin == "pilot_inline",
|
|
)
|
|
)
|
|
if existing is None:
|
|
raise
|
|
session = existing
|
|
|
|
# Re-fetch with message_records loaded
|
|
session = await script_builder_service.get_session(db, session.id, current_user.id)
|
|
return _session_to_detail(session)
|
|
|
|
# ── Standalone session ──────────────────────────────────────────────────
|
|
# Acquire per-user advisory lock so concurrent create requests are serialized.
|
|
# Without this, two simultaneous requests both read count < limit and both
|
|
# insert, exceeding MAX_SESSIONS_PER_USER.
|
|
user_lock_key = hash(str(current_user.id)) % (2**62)
|
|
await db.execute(text("SELECT pg_advisory_xact_lock(:key)"), {"key": user_lock_key})
|
|
|
|
# Enforce max concurrent sessions (inline sessions excluded from cap)
|
|
count = await script_builder_service.count_user_sessions(db, current_user.id, include_inline=False)
|
|
if count >= MAX_SESSIONS_PER_USER:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Maximum of {MAX_SESSIONS_PER_USER} builder sessions allowed. Delete an old session first.",
|
|
)
|
|
|
|
session = await script_builder_service.create_session(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
team_id=current_user.team_id,
|
|
language=data.language,
|
|
origin=data.origin,
|
|
ai_session_id=data.ai_session_id,
|
|
)
|
|
await db.commit()
|
|
# Re-fetch with message_records loaded
|
|
session = await script_builder_service.get_session(db, session.id, current_user.id)
|
|
return _session_to_detail(session)
|
|
|
|
|
|
@router.get("/sessions", response_model=list[ScriptBuilderSessionSummary])
|
|
async def list_sessions(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
) -> list[ScriptBuilderSessionSummary]:
|
|
"""List user's recent builder sessions (lightweight, no messages)."""
|
|
sessions = await script_builder_service.list_sessions(
|
|
db=db, user_id=current_user.id, limit=limit, offset=offset
|
|
)
|
|
return [_session_to_summary(s) for s in sessions]
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=ScriptBuilderSessionDetail)
|
|
async def get_session(
|
|
session_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> ScriptBuilderSessionDetail:
|
|
"""Get full session detail with message history."""
|
|
session = await script_builder_service.get_session(db, session_id, current_user.id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return _session_to_detail(session)
|
|
|
|
|
|
@router.post(
|
|
"/sessions/{session_id}/messages",
|
|
response_model=ScriptBuilderMessageResponse,
|
|
)
|
|
@limiter.limit("10/minute")
|
|
async def send_message(
|
|
request: Request,
|
|
session_id: UUID,
|
|
data: ScriptBuilderMessageRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> ScriptBuilderMessageResponse:
|
|
"""Send a message and get AI-generated script response."""
|
|
session = await script_builder_service.get_session(db, session_id, current_user.id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
try:
|
|
response = await script_builder_service.send_message(db, session, data.content)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
await db.commit()
|
|
return response
|
|
|
|
|
|
@router.delete("/sessions/{session_id}", status_code=204)
|
|
async def delete_session(
|
|
session_id: UUID,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> None:
|
|
"""Delete a builder session."""
|
|
deleted = await script_builder_service.delete_session(db, session_id, current_user.id)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
await db.commit()
|
|
|
|
|
|
@router.post(
|
|
"/sessions/{session_id}/save",
|
|
response_model=ScriptTemplateDetail,
|
|
status_code=201,
|
|
)
|
|
async def save_to_library(
|
|
session_id: UUID,
|
|
data: SaveToLibraryRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> ScriptTemplateDetail:
|
|
"""Save the latest generated script to the Script Library."""
|
|
session = await script_builder_service.get_session(db, session_id, current_user.id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
try:
|
|
template = await script_builder_service.save_to_library(
|
|
db=db,
|
|
session=session,
|
|
name=data.name,
|
|
description=data.description,
|
|
category_id=data.category_id,
|
|
share_with_team=data.share_with_team,
|
|
user_id=current_user.id,
|
|
account_id=current_user.account_id,
|
|
team_id=current_user.team_id,
|
|
script_body=data.script_body,
|
|
parameters_schema=data.parameters_schema,
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
await db.commit()
|
|
await db.refresh(template)
|
|
return ScriptTemplateDetail.model_validate(template)
|