feat(pilot): inline Script Builder session — idempotent create + auth + filtered list
POST /script-builder/sessions now supports origin='pilot_inline': - Requires ai_session_id; validates it against current user ownership. - Get-or-create: returns existing row for (user, ai_session_id) pair. - Partial unique index on the DB backs the invariant; races resolve to the single winner row. list_sessions + count_user_sessions default-scope to origin='standalone' so inline scratch sessions don't pollute the /script-builder dashboard or count against the 5-session cap. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,12 +3,14 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.user import User
|
||||
from app.models.script_builder_session import ScriptBuilderSession
|
||||
from app.schemas.script_builder import (
|
||||
@@ -67,15 +69,85 @@ async def create_session(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> ScriptBuilderSessionDetail:
|
||||
"""Start a new Script Builder session."""
|
||||
"""Start a new Script Builder session.
|
||||
|
||||
When origin='pilot_inline', behaves as get-or-create: the same row is
|
||||
returned on repeated calls with the same (user, ai_session_id) pair.
|
||||
Inline sessions are excluded from the session cap and the list endpoint.
|
||||
"""
|
||||
# Phase 9: inline origin validation + authorization
|
||||
if data.origin == "pilot_inline":
|
||||
if data.ai_session_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="ai_session_id is required when origin='pilot_inline'",
|
||||
)
|
||||
# Ownership check: the pilot session must belong to the current user.
|
||||
ai_session = await db.scalar(
|
||||
select(AISession).where(
|
||||
AISession.id == data.ai_session_id,
|
||||
AISession.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if ai_session is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Session not found",
|
||||
)
|
||||
|
||||
# Idempotent get-or-create: if a pilot_inline row already exists for
|
||||
# this (user, ai_session_id) pair, return it without creating a duplicate.
|
||||
existing = await db.scalar(
|
||||
select(ScriptBuilderSession).where(
|
||||
ScriptBuilderSession.user_id == current_user.id,
|
||||
ScriptBuilderSession.ai_session_id == data.ai_session_id,
|
||||
ScriptBuilderSession.origin == "pilot_inline",
|
||||
)
|
||||
)
|
||||
if existing is not None:
|
||||
# Re-fetch with message_records loaded
|
||||
session = await script_builder_service.get_session(db, existing.id, current_user.id)
|
||||
return _session_to_detail(session)
|
||||
|
||||
# Create the inline session — wrap in IntegrityError catch for races.
|
||||
try:
|
||||
session = await script_builder_service.create_session(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
language=data.language,
|
||||
origin=data.origin,
|
||||
ai_session_id=data.ai_session_id,
|
||||
)
|
||||
await db.commit()
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
# Race: another request won the unique index — re-read the winner row.
|
||||
existing = await db.scalar(
|
||||
select(ScriptBuilderSession).where(
|
||||
ScriptBuilderSession.user_id == current_user.id,
|
||||
ScriptBuilderSession.ai_session_id == data.ai_session_id,
|
||||
ScriptBuilderSession.origin == "pilot_inline",
|
||||
)
|
||||
)
|
||||
if existing is None:
|
||||
raise
|
||||
session = existing
|
||||
|
||||
# Re-fetch with message_records loaded
|
||||
session = await script_builder_service.get_session(db, session.id, current_user.id)
|
||||
return _session_to_detail(session)
|
||||
|
||||
# ── Standalone session ──────────────────────────────────────────────────
|
||||
# Acquire per-user advisory lock so concurrent create requests are serialized.
|
||||
# Without this, two simultaneous requests both read count < limit and both
|
||||
# insert, exceeding MAX_SESSIONS_PER_USER.
|
||||
user_lock_key = hash(str(current_user.id)) % (2**62)
|
||||
await db.execute(text("SELECT pg_advisory_xact_lock(:key)"), {"key": user_lock_key})
|
||||
|
||||
# Enforce max concurrent sessions
|
||||
count = await script_builder_service.count_user_sessions(db, current_user.id)
|
||||
# Enforce max concurrent sessions (inline sessions excluded from cap)
|
||||
count = await script_builder_service.count_user_sessions(db, current_user.id, include_inline=False)
|
||||
if count >= MAX_SESSIONS_PER_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -88,6 +160,8 @@ async def create_session(
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
language=data.language,
|
||||
origin=data.origin,
|
||||
ai_session_id=data.ai_session_id,
|
||||
)
|
||||
await db.commit()
|
||||
# Re-fetch with message_records loaded
|
||||
|
||||
@@ -148,6 +148,8 @@ async def create_session(
|
||||
team_id: UUID | None,
|
||||
language: str,
|
||||
initial_prompt: str | None = None,
|
||||
origin: str = "standalone",
|
||||
ai_session_id: UUID | None = None,
|
||||
) -> ScriptBuilderSession:
|
||||
"""Create a new Script Builder session."""
|
||||
session = ScriptBuilderSession(
|
||||
@@ -155,6 +157,8 @@ async def create_session(
|
||||
account_id=account_id,
|
||||
team_id=team_id,
|
||||
language=language,
|
||||
origin=origin,
|
||||
ai_session_id=ai_session_id,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
@@ -295,15 +299,22 @@ async def list_sessions(
|
||||
user_id: UUID,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
*,
|
||||
include_inline: bool = False,
|
||||
) -> list[ScriptBuilderSession]:
|
||||
"""List user's builder sessions ordered by updated_at desc."""
|
||||
result = await db.execute(
|
||||
"""List user's builder sessions ordered by updated_at desc.
|
||||
|
||||
By default (include_inline=False) excludes pilot_inline sessions so the
|
||||
/script-builder dashboard only shows standalone sessions.
|
||||
"""
|
||||
stmt = (
|
||||
select(ScriptBuilderSession)
|
||||
.where(ScriptBuilderSession.user_id == user_id)
|
||||
.order_by(ScriptBuilderSession.updated_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
if not include_inline:
|
||||
stmt = stmt.where(ScriptBuilderSession.origin == "standalone")
|
||||
stmt = stmt.order_by(ScriptBuilderSession.updated_at.desc()).limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@@ -321,13 +332,23 @@ async def delete_session(
|
||||
return True
|
||||
|
||||
|
||||
async def count_user_sessions(db: AsyncSession, user_id: UUID) -> int:
|
||||
"""Count active builder sessions for a user."""
|
||||
result = await db.execute(
|
||||
select(func.count(ScriptBuilderSession.id)).where(
|
||||
ScriptBuilderSession.user_id == user_id
|
||||
)
|
||||
async def count_user_sessions(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
*,
|
||||
include_inline: bool = False,
|
||||
) -> int:
|
||||
"""Count active builder sessions for a user.
|
||||
|
||||
By default (include_inline=False) excludes pilot_inline sessions so they
|
||||
don't consume slots against the MAX_SESSIONS_PER_USER cap.
|
||||
"""
|
||||
stmt = select(func.count(ScriptBuilderSession.id)).where(
|
||||
ScriptBuilderSession.user_id == user_id
|
||||
)
|
||||
if not include_inline:
|
||||
stmt = stmt.where(ScriptBuilderSession.origin == "standalone")
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
|
||||
176
backend/tests/test_script_builder_inline.py
Normal file
176
backend/tests/test_script_builder_inline.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Integration tests for inline pilot_inline script_builder_session behavior.
|
||||
|
||||
Covers:
|
||||
- Idempotent get-or-create for (user, ai_session_id) on origin='pilot_inline'
|
||||
- Authorization: ai_session_id must belong to current user
|
||||
- list_sessions + count_user_sessions default-scope to 'standalone'
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select, func
|
||||
from uuid import uuid4
|
||||
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.script_builder_session import ScriptBuilderSession
|
||||
|
||||
|
||||
async def _make_pilot_session(test_db, user) -> str:
|
||||
"""Helper: create a minimal pilot session owned by `user`.
|
||||
|
||||
Matches the existing pattern used by test_fix_outcome_endpoint.py.
|
||||
`user` is the dict returned by the test_user fixture:
|
||||
{"email": ..., "password": ..., "user_data": {"id": ..., "account_id": ..., ...}}
|
||||
"""
|
||||
user_id = user["user_data"]["id"]
|
||||
account_id = user["user_data"]["account_id"]
|
||||
session = AISession(
|
||||
id=uuid4(), user_id=user_id, account_id=account_id,
|
||||
session_type="tshoot", intake_type="psa_ticket",
|
||||
intake_content={}, title="QA",
|
||||
status="active", confidence_tier="exploring", confidence_score=0.0,
|
||||
)
|
||||
test_db.add(session)
|
||||
await test_db.commit()
|
||||
return str(session.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inline_create_is_idempotent(
|
||||
client: AsyncClient, test_user, auth_headers, test_db
|
||||
):
|
||||
"""Second create with same (user, ai_session_id) returns the existing row."""
|
||||
ai_session_id = await _make_pilot_session(test_db, test_user)
|
||||
|
||||
r1 = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline",
|
||||
"ai_session_id": ai_session_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r1.status_code in (200, 201), r1.text
|
||||
first_id = r1.json()["id"]
|
||||
|
||||
r2 = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline",
|
||||
"ai_session_id": ai_session_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r2.status_code in (200, 201)
|
||||
assert r2.json()["id"] == first_id
|
||||
|
||||
# DB confirms only one row
|
||||
row_count = await test_db.scalar(
|
||||
select(func.count()).select_from(ScriptBuilderSession).where(
|
||||
ScriptBuilderSession.user_id == test_user["user_data"]["id"],
|
||||
ScriptBuilderSession.origin == "pilot_inline",
|
||||
)
|
||||
)
|
||||
assert row_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inline_requires_ai_session_id(
|
||||
client: AsyncClient, auth_headers
|
||||
):
|
||||
"""origin='pilot_inline' without ai_session_id is rejected."""
|
||||
r = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r.status_code == 400
|
||||
assert "ai_session_id" in r.text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inline_ai_session_must_belong_to_caller(
|
||||
client: AsyncClient, test_user, auth_headers, test_db
|
||||
):
|
||||
"""ai_session_id pointing at another user's session is rejected."""
|
||||
# Create pilot session owned by a DIFFERENT user
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
other_account = Account(id=uuid4(), name="other", display_code="OTH-0001")
|
||||
test_db.add(other_account)
|
||||
await test_db.flush()
|
||||
other_user = User(
|
||||
id=uuid4(), email="other@example.com",
|
||||
password_hash="x", name="Other", role="engineer",
|
||||
is_super_admin=False, is_team_admin=False, is_active=True,
|
||||
is_service_account=False, must_change_password=False,
|
||||
account_id=other_account.id, account_role="engineer",
|
||||
)
|
||||
test_db.add(other_user)
|
||||
await test_db.flush()
|
||||
# Build user dict in the same shape as the test_user fixture
|
||||
other_user_dict = {
|
||||
"user_data": {"id": str(other_user.id), "account_id": str(other_account.id)}
|
||||
}
|
||||
other_session_id = await _make_pilot_session(test_db, other_user_dict)
|
||||
|
||||
r = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline",
|
||||
"ai_session_id": other_session_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r.status_code in (403, 404), r.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_excludes_inline(
|
||||
client: AsyncClient, test_user, auth_headers, test_db
|
||||
):
|
||||
"""GET /scripts/builder/sessions returns only standalone rows."""
|
||||
ai_session_id = await _make_pilot_session(test_db, test_user)
|
||||
|
||||
# Create one inline session
|
||||
await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline",
|
||||
"ai_session_id": ai_session_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
# Create one standalone session
|
||||
await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
r = await client.get("/api/v1/scripts/builder/sessions", headers=auth_headers)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
# Depending on response shape, this may be a list or {"sessions": [...]}.
|
||||
items = body if isinstance(body, list) else body.get("sessions", body.get("items", []))
|
||||
# Response schema does not surface `origin`; len==1 is the only meaningful guard:
|
||||
# inline row would push this to 2.
|
||||
assert len(items) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inline_sessions_do_not_count_against_cap(
|
||||
client: AsyncClient, test_user, auth_headers, test_db
|
||||
):
|
||||
"""Creating 5 pilot_inline sessions does not block a subsequent standalone."""
|
||||
# Create 5 distinct pilot sessions and attach inline builder sessions to each
|
||||
for _ in range(5):
|
||||
ai_session_id = await _make_pilot_session(test_db, test_user)
|
||||
r = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell", "origin": "pilot_inline",
|
||||
"ai_session_id": ai_session_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
|
||||
# A standalone create should still succeed — inline sessions don't count
|
||||
r = await client.post(
|
||||
"/api/v1/scripts/builder/sessions",
|
||||
json={"language": "powershell"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert r.status_code in (200, 201), r.text
|
||||
Reference in New Issue
Block a user