Files
resolutionflow/backend/app/api/endpoints/steps.py
chihlasm 758cd61621 fix: propagate account_id through all write paths missing NOT NULL coverage
Service layer (production code):
- branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint
  from session.account_id; load session in create_fork for this purpose
- handoff_manager: set account_id on SessionHandoff from session.account_id
- ai_suggestions endpoint: set account_id on AISuggestion from current_user
- steps endpoint (/feedback): set account_id on StepRating from current_user
- ratings endpoint: set account_id on StepRating from current_user

Test infrastructure:
- conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after
  Base.metadata.create_all so global categories and gallery items have a valid FK
- test_rls_isolation: add _ensure_rls_schema fixture that runs
  'alembic upgrade head' before module tests — previous function-scoped
  test_db fixtures drop the schema, leaving the RLS tests with no tables
- test_branding: create Account before User in helper functions
- test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate
- test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree,
  ScriptTemplate, TreeCategory
- test_resolution_outputs: set account_id=session.account_id on
  SessionResolutionOutput
- test_analytics_phase5: set account_id on PsaPostLog
- test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in
  migration default test (NOT NULL now enforced)
- test_maintenance_schedules: set account_id on other_tree
- test_save_session_as_tree: set account_id on all 5 Session() constructors

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 04:24:36 +00:00

669 lines
22 KiB
Python

