diff --git a/backend/app/api/endpoints/script_builder.py b/backend/app/api/endpoints/script_builder.py index 1cb849f8..f028ae28 100644 --- a/backend/app/api/endpoints/script_builder.py +++ b/backend/app/api/endpoints/script_builder.py @@ -3,12 +3,14 @@ from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy import text +from sqlalchemy import select, text +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db from app.core.rate_limit import limiter from app.api.deps import get_current_active_user +from app.models.ai_session import AISession from app.models.user import User from app.models.script_builder_session import ScriptBuilderSession from app.schemas.script_builder import ( @@ -67,15 +69,85 @@ async def create_session( db: Annotated[AsyncSession, Depends(get_db)], current_user: Annotated[User, Depends(get_current_active_user)], ) -> ScriptBuilderSessionDetail: - """Start a new Script Builder session.""" + """Start a new Script Builder session. + + When origin='pilot_inline', behaves as get-or-create: the same row is + returned on repeated calls with the same (user, ai_session_id) pair. + Inline sessions are excluded from the session cap and the list endpoint. + """ + # Phase 9: inline origin validation + authorization + if data.origin == "pilot_inline": + if data.ai_session_id is None: + raise HTTPException( + status_code=400, + detail="ai_session_id is required when origin='pilot_inline'", + ) + # Ownership check: the pilot session must belong to the current user. + ai_session = await db.scalar( + select(AISession).where( + AISession.id == data.ai_session_id, + AISession.user_id == current_user.id, + ) + ) + if ai_session is None: + raise HTTPException( + status_code=404, + detail="Session not found", + ) + + # Idempotent get-or-create: if a pilot_inline row already exists for + # this (user, ai_session_id) pair, return it without creating a duplicate. + existing = await db.scalar( + select(ScriptBuilderSession).where( + ScriptBuilderSession.user_id == current_user.id, + ScriptBuilderSession.ai_session_id == data.ai_session_id, + ScriptBuilderSession.origin == "pilot_inline", + ) + ) + if existing is not None: + # Re-fetch with message_records loaded + session = await script_builder_service.get_session(db, existing.id, current_user.id) + return _session_to_detail(session) + + # Create the inline session — wrap in IntegrityError catch for races. + try: + session = await script_builder_service.create_session( + db=db, + user_id=current_user.id, + account_id=current_user.account_id, + team_id=current_user.team_id, + language=data.language, + origin=data.origin, + ai_session_id=data.ai_session_id, + ) + await db.commit() + except IntegrityError: + await db.rollback() + # Race: another request won the unique index — re-read the winner row. + existing = await db.scalar( + select(ScriptBuilderSession).where( + ScriptBuilderSession.user_id == current_user.id, + ScriptBuilderSession.ai_session_id == data.ai_session_id, + ScriptBuilderSession.origin == "pilot_inline", + ) + ) + if existing is None: + raise + session = existing + + # Re-fetch with message_records loaded + session = await script_builder_service.get_session(db, session.id, current_user.id) + return _session_to_detail(session) + + # ── Standalone session ────────────────────────────────────────────────── # Acquire per-user advisory lock so concurrent create requests are serialized. # Without this, two simultaneous requests both read count < limit and both # insert, exceeding MAX_SESSIONS_PER_USER. user_lock_key = hash(str(current_user.id)) % (2**62) await db.execute(text("SELECT pg_advisory_xact_lock(:key)"), {"key": user_lock_key}) - # Enforce max concurrent sessions - count = await script_builder_service.count_user_sessions(db, current_user.id) + # Enforce max concurrent sessions (inline sessions excluded from cap) + count = await script_builder_service.count_user_sessions(db, current_user.id, include_inline=False) if count >= MAX_SESSIONS_PER_USER: raise HTTPException( status_code=400, @@ -88,6 +160,8 @@ async def create_session( account_id=current_user.account_id, team_id=current_user.team_id, language=data.language, + origin=data.origin, + ai_session_id=data.ai_session_id, ) await db.commit() # Re-fetch with message_records loaded diff --git a/backend/app/services/script_builder_service.py b/backend/app/services/script_builder_service.py index aec7e87a..2422c437 100644 --- a/backend/app/services/script_builder_service.py +++ b/backend/app/services/script_builder_service.py @@ -148,6 +148,8 @@ async def create_session( team_id: UUID | None, language: str, initial_prompt: str | None = None, + origin: str = "standalone", + ai_session_id: UUID | None = None, ) -> ScriptBuilderSession: """Create a new Script Builder session.""" session = ScriptBuilderSession( @@ -155,6 +157,8 @@ async def create_session( account_id=account_id, team_id=team_id, language=language, + origin=origin, + ai_session_id=ai_session_id, ) db.add(session) await db.flush() @@ -295,15 +299,22 @@ async def list_sessions( user_id: UUID, limit: int = 20, offset: int = 0, + *, + include_inline: bool = False, ) -> list[ScriptBuilderSession]: - """List user's builder sessions ordered by updated_at desc.""" - result = await db.execute( + """List user's builder sessions ordered by updated_at desc. + + By default (include_inline=False) excludes pilot_inline sessions so the + /script-builder dashboard only shows standalone sessions. + """ + stmt = ( select(ScriptBuilderSession) .where(ScriptBuilderSession.user_id == user_id) - .order_by(ScriptBuilderSession.updated_at.desc()) - .limit(limit) - .offset(offset) ) + if not include_inline: + stmt = stmt.where(ScriptBuilderSession.origin == "standalone") + stmt = stmt.order_by(ScriptBuilderSession.updated_at.desc()).limit(limit).offset(offset) + result = await db.execute(stmt) return list(result.scalars().all()) @@ -321,13 +332,23 @@ async def delete_session( return True -async def count_user_sessions(db: AsyncSession, user_id: UUID) -> int: - """Count active builder sessions for a user.""" - result = await db.execute( - select(func.count(ScriptBuilderSession.id)).where( - ScriptBuilderSession.user_id == user_id - ) +async def count_user_sessions( + db: AsyncSession, + user_id: UUID, + *, + include_inline: bool = False, +) -> int: + """Count active builder sessions for a user. + + By default (include_inline=False) excludes pilot_inline sessions so they + don't consume slots against the MAX_SESSIONS_PER_USER cap. + """ + stmt = select(func.count(ScriptBuilderSession.id)).where( + ScriptBuilderSession.user_id == user_id ) + if not include_inline: + stmt = stmt.where(ScriptBuilderSession.origin == "standalone") + result = await db.execute(stmt) return result.scalar_one() diff --git a/backend/tests/test_script_builder_inline.py b/backend/tests/test_script_builder_inline.py new file mode 100644 index 00000000..2db04e2c --- /dev/null +++ b/backend/tests/test_script_builder_inline.py @@ -0,0 +1,176 @@ +"""Integration tests for inline pilot_inline script_builder_session behavior. + +Covers: +- Idempotent get-or-create for (user, ai_session_id) on origin='pilot_inline' +- Authorization: ai_session_id must belong to current user +- list_sessions + count_user_sessions default-scope to 'standalone' +""" +from __future__ import annotations + +import pytest +from httpx import AsyncClient +from sqlalchemy import select, func +from uuid import uuid4 + +from app.models.ai_session import AISession +from app.models.script_builder_session import ScriptBuilderSession + + +async def _make_pilot_session(test_db, user) -> str: + """Helper: create a minimal pilot session owned by `user`. + + Matches the existing pattern used by test_fix_outcome_endpoint.py. + `user` is the dict returned by the test_user fixture: + {"email": ..., "password": ..., "user_data": {"id": ..., "account_id": ..., ...}} + """ + user_id = user["user_data"]["id"] + account_id = user["user_data"]["account_id"] + session = AISession( + id=uuid4(), user_id=user_id, account_id=account_id, + session_type="tshoot", intake_type="psa_ticket", + intake_content={}, title="QA", + status="active", confidence_tier="exploring", confidence_score=0.0, + ) + test_db.add(session) + await test_db.commit() + return str(session.id) + + +@pytest.mark.asyncio +async def test_inline_create_is_idempotent( + client: AsyncClient, test_user, auth_headers, test_db +): + """Second create with same (user, ai_session_id) returns the existing row.""" + ai_session_id = await _make_pilot_session(test_db, test_user) + + r1 = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline", + "ai_session_id": ai_session_id}, + headers=auth_headers, + ) + assert r1.status_code in (200, 201), r1.text + first_id = r1.json()["id"] + + r2 = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline", + "ai_session_id": ai_session_id}, + headers=auth_headers, + ) + assert r2.status_code in (200, 201) + assert r2.json()["id"] == first_id + + # DB confirms only one row + row_count = await test_db.scalar( + select(func.count()).select_from(ScriptBuilderSession).where( + ScriptBuilderSession.user_id == test_user["user_data"]["id"], + ScriptBuilderSession.origin == "pilot_inline", + ) + ) + assert row_count == 1 + + +@pytest.mark.asyncio +async def test_inline_requires_ai_session_id( + client: AsyncClient, auth_headers +): + """origin='pilot_inline' without ai_session_id is rejected.""" + r = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline"}, + headers=auth_headers, + ) + assert r.status_code == 400 + assert "ai_session_id" in r.text.lower() + + +@pytest.mark.asyncio +async def test_inline_ai_session_must_belong_to_caller( + client: AsyncClient, test_user, auth_headers, test_db +): + """ai_session_id pointing at another user's session is rejected.""" + # Create pilot session owned by a DIFFERENT user + from app.models.user import User + from app.models.account import Account + other_account = Account(id=uuid4(), name="other", display_code="OTH-0001") + test_db.add(other_account) + await test_db.flush() + other_user = User( + id=uuid4(), email="other@example.com", + password_hash="x", name="Other", role="engineer", + is_super_admin=False, is_team_admin=False, is_active=True, + is_service_account=False, must_change_password=False, + account_id=other_account.id, account_role="engineer", + ) + test_db.add(other_user) + await test_db.flush() + # Build user dict in the same shape as the test_user fixture + other_user_dict = { + "user_data": {"id": str(other_user.id), "account_id": str(other_account.id)} + } + other_session_id = await _make_pilot_session(test_db, other_user_dict) + + r = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline", + "ai_session_id": other_session_id}, + headers=auth_headers, + ) + assert r.status_code in (403, 404), r.text + + +@pytest.mark.asyncio +async def test_list_sessions_excludes_inline( + client: AsyncClient, test_user, auth_headers, test_db +): + """GET /scripts/builder/sessions returns only standalone rows.""" + ai_session_id = await _make_pilot_session(test_db, test_user) + + # Create one inline session + await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline", + "ai_session_id": ai_session_id}, + headers=auth_headers, + ) + # Create one standalone session + await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell"}, + headers=auth_headers, + ) + + r = await client.get("/api/v1/scripts/builder/sessions", headers=auth_headers) + assert r.status_code == 200 + body = r.json() + # Depending on response shape, this may be a list or {"sessions": [...]}. + items = body if isinstance(body, list) else body.get("sessions", body.get("items", [])) + # Response schema does not surface `origin`; len==1 is the only meaningful guard: + # inline row would push this to 2. + assert len(items) == 1 + + +@pytest.mark.asyncio +async def test_inline_sessions_do_not_count_against_cap( + client: AsyncClient, test_user, auth_headers, test_db +): + """Creating 5 pilot_inline sessions does not block a subsequent standalone.""" + # Create 5 distinct pilot sessions and attach inline builder sessions to each + for _ in range(5): + ai_session_id = await _make_pilot_session(test_db, test_user) + r = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell", "origin": "pilot_inline", + "ai_session_id": ai_session_id}, + headers=auth_headers, + ) + assert r.status_code in (200, 201), r.text + + # A standalone create should still succeed — inline sessions don't count + r = await client.post( + "/api/v1/scripts/builder/sessions", + json={"language": "powershell"}, + headers=auth_headers, + ) + assert r.status_code in (200, 201), r.text