fix: prevent InFailedSQLTransactionError in session creation

Root cause: embedding generation could break the DB transaction via a failed
SQL statement. The except block caught the Python error but left the transaction
in a failed state. Subsequent queries (_record_usage → subscription lookup)
then failed with InFailedSQLTransactionError.

Fixes:
- session_embedding_service: use begin_nested() savepoint so failures don't
  poison the parent transaction
- ai_sessions.py: add db.rollback() before _record_usage in all 3 error
  handlers (create, respond, pickup) to recover from broken transactions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-20 04:36:12 +00:00
parent 2ed8a2af15
commit eed771cb27
2 changed files with 74 additions and 61 deletions

View File

@@ -139,13 +139,18 @@ async def create_session(
) )
except Exception as e: except Exception as e:
logger.exception("FlowPilot session start failed: %s", e) logger.exception("FlowPilot session start failed: %s", e)
await _record_usage( # Rollback the failed transaction before attempting usage recording
current_user, db, await db.rollback()
generation_type="flowpilot_start", try:
input_tokens=0, output_tokens=0, await _record_usage(
succeeded=False, error_code=type(e).__name__, current_user, db,
) generation_type="flowpilot_start",
await db.commit() input_tokens=0, output_tokens=0,
succeeded=False, error_code=type(e).__name__,
)
await db.commit()
except Exception:
logger.warning("Failed to record usage after session start failure", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.", detail=f"AI provider error ({type(e).__name__}). Please try again.",
@@ -193,15 +198,19 @@ async def respond_to_step(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e: except Exception as e:
logger.exception("FlowPilot response failed: %s", e) logger.exception("FlowPilot response failed: %s", e)
await _record_usage( await db.rollback()
current_user, db, try:
generation_type="flowpilot_respond", await _record_usage(
input_tokens=0, output_tokens=0, current_user, db,
succeeded=False, generation_type="flowpilot_respond",
session_id=session_id, input_tokens=0, output_tokens=0,
error_code=type(e).__name__, succeeded=False,
) session_id=session_id,
await db.commit() error_code=type(e).__name__,
)
await db.commit()
except Exception:
logger.warning("Failed to record usage after response failure", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.", detail=f"AI provider error ({type(e).__name__}). Please try again.",
@@ -387,15 +396,19 @@ async def pickup_session(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
except Exception as e: except Exception as e:
logger.exception("FlowPilot pickup failed: %s", e) logger.exception("FlowPilot pickup failed: %s", e)
await _record_usage( await db.rollback()
current_user, db, try:
generation_type="flowpilot_pickup", await _record_usage(
input_tokens=0, output_tokens=0, current_user, db,
succeeded=False, generation_type="flowpilot_pickup",
session_id=session_id, input_tokens=0, output_tokens=0,
error_code=type(e).__name__, succeeded=False,
) session_id=session_id,
await db.commit() error_code=type(e).__name__,
)
await db.commit()
except Exception:
logger.warning("Failed to record usage after pickup failure", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"AI provider error ({type(e).__name__}). Please try again.", detail=f"AI provider error ({type(e).__name__}). Please try again.",

View File

@@ -54,44 +54,44 @@ async def generate_session_embedding(session_id: UUID, db: AsyncSession) -> None
embedding_str = "[" + ",".join(str(v) for v in embedding_vector) + "]" embedding_str = "[" + ",".join(str(v) for v in embedding_vector) + "]"
# Check for existing embedding # Use a savepoint so failures don't poison the parent transaction
existing = await db.execute( async with db.begin_nested():
select(AISessionEmbedding).where( # Check for existing embedding
AISessionEmbedding.session_id == session_id existing = await db.execute(
select(AISessionEmbedding).where(
AISessionEmbedding.session_id == session_id
)
) )
) embed_record = existing.scalar_one_or_none()
embed_record = existing.scalar_one_or_none()
if embed_record: if embed_record:
# Update existing # Update existing
embed_record.chunk_text = chunk_text embed_record.chunk_text = chunk_text
await db.execute( await db.execute(
text( text(
"UPDATE ai_session_embeddings " "UPDATE ai_session_embeddings "
"SET embedding = :emb::vector, updated_at = now() " "SET embedding = :emb::vector, updated_at = now() "
"WHERE session_id = :sid" "WHERE session_id = :sid"
), ),
{"emb": embedding_str, "sid": str(session_id)}, {"emb": embedding_str, "sid": str(session_id)},
) )
else: else:
# Insert new via raw SQL to include vector column # Insert new via raw SQL to include vector column
await db.execute( await db.execute(
text(""" text("""
INSERT INTO ai_session_embeddings INSERT INTO ai_session_embeddings
(id, session_id, account_id, chunk_text, embedding_model, embedding, created_at, updated_at) (id, session_id, account_id, chunk_text, embedding_model, embedding, created_at, updated_at)
VALUES VALUES
(gen_random_uuid(), :session_id, :account_id, :chunk_text, :model, :embedding::vector, now(), now()) (gen_random_uuid(), :session_id, :account_id, :chunk_text, :model, :embedding::vector, now(), now())
"""), """),
{ {
"session_id": str(session_id), "session_id": str(session_id),
"account_id": str(session.account_id), "account_id": str(session.account_id),
"chunk_text": chunk_text, "chunk_text": chunk_text,
"model": "voyage-3.5", "model": "voyage-3.5",
"embedding": embedding_str, "embedding": embedding_str,
}, },
) )
await db.flush()
except Exception: except Exception:
logger.warning( logger.warning(
"Failed to generate embedding for session %s", session_id, exc_info=True "Failed to generate embedding for session %s", session_id, exc_info=True