from uuid import UUID
from typing import Optional
from datetime import datetime, timezone
from decimal import Decimal
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func, desc, Integer, case
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.database import get_db
from app.api.deps import get_current_active_user, require_engineer_or_admin
from app.core.permissions import can_view_step, can_edit_step
from app.core.filters import build_step_visibility_filter
from app.models.user import User
from app.models.step_library import StepLibrary, StepRating
from app.models.step_category import StepCategory
from app.schemas.step_library import (
StepLibraryCreate,
StepLibraryUpdate,
StepLibraryResponse,
StepLibraryListResponse,
StepRatingCreate,
StepRatingUpdate,
StepRatingResponse,
PopularTagResponse,
)
router = APIRouter(prefix="/steps", tags=["steps"])
async def get_step_or_404(
step_id: UUID,
db: AsyncSession,
current_user: User,
check_view: bool = True,
check_edit: bool = False
) -> StepLibrary:
"""Get step by ID with permission checks."""
result = await db.execute(
select(StepLibrary).where(
StepLibrary.id == step_id,
StepLibrary.is_active == True
).options(selectinload(StepLibrary.source_tree))
)
step = result.scalar_one_or_none()
if not step:
raise HTTPException(status_code=404, detail="Step not found")
if check_view and not can_view_step(current_user, step):
raise HTTPException(status_code=404, detail="Step not found")
if check_edit and not can_edit_step(current_user, step):
raise HTTPException(status_code=404, detail="Step not found")
return step
@router.get("", response_model=list[StepLibraryListResponse])
async def list_steps(
visibility: Optional[str] = Query(None, pattern="^(private|team|public)$"),
category_id: Optional[UUID] = None,
tags: Optional[list[str]] = Query(None),
min_rating: Optional[float] = Query(None, ge=0, le=5),
step_type: Optional[str] = Query(None, pattern="^(decision|action|solution)$"),
sort_by: str = Query("recent", pattern="^(recent|popular|highest_rated|most_used)$"),
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""List steps with filters and pagination."""
query = select(StepLibrary).where(
StepLibrary.is_active == True,
build_step_visibility_filter(current_user)
).options(selectinload(StepLibrary.source_tree))
# Apply filters
if visibility:
query = query.where(StepLibrary.visibility == visibility)
if category_id:
query = query.where(StepLibrary.category_id == category_id)
if tags:
# Match any of the provided tags
query = query.where(StepLibrary.tags.overlap(tags))
if min_rating is not None:
query = query.where(StepLibrary.rating_average >= Decimal(str(min_rating)))
if step_type:
query = query.where(StepLibrary.step_type == step_type)
# Apply sorting
if sort_by == "recent":
query = query.order_by(desc(StepLibrary.created_at))
elif sort_by == "popular" or sort_by == "most_used":
query = query.order_by(desc(StepLibrary.usage_count))
elif sort_by == "highest_rated":
query = query.order_by(desc(StepLibrary.rating_average), desc(StepLibrary.rating_count))
# Apply pagination
query = query.offset(offset).limit(limit)
result = await db.execute(query)
steps = result.scalars().all()
# Fetch category names and author names
response = []
for step in steps:
step_dict = {
"id": step.id,
"title": step.title,
"step_type": step.step_type,
"visibility": step.visibility,
"category_id": step.category_id,
"tags": step.tags,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
"is_featured": step.is_featured,
"created_by": step.created_by,
"created_at": step.created_at,
"is_flow_synced": step.is_flow_synced,
"source_tree_name": step.source_tree.name if step.source_tree else None,
}
# Get category name if exists
if step.category_id:
cat_result = await db.execute(
select(StepCategory.name).where(StepCategory.id == step.category_id)
)
cat_name = cat_result.scalar_one_or_none()
step_dict["category_name"] = cat_name
# Get author name
author_result = await db.execute(
select(User.name).where(User.id == step.created_by)
)
author_name = author_result.scalar_one_or_none()
step_dict["author_name"] = author_name
response.append(StepLibraryListResponse(**step_dict))
return response
@router.get("/search", response_model=list[StepLibraryListResponse])
async def search_steps(
q: str = Query(..., min_length=1),
limit: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Full-text search for steps."""
# Use PostgreSQL full-text search
search_query = func.to_tsquery('english', q.replace(' ', ' & '))
query = select(StepLibrary).where(
StepLibrary.is_active == True,
build_step_visibility_filter(current_user),
func.to_tsvector('english', StepLibrary.title).match(search_query)
).options(selectinload(StepLibrary.source_tree)).order_by(desc(StepLibrary.rating_average)).limit(limit)
result = await db.execute(query)
steps = result.scalars().all()
response = []
for step in steps:
step_dict = {
"id": step.id,
"title": step.title,
"step_type": step.step_type,
"visibility": step.visibility,
"category_id": step.category_id,
"tags": step.tags,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
"is_featured": step.is_featured,
"created_by": step.created_by,
"created_at": step.created_at,
"is_flow_synced": step.is_flow_synced,
"source_tree_name": step.source_tree.name if step.source_tree else None,
}
if step.category_id:
cat_result = await db.execute(
select(StepCategory.name).where(StepCategory.id == step.category_id)
)
step_dict["category_name"] = cat_result.scalar_one_or_none()
author_result = await db.execute(
select(User.name).where(User.id == step.created_by)
)
step_dict["author_name"] = author_result.scalar_one_or_none()
response.append(StepLibraryListResponse(**step_dict))
return response
@router.get("/tags/popular", response_model=list[PopularTagResponse])
async def get_popular_tags(
limit: int = Query(20, ge=1, le=50),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Get popular tags with usage counts."""
# Use unnest to expand arrays and count occurrences
query = select(
func.unnest(StepLibrary.tags).label('tag'),
func.count().label('count')
).where(
StepLibrary.is_active == True,
build_step_visibility_filter(current_user)
).group_by(
func.unnest(StepLibrary.tags)
).order_by(
desc(func.count())
).limit(limit)
result = await db.execute(query)
tags = result.all()
return [PopularTagResponse(tag=row.tag, count=row.count) for row in tags]
@router.get("/{step_id}", response_model=StepLibraryResponse)
async def get_step(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Get a step by ID."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
response_dict = {
"id": step.id,
"title": step.title,
"step_type": step.step_type,
"content": step.content,
"category_id": step.category_id,
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
"helpful_yes": step.helpful_yes,
"helpful_no": step.helpful_no,
"is_featured": step.is_featured,
"is_verified": step.is_verified,
"is_active": step.is_active,
"created_at": step.created_at,
"updated_at": step.updated_at,
"is_flow_synced": step.is_flow_synced,
"source_tree_name": step.source_tree.name if step.source_tree else None,
}
# Get category name if exists
if step.category_id:
cat_result = await db.execute(
select(StepCategory.name).where(StepCategory.id == step.category_id)
)
response_dict["category_name"] = cat_result.scalar_one_or_none()
# Get author name
author_result = await db.execute(
select(User.name).where(User.id == step.created_by)
)
response_dict["author_name"] = author_result.scalar_one_or_none()
return StepLibraryResponse(**response_dict)
@router.post("", response_model=StepLibraryResponse, status_code=201)
async def create_step(
step_data: StepLibraryCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_engineer_or_admin)
):
"""Create a new step (engineers and admins only, not viewers)."""
# Validate category if provided
if step_data.category_id:
cat_result = await db.execute(
select(StepCategory).where(
StepCategory.id == step_data.category_id,
StepCategory.is_active == True
)
)
if not cat_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Invalid category")
# Account validation: can only set account_id to own account
account_id = step_data.account_id
if account_id and account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Cannot create step for another account")
step = StepLibrary(
title=step_data.title,
step_type=step_data.step_type,
content=step_data.content.model_dump(),
category_id=step_data.category_id,
tags=step_data.tags,
visibility=step_data.visibility,
created_by=current_user.id,
account_id=account_id or current_user.account_id,
)
db.add(step)
await db.commit()
await db.refresh(step)
# Build response
response_dict = {
"id": step.id,
"title": step.title,
"step_type": step.step_type,
"content": step.content,
"category_id": step.category_id,
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
"helpful_yes": step.helpful_yes,
"helpful_no": step.helpful_no,
"is_featured": step.is_featured,
"is_verified": step.is_verified,
"is_active": step.is_active,
"created_at": step.created_at,
"updated_at": step.updated_at,
"author_name": current_user.name,
}
if step.category_id:
cat_result = await db.execute(
select(StepCategory.name).where(StepCategory.id == step.category_id)
)
response_dict["category_name"] = cat_result.scalar_one_or_none()
return StepLibraryResponse(**response_dict)
@router.put("/{step_id}", response_model=StepLibraryResponse)
async def update_step(
step_id: UUID,
step_data: StepLibraryUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Update a step (owner or admin only)."""
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
if step.is_flow_synced:
raise HTTPException(
status_code=400,
detail="Flow-synced steps are read-only. Fork to customize."
)
# Validate category if being updated
if step_data.category_id:
cat_result = await db.execute(
select(StepCategory).where(
StepCategory.id == step_data.category_id,
StepCategory.is_active == True
)
)
if not cat_result.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Invalid category")
# Apply updates
update_data = step_data.model_dump(exclude_unset=True)
if 'content' in update_data and update_data['content']:
update_data['content'] = update_data['content'].model_dump() if hasattr(update_data['content'], 'model_dump') else update_data['content']
for field, value in update_data.items():
setattr(step, field, value)
step.updated_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(step)
# Build response
response_dict = {
"id": step.id,
"title": step.title,
"step_type": step.step_type,
"content": step.content,
"category_id": step.category_id,
"tags": step.tags,
"visibility": step.visibility,
"created_by": step.created_by,
"account_id": step.account_id,
"usage_count": step.usage_count,
"rating_average": step.rating_average,
"rating_count": step.rating_count,
"helpful_yes": step.helpful_yes,
"helpful_no": step.helpful_no,
"is_featured": step.is_featured,
"is_verified": step.is_verified,
"is_active": step.is_active,
"created_at": step.created_at,
"updated_at": step.updated_at,
}
if step.category_id:
cat_result = await db.execute(
select(StepCategory.name).where(StepCategory.id == step.category_id)
)
response_dict["category_name"] = cat_result.scalar_one_or_none()
author_result = await db.execute(
select(User.name).where(User.id == step.created_by)
)
response_dict["author_name"] = author_result.scalar_one_or_none()
return StepLibraryResponse(**response_dict)
@router.delete("/{step_id}", status_code=204)
async def delete_step(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Soft delete a step (owner or admin only)."""
step = await get_step_or_404(step_id, db, current_user, check_edit=True)
step.is_active = False
step.updated_at = datetime.now(timezone.utc)
await db.commit()
return None
# Rating endpoints
@router.post("/{step_id}/rate", response_model=StepRatingResponse, status_code=201)
async def rate_step(
step_id: UUID,
rating_data: StepRatingCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Rate a step (1-5 stars with optional review)."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
# Check if user already rated
existing = await db.execute(
select(StepRating).where(
StepRating.step_id == step_id,
StepRating.user_id == current_user.id
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="You have already rated this step. Use PUT to update.")
rating = StepRating(
step_id=step_id,
user_id=current_user.id,
account_id=current_user.account_id,
rating=rating_data.rating,
was_helpful=rating_data.was_helpful,
review_text=rating_data.review_text,
session_id=rating_data.session_id,
is_verified_use=rating_data.session_id is not None,
)
db.add(rating)
# Update aggregated ratings on step
await _update_step_ratings(db, step)
await db.commit()
await db.refresh(rating)
return StepRatingResponse(
id=rating.id,
step_id=rating.step_id,
user_id=rating.user_id,
rating=rating.rating,
was_helpful=rating.was_helpful,
review_text=rating.review_text,
is_verified_use=rating.is_verified_use,
session_id=rating.session_id,
is_visible=rating.is_visible,
created_at=rating.created_at,
updated_at=rating.updated_at,
user_name=current_user.name,
)
@router.put("/{step_id}/rate", response_model=StepRatingResponse)
async def update_rating(
step_id: UUID,
rating_data: StepRatingUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Update your rating for a step."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
result = await db.execute(
select(StepRating).where(
StepRating.step_id == step_id,
StepRating.user_id == current_user.id
)
)
rating = result.scalar_one_or_none()
if not rating:
raise HTTPException(status_code=404, detail="Rating not found. Use POST to create.")
update_data = rating_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(rating, field, value)
rating.updated_at = datetime.now(timezone.utc)
# Update aggregated ratings on step
await _update_step_ratings(db, step)
await db.commit()
await db.refresh(rating)
return StepRatingResponse(
id=rating.id,
step_id=rating.step_id,
user_id=rating.user_id,
rating=rating.rating,
was_helpful=rating.was_helpful,
review_text=rating.review_text,
is_verified_use=rating.is_verified_use,
session_id=rating.session_id,
is_visible=rating.is_visible,
created_at=rating.created_at,
updated_at=rating.updated_at,
user_name=current_user.name,
)
@router.delete("/{step_id}/rate", status_code=204)
async def delete_rating(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Delete your rating for a step."""
step = await get_step_or_404(step_id, db, current_user, check_view=True)
result = await db.execute(
select(StepRating).where(
StepRating.step_id == step_id,
StepRating.user_id == current_user.id
)
)
rating = result.scalar_one_or_none()
if not rating:
raise HTTPException(status_code=404, detail="Rating not found")
await db.delete(rating)
# Update aggregated ratings on step
await _update_step_ratings(db, step)
await db.commit()
return None
# Backward-compatible /ratings alias routes
@router.post("/{step_id}/ratings", response_model=StepRatingResponse, status_code=201,
include_in_schema=False)
async def rate_step_alias(
step_id: UUID,
rating_data: StepRatingCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Alias for POST /{step_id}/rate."""
return await rate_step(step_id, rating_data, db, current_user)
@router.put("/{step_id}/ratings", response_model=StepRatingResponse,
include_in_schema=False)
async def update_rating_alias(
step_id: UUID,
rating_data: StepRatingUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Alias for PUT /{step_id}/rate."""
return await update_rating(step_id, rating_data, db, current_user)
@router.delete("/{step_id}/ratings", status_code=204,
include_in_schema=False)
async def delete_rating_alias(
step_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Alias for DELETE /{step_id}/rate."""
return await delete_rating(step_id, db, current_user)
@router.get("/{step_id}/reviews", response_model=list[StepRatingResponse])
async def get_reviews(
step_id: UUID,
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""Get reviews for a step."""
await get_step_or_404(step_id, db, current_user, check_view=True)
result = await db.execute(
select(StepRating).where(
StepRating.step_id == step_id,
StepRating.is_visible == True,
StepRating.review_text.isnot(None)
).order_by(desc(StepRating.created_at)).offset(offset).limit(limit)
)
ratings = result.scalars().all()
response = []
for rating in ratings:
user_result = await db.execute(
select(User.name).where(User.id == rating.user_id)
)
user_name = user_result.scalar_one_or_none()
response.append(StepRatingResponse(
id=rating.id,
step_id=rating.step_id,
user_id=rating.user_id,
rating=rating.rating,
was_helpful=rating.was_helpful,
review_text=rating.review_text,
is_verified_use=rating.is_verified_use,
session_id=rating.session_id,
is_visible=rating.is_visible,
created_at=rating.created_at,
updated_at=rating.updated_at,
user_name=user_name,
))
return response
async def _update_step_ratings(db: AsyncSession, step: StepLibrary):
"""Update aggregated rating fields on a step."""
# Calculate new aggregates
result = await db.execute(
select(
func.count(StepRating.id),
func.avg(StepRating.rating),
func.sum(case((StepRating.was_helpful == True, 1), else_=0)),
func.sum(case((StepRating.was_helpful == False, 1), else_=0))
).where(StepRating.step_id == step.id)
)
row = result.one()
step.rating_count = row[0] or 0
step.rating_average = Decimal(str(round(row[1] or 0, 2)))
step.helpful_yes = row[2] or 0
step.helpful_no = row[3] or 0