feat: add supporting data CRUD endpoints with tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
201
backend/app/api/endpoints/supporting_data.py
Normal file
201
backend/app/api/endpoints/supporting_data.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import base64
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models import User
|
||||
from app.models.session import Session
|
||||
from app.models.supporting_data import SessionSupportingData
|
||||
from app.schemas.supporting_data import (
|
||||
SupportingDataCreate,
|
||||
SupportingDataUpdate,
|
||||
SupportingDataResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/sessions", tags=["supporting-data"])
|
||||
|
||||
MAX_ITEMS_PER_SESSION = 20
|
||||
MAX_TEXT_SNIPPET_CHARS = 50_000
|
||||
MAX_SCREENSHOT_RAW_BYTES = 2 * 1024 * 1024 # 2MB
|
||||
|
||||
|
||||
async def _check_session_access(user: User, session: Session, db: AsyncSession) -> None:
|
||||
"""Verify user has access to the session (owner, team admin, or super admin)."""
|
||||
if user.is_super_admin:
|
||||
return
|
||||
if session.user_id == user.id:
|
||||
return
|
||||
# Team admins can only access sessions from their own team members
|
||||
if user.is_team_admin and user.team_id is not None:
|
||||
session_owner = await db.get(User, session.user_id)
|
||||
if session_owner and session_owner.team_id == user.team_id:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
|
||||
async def _get_session_or_404(session_id: UUID, db: AsyncSession) -> Session:
|
||||
"""Fetch session by ID or raise 404."""
|
||||
result = await db.execute(select(Session).where(Session.id == session_id))
|
||||
session = result.scalar_one_or_none()
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{session_id}/supporting-data",
|
||||
response_model=SupportingDataResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_supporting_data(
|
||||
session_id: UUID,
|
||||
data: SupportingDataCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Add a supporting data item (text snippet or screenshot) to a session."""
|
||||
session = await _get_session_or_404(session_id, db)
|
||||
await _check_session_access(current_user, session, db)
|
||||
|
||||
# Check item limit
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(SessionSupportingData).where(
|
||||
SessionSupportingData.session_id == session_id
|
||||
)
|
||||
)
|
||||
current_count = count_result.scalar() or 0
|
||||
if current_count >= MAX_ITEMS_PER_SESSION:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Maximum {MAX_ITEMS_PER_SESSION} supporting data items per session",
|
||||
)
|
||||
|
||||
# Validate content size based on type
|
||||
if data.data_type == "text_snippet":
|
||||
if len(data.content) > MAX_TEXT_SNIPPET_CHARS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Text snippet exceeds maximum {MAX_TEXT_SNIPPET_CHARS} characters",
|
||||
)
|
||||
elif data.data_type == "screenshot":
|
||||
try:
|
||||
raw_bytes = base64.b64decode(data.content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid base64 content for screenshot")
|
||||
if len(raw_bytes) > MAX_SCREENSHOT_RAW_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Screenshot exceeds maximum {MAX_SCREENSHOT_RAW_BYTES // (1024 * 1024)}MB raw size",
|
||||
)
|
||||
|
||||
# Auto-increment sort_order
|
||||
max_order_result = await db.execute(
|
||||
select(func.max(SessionSupportingData.sort_order)).where(
|
||||
SessionSupportingData.session_id == session_id
|
||||
)
|
||||
)
|
||||
max_order = max_order_result.scalar()
|
||||
next_order = (max_order or 0) + 1
|
||||
|
||||
item = SessionSupportingData(
|
||||
session_id=session_id,
|
||||
label=data.label,
|
||||
data_type=data.data_type,
|
||||
content=data.content,
|
||||
content_type=data.content_type,
|
||||
sort_order=next_order,
|
||||
)
|
||||
db.add(item)
|
||||
await db.commit()
|
||||
await db.refresh(item)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{session_id}/supporting-data",
|
||||
response_model=list[SupportingDataResponse],
|
||||
)
|
||||
async def list_supporting_data(
|
||||
session_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""List all supporting data items for a session, ordered by sort_order."""
|
||||
session = await _get_session_or_404(session_id, db)
|
||||
await _check_session_access(current_user, session, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SessionSupportingData)
|
||||
.where(SessionSupportingData.session_id == session_id)
|
||||
.order_by(SessionSupportingData.sort_order)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{session_id}/supporting-data/{item_id}",
|
||||
response_model=SupportingDataResponse,
|
||||
)
|
||||
async def update_supporting_data(
|
||||
session_id: UUID,
|
||||
item_id: UUID,
|
||||
data: SupportingDataUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Update a supporting data item's label or content."""
|
||||
session = await _get_session_or_404(session_id, db)
|
||||
await _check_session_access(current_user, session, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SessionSupportingData).where(
|
||||
SessionSupportingData.id == item_id,
|
||||
SessionSupportingData.session_id == session_id,
|
||||
)
|
||||
)
|
||||
item = result.scalar_one_or_none()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Supporting data item not found")
|
||||
|
||||
if data.label is not None:
|
||||
item.label = data.label
|
||||
if data.content is not None:
|
||||
item.content = data.content
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(item)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{session_id}/supporting-data/{item_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_supporting_data(
|
||||
session_id: UUID,
|
||||
item_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Remove a supporting data item from a session."""
|
||||
session = await _get_session_or_404(session_id, db)
|
||||
await _check_session_access(current_user, session, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SessionSupportingData).where(
|
||||
SessionSupportingData.id == item_id,
|
||||
SessionSupportingData.session_id == session_id,
|
||||
)
|
||||
)
|
||||
item = result.scalar_one_or_none()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Supporting data item not found")
|
||||
|
||||
await db.delete(item)
|
||||
await db.commit()
|
||||
@@ -20,6 +20,7 @@ from app.api.endpoints import scripts
|
||||
from app.api.endpoints import integrations
|
||||
from app.api.endpoints import onboarding
|
||||
from app.api.endpoints import branding
|
||||
from app.api.endpoints import supporting_data
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -65,3 +66,4 @@ api_router.include_router(scripts.router)
|
||||
api_router.include_router(integrations.router)
|
||||
api_router.include_router(onboarding.router)
|
||||
api_router.include_router(branding.router)
|
||||
api_router.include_router(supporting_data.router)
|
||||
|
||||
217
backend/tests/test_supporting_data.py
Normal file
217
backend/tests/test_supporting_data.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import base64
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_session(client: AsyncClient, auth_headers: dict, test_tree: dict):
|
||||
"""Create a test session from the test tree."""
|
||||
response = await client.post(
|
||||
"/api/v1/sessions",
|
||||
json={"tree_id": test_tree["id"]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201, f"Failed to create session: {response.text}"
|
||||
return response.json()
|
||||
|
||||
|
||||
# --- Create ---
|
||||
|
||||
|
||||
async def test_create_text_snippet(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Create a text snippet supporting data item — returns 201."""
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "Error log",
|
||||
"data_type": "text_snippet",
|
||||
"content": "NullReferenceException at line 42",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["label"] == "Error log"
|
||||
assert data["data_type"] == "text_snippet"
|
||||
assert data["content"] == "NullReferenceException at line 42"
|
||||
assert data["sort_order"] == 1
|
||||
assert data["session_id"] == test_session["id"]
|
||||
|
||||
|
||||
async def test_create_screenshot(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Create a screenshot supporting data item — returns 201."""
|
||||
# Small valid base64 content (a tiny PNG-like payload)
|
||||
small_content = base64.b64encode(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100).decode()
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "Error screenshot",
|
||||
"data_type": "screenshot",
|
||||
"content": small_content,
|
||||
"content_type": "image/png",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["label"] == "Error screenshot"
|
||||
assert data["data_type"] == "screenshot"
|
||||
assert data["content_type"] == "image/png"
|
||||
|
||||
|
||||
# --- List ---
|
||||
|
||||
|
||||
async def test_list_items_in_sort_order(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""List returns items ordered by sort_order."""
|
||||
# Create 3 items
|
||||
for i in range(3):
|
||||
resp = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": f"Item {i}",
|
||||
"data_type": "text_snippet",
|
||||
"content": f"Content {i}",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
items = response.json()
|
||||
assert len(items) == 3
|
||||
assert items[0]["label"] == "Item 0"
|
||||
assert items[1]["label"] == "Item 1"
|
||||
assert items[2]["label"] == "Item 2"
|
||||
assert items[0]["sort_order"] < items[1]["sort_order"] < items[2]["sort_order"]
|
||||
|
||||
|
||||
# --- Delete ---
|
||||
|
||||
|
||||
async def test_delete_item(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Delete removes the item."""
|
||||
create_resp = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "To delete",
|
||||
"data_type": "text_snippet",
|
||||
"content": "Will be removed",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
item_id = create_resp.json()["id"]
|
||||
|
||||
delete_resp = await client.delete(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data/{item_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert delete_resp.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
list_resp = await client.get(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert list_resp.status_code == 200
|
||||
assert len(list_resp.json()) == 0
|
||||
|
||||
|
||||
# --- Validation ---
|
||||
|
||||
|
||||
async def test_exceed_20_item_limit(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Cannot exceed 20 items per session — returns 400."""
|
||||
for i in range(20):
|
||||
resp = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": f"Item {i}",
|
||||
"data_type": "text_snippet",
|
||||
"content": f"Content {i}",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201, f"Failed creating item {i}: {resp.text}"
|
||||
|
||||
# 21st should fail
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "One too many",
|
||||
"data_type": "text_snippet",
|
||||
"content": "Should fail",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "20" in response.json()["detail"]
|
||||
|
||||
|
||||
async def test_screenshot_exceeds_2mb(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Screenshot over 2MB raw (base64 decoded) — returns 400."""
|
||||
# Create content that decodes to > 2MB
|
||||
large_raw = b"\x00" * (2 * 1024 * 1024 + 1) # 2MB + 1 byte
|
||||
large_b64 = base64.b64encode(large_raw).decode()
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "Large screenshot",
|
||||
"data_type": "screenshot",
|
||||
"content": large_b64,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "2MB" in response.json()["detail"]
|
||||
|
||||
|
||||
async def test_text_snippet_over_50k_chars(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""Text snippet over 50,000 characters — returns 400."""
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "Huge text",
|
||||
"data_type": "text_snippet",
|
||||
"content": "x" * 50_001,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "50000" in response.json()["detail"]
|
||||
|
||||
|
||||
# --- Update ---
|
||||
|
||||
|
||||
async def test_patch_update_label(client: AsyncClient, auth_headers: dict, test_session: dict):
|
||||
"""PATCH to update label returns updated item."""
|
||||
create_resp = await client.post(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data",
|
||||
json={
|
||||
"label": "Original label",
|
||||
"data_type": "text_snippet",
|
||||
"content": "Some content",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
item_id = create_resp.json()["id"]
|
||||
|
||||
patch_resp = await client.patch(
|
||||
f"/api/v1/sessions/{test_session['id']}/supporting-data/{item_id}",
|
||||
json={"label": "Updated label"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert patch_resp.status_code == 200
|
||||
data = patch_resp.json()
|
||||
assert data["label"] == "Updated label"
|
||||
assert data["content"] == "Some content" # unchanged
|
||||
Reference in New Issue
Block a user