61 Commits

Author SHA1 Message Date
Claude
1c6e22ceb3 docs: update CHANGELOG with tenant isolation Phase 1 and recent fixes
Added comprehensive documentation of Tenant Isolation Phase 1 including RLS implementation,
database role separation, admin database isolation, account_id backfill, and global content
separation. Also documented recent bug fixes including non-default tree handling.

https://claude.ai/code/session_01CwER1BcCEGkdUNLRR3fqc6
2026-04-10 10:46:04 +00:00
chihlasm
8292e6ec65 fix: handle non-default, no-team trees in global content migration
Migration 019 only backfills trees with team_id IS NOT NULL.
Migration 3a40fe11b427 only covered is_default=TRUE trees.
Trees with team_id=NULL and is_default=FALSE (e.g. inactive test trees,
pre-team-system content) fell through both passes and triggered the NULL
guard.

Add two new UPDATE steps after the is_default pass:
1. Assign remaining trees to their author's account (if author has one)
2. Final fallback to PLATFORM_ACCOUNT_ID for any still-NULL rows

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 05:21:26 +00:00
chihlasm
20bd428d83 Merge pull request #133 from resolutionflow/feat/tenant-isolation-phase-1
feat: Phase 1 tenant isolation — add account_id to all tenant tables
2026-04-10 00:57:53 -04:00
chihlasm
b9da0e7107 chore: resolve merge conflicts with main
- deps.py: keep require_tenant_context + require_admin_db (RLS deps);
  drop unused get_tenant_context stub from Phase 0
- categories.py: keep both PLATFORM_ACCOUNT_ID and tenant_filter imports
  (body uses both)
- tenant-isolation spec: keep main's resolved TargetList/teams audit answers

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 04:57:39 +00:00
chihlasm
8f044849d4 fix: get_tree returns 404 (not 403) for inaccessible trees — don't leak resource existence
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 04:17:31 +00:00
chihlasm
14304be383 fix: correct RLS test fixtures — tree_structure NOT NULL, tree_tags schema, session-scoped set_config
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 04:15:41 +00:00
chihlasm
a5c5eb6cc3 fix: convert DATABASE_URL_SYNC from property to overridable field for Alembic superuser URL
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 04:03:32 +00:00
chihlasm
c4f919f3a5 feat: migration — enable RLS on trees, tags, categories, psa_connections, flow_proposals
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 04:02:10 +00:00
chihlasm
8de6ee7aa4 feat: migration — create resolutionflow_app and resolutionflow_admin DB roles
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:59:28 +00:00
chihlasm
83ad2e0661 feat: migrate admin endpoints to get_admin_db (BYPASSRLS) before RLS switch
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:57:18 +00:00
chihlasm
ce4056c6b9 test: add failing RLS isolation tests (green after Task 10 migration + Task 11 URL switch)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:54:42 +00:00
chihlasm
9d60b9a244 feat: apply require_tenant_context to all user-facing routers
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:52:52 +00:00
chihlasm
df9ecf2d29 feat: add require_tenant_context and require_admin_db dependencies
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:50:59 +00:00
chihlasm
b0e5f12897 feat: register RLS transaction-begin listener on app engine at startup
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:49:49 +00:00
chihlasm
b4f8694f6b feat: add tenant_context module — ContextVar, transaction listener, tenant_filter
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:48:34 +00:00
chihlasm
6f1becf21f feat: add admin_engine and get_admin_db for BYPASSRLS admin endpoints
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:46:29 +00:00
chihlasm
acbfb3fb37 feat: add ADMIN_DATABASE_URL setting with fallback to DATABASE_URL
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 03:45:52 +00:00
chihlasm
a394a1d464 fix: replace account_id=None with PLATFORM_ACCOUNT_ID for global content
After migration 174f442795b7 enforces NOT NULL on account_id, all
platform/global content must use the sentinel platform account instead
of NULL. Three categories of fixes:

1. trees.py: is_default trees now get PLATFORM_ACCOUNT_ID (not None)
2. admin_categories.py: global category CRUD now uses PLATFORM_ACCOUNT_ID
3. categories.py, tags.py, step_categories.py: creation endpoints coerce
   None → PLATFORM_ACCOUNT_ID; IS NULL filter queries updated to
   == PLATFORM_ACCOUNT_ID (IS NULL queries returned empty after migration
   backfilled all global rows to the platform account)

Defines PLATFORM_ACCOUNT_ID constant in app/core/service_account.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:35:52 +00:00
chihlasm
d2ebc4f182 fix: correct tree tags subquery in template_trees migration
The INSERT into template_trees incorrectly referenced `tags` as a column
on the `trees` table. Tags are a relationship via the `tree_tag_assignments`
join table — there is no direct column. Migration was failing with:

  UndefinedColumn: column "tags" does not exist ... FROM trees

Fixed by replacing COALESCE(tags, '[]') with a correlated subquery that
aggregates tag names from tree_tag_assignments → tree_tags.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 17:30:05 +00:00
chihlasm
8bcf08ae06 fix: persist account ownership for script templates and generations 2026-04-09 17:18:38 +00:00
Claude
85575839f2 docs: update CHANGELOG with tenant isolation Phase 0 and security fixes
- Add Tenant Isolation Phase 0 (#132) — app-layer filtering, cross-tenant audit, UUID isolation
- Document CRITICAL copilot tree query isolation fix (#131)
- Add AI session search, analytics, category, PSA retry, and task lane fixes
- Note 404 (not 403) responses for cross-tenant access to avoid confirming resource existence

https://claude.ai/code/session_014EUBLi2jHrnzJupcetmdwV
2026-04-09 10:41:21 +00:00
chihlasm
478205c208 fix: platform account fallback for script_templates seeded without team/user
Migration 057 inserts 6 AD script templates with NULL team_id and NULL
created_by. Neither backfill path (created_by→users, team_id→team admin)
could attribute them to an account, causing the verify check to fail.

Fix: pre-create the platform sentinel account (ON CONFLICT DO NOTHING,
safe since 3a40fe11b427 also creates it idempotently) and add a final
fallback UPDATE assigning any remaining NULL script_templates to it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 06:41:00 +00:00
chihlasm
0f33feb6d6 fix: use correlated subquery in psa_post_log backfill to avoid invalid FROM-clause reference
PostgreSQL UPDATE...FROM does not allow the updated table to be
referenced inside the FROM clause's JOIN conditions. Replace the
LEFT JOIN psa_connections with a correlated subquery.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 06:31:17 +00:00
chihlasm
034b858fc9 fix: add depends_on 067 to cc214c63aa30 to fix fresh-DB migration order
session_resolution_outputs is created in migration 067 (sequential branch
from 064). On fresh databases, Alembic could run cc214c63aa30 before 067,
causing "table does not exist" errors. depends_on ensures 067 always runs
first regardless of branch traversal order.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 06:20:00 +00:00
chihlasm
b937cb41e4 fix: merge Phase 1 account_id chain with main head to resolve multiple-heads error
Combines the Phase 1 tenant isolation chain (064 → ... → 174f442795b7)
with the main sequential chain (064 → ... → 070) into a single Alembic
head (a9f3b2c1d4e5) so `alembic upgrade head` in the Dockerfile works
without ambiguity.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 06:14:04 +00:00
chihlasm
0d475c71ed fix: correct Phase 1 down_revision — chain from 064 not b8d2f4a6c091
b8d2f4a6c091 was NOT the production head. The true head was 064
(064_normalize_script_builder_messages) via the chain:
b8d2f4a6c091 → f0aad74ea51b → 062 → 063 → 064

This caused 'multiple head revisions' on Railway deployment.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 06:04:10 +00:00
chihlasm
417fa562ce fix: Task 9 migration — include tags in template_trees INSERT
The tags column was accidentally omitted from the is_default tree copy.
Now uses COALESCE(tags, '[]'::jsonb) to preserve source tree tags.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:34:59 +00:00
chihlasm
42937b24a4 feat: Phase 1 Group 9 — enforce NOT NULL on all account_id columns
All previously-nullable account_id columns are now NOT NULL.
tree_embeddings and feedback backfilled before constraint applied.
Global content assigned to platform sentinel account (00000000-...-0001)
in preceding migration.

Tables updated: users, trees, tree_categories, tree_tags,
step_categories, step_library, tree_embeddings, feedback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:34:32 +00:00
chihlasm
b4b8c67d3b feat: Phase 1 Group 10 — create global content tables and platform account
Creates template_trees and platform_steps (no account_id, no RLS).
Migrates is_default=TRUE trees and public steps into them.
Creates sentinel platform account (00000000-...-0001) for global
tree_categories, tree_tags, step_categories, step_library, and
is_default trees — clearing all NULL account_id rows in those tables
as prerequisite for Group 9 SET NOT NULL.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:31:33 +00:00
chihlasm
d24da77604 feat: Phase 1 Group 8 — add account_id to target_lists (keep team_id)
Zero rows in production — this is a schema-only migration in practice.
team_id kept for app code compatibility. Drop deferred to later cleanup.
Backfill: team_id → team admin user → account_id; fallback: created_by.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:25:24 +00:00
chihlasm
857e782d14 feat: Phase 1 Group 7 — add account_id to script tables (keep team_id)
team_id is kept in all three tables — drop deferred until app code
is fully migrated off team_id references.

Tables: script_builder_sessions, script_templates, script_generations
Backfill: user_id/created_by → users.account_id

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:23:35 +00:00
chihlasm
086c4580f1 feat: Phase 1 Group 6 — add account_id to maintenance_schedules
Primary backfill: tree_id → trees.account_id
Fallback: created_by → users.account_id (for is_default tree rows)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:20:56 +00:00
chihlasm
0d69474128 feat: Phase 1 Group 5 — add account_id to PSA and notification tables
psa_post_log: backfill via psa_connection, fallback to posted_by user
psa_member_mappings: backfill via psa_connection
notification_logs: backfill via notification_config

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:19:12 +00:00
chihlasm
b5fdb488b3 feat: Phase 1 Group 4 — add account_id to user_folders and user_pinned_trees
Backfill: user_id → users.account_id

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:16:50 +00:00
chihlasm
de5ecf4fb2 feat: Phase 1 Group 3 — add account_id to step_ratings and step_usage_log
Backfill from rater/user's account_id (not the step's account_id).
This is an explicit design decision — step rating data is attributed
to the account that performed the rating.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:15:10 +00:00
chihlasm
2779a41b94 feat: Phase 1 Group 2 — add account_id to AI branching tables
Tables: session_branches, session_handoffs, fork_points,
        ai_session_steps, ai_suggestions
Backfill: session_id → ai_sessions.account_id (all except
ai_suggestions which uses user_id → users.account_id)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:12:18 +00:00
chihlasm
4666c4f6d2 feat: Phase 1 Group 1 — add account_id to core session tables
Migration sequence: add nullable → backfill via user_id/ai_session chain
→ verify zero NULLs → SET NOT NULL → CREATE INDEX.

Tables: sessions, attachments, session_supporting_data,
        session_resolution_outputs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 05:09:14 +00:00
chihlasm
2837c6e4cf docs: add Phase 1 tenant isolation schema migrations implementation plan
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 04:58:24 +00:00
chihlasm
b3dba57bc5 feat: tenant isolation Phase 0 — app-layer filters, UUID audit, CI gate (#132)
* docs: add tenant data isolation design spec

Complete architecture plan for multi-tenant data isolation across
all layers (PostgreSQL RLS, application-layer filtering, schema
migration, testing strategy, and phased rollout checklist).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs: add background job isolation policy to tenant isolation spec

Documents policy for all 5 existing background jobs:
- Knowledge Flywheel and PSA Retry flagged for account_id threading
- Chat Retention already follows correct pattern (model for others)
- Maintenance Schedule Firing needs account_id in queries + Session creation
- AI Conversation Expiry approved as cross-tenant with justification

Adds approved cross-tenant query registry and Phase 2 checklist items.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs: add tenant isolation Phase 0 implementation plan

8 tasks covering: CRITICAL copilot hotfix, tenant_filter() helper,
get_tenant_context dependency, analytics/category/AI session gap fixes,
full UUID endpoint audit, TargetList dead code audit, teams orphan
check, and CI grep check for missing tenant filters.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat: add tenant_filter() helper and get_tenant_context dependency

tenant_filter(model, account_id) is the canonical app-layer tenant
scoping expression. Every query on a tenant table must use it.
build_tree_access_filter and build_step_visibility_filter updated
to call tenant_filter() internally for the account_id match.

get_tenant_context is a FastAPI dependency that returns account_id
or raises 403 if the user has no account — prevents raw access to
current_user.account_id and centralises the null check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: scope analytics/flows/{tree_id} to requesting account

Any authenticated user could read flow analytics (session counts,
completion rates, CSAT) for any tree UUID. Now returns 404 if the
tree doesn't belong to the requesting account.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: scope category tree_count to requesting account

tree_count on GET /categories/{id} was including trees from all
accounts, leaking cross-tenant row counts.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: restrict AI session search to current user only

Search endpoint used OR(user_id, account_id), exposing other users'
problem_summary and problem_domain within the same account. Sessions
are user-scoped only — cross-user access requires explicit escalation
or sharing. List and search endpoints now behave consistently.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: add ownership check and 404 responses to ai-sessions endpoints

Cross-tenant isolation audit found:
- retry-psa-push had NO ownership check (CRITICAL) — any user could retry any session's PSA push
- save_task_lane used db.get() without ownership filter, returned 403 revealing existence
- get_session returned 403 instead of 404 for unauthorized access
- stream_documentation returned 403 instead of 404

All now use query-level user_id filtering and return 404 to avoid revealing existence.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-tenant session access

All session endpoints (get, update, complete, scratchpad, variables, export,
ticket-link) now return 404 instead of 403 when a user tries to access
another user's session. This prevents confirming existence of resources
across tenant boundaries.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-tenant tree access

get_tree and update_tree now return 404 when a user cannot access a tree
(private tree from another account). Prevents confirming resource existence
across tenant boundaries.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-tenant step access

get_step_or_404 now returns 404 when can_view_step or can_edit_step fails,
preventing confirmation of step existence across tenant boundaries.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-tenant upload access

get_upload_url and delete_upload now return 404 when the upload belongs to
a different account/user, preventing resource existence confirmation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-tenant share access

revoke_share and create_share now return 404 when the caller is not the
owner, preventing resource existence confirmation across users.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-team tree access in maintenance schedules

_get_tree_or_403 now returns 404 when the user's team does not match,
preventing confirmation of tree existence across teams.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-account tag access

get_tag now returns 404 for account-specific tags that belong to another
account, preventing resource existence confirmation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 instead of 403 for cross-account step category access

get_step_category now returns 404 for account-specific categories that
belong to another account, preventing resource existence confirmation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* test: add cross-tenant isolation tests for Task 6 UUID audit

Tests cover:
- Tree GET/PUT returns 404 for cross-account access
- Session GET returns 404 for cross-user access
- AI session GET returns 404 for cross-user access
- AI session retry-psa-push requires ownership
- Upload URL returns 404 for cross-account access
- Share revoke returns 404 for cross-user access

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: return 404 (not 403) for get_documentation cross-user access; add missing Task 6 tests

get_documentation was revealing session existence via 403. Added pre-check
query filtering by session_id AND user_id before calling the engine.

Also add cross-tenant isolation tests for steps, tags, step_categories,
and maintenance_schedules endpoints fixed in Task 6 (TDD was skipped).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: address Task 6 quality review — rename helper, restore 403 for intra-account, add docs test

- Rename _get_tree_or_403 → _get_tree_or_404 in maintenance_schedules.py
  (function now raises 404, old name was misleading)
- Restore HTTP 403 for intra-account permission failures in update_tree:
  same-account users who can see a tree but can't edit it got 404 (wrong);
  only cross-account lookups should return 404 to avoid confirming existence
- Apply same 403/404 distinction to update_tree_visibility
- Add test: get_documentation must return 404 for cross-user session access
- Add comment documenting owner-only design for documentation endpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* chore: Task 7+8 — TargetList audit, CI tenant-filter grep check

Task 7: TargetList dead code audit
- Found active code references in 12+ files across backend and frontend
  (full CRUD API + frontend page + MaintenanceScheduleSection + BatchLaunchModal)
- Decision: migrate to account_id in Phase 1 (cannot drop)
- DB row count not available from code-server — must verify from VPS SSH
  before Phase 1 migration
- Teams orphan check query documented; must run from VPS SSH before Phase 1
- Results documented in spec Section 9

Task 8: CI tenant-filter enforcement check (warn mode)
- Create backend/scripts/check_tenant_filters.py
  Scans endpoint and service files for select() on tenant tables without
  tenant_filter/account_id/user_id in surrounding context. Currently
  reports 109 warnings (Phase 1 backlog). Exits 0 (warn mode).
- Add Check tenant filter enforcement step to backend CI job
  Add --fail flag after Phase 1 backlog clears to make it blocking.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs: record Phase 0 audit results — 0 orphaned teams, 0 target_list rows

Both checks confirmed 2026-04-09 from production DB.
Phase 1 migration is safe to proceed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:42:19 -04:00
chihlasm
29a9573d6e fix: CRITICAL — scope copilot tree query to current account (#131)
* docs: add tenant data isolation design spec

Complete architecture plan for multi-tenant data isolation across
all layers (PostgreSQL RLS, application-layer filtering, schema
migration, testing strategy, and phased rollout checklist).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs: add background job isolation policy to tenant isolation spec

Documents policy for all 5 existing background jobs:
- Knowledge Flywheel and PSA Retry flagged for account_id threading
- Chat Retention already follows correct pattern (model for others)
- Maintenance Schedule Firing needs account_id in queries + Session creation
- AI Conversation Expiry approved as cross-tenant with justification

Adds approved cross-tenant query registry and Phase 2 checklist items.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* docs: add tenant isolation Phase 0 implementation plan

8 tasks covering: CRITICAL copilot hotfix, tenant_filter() helper,
get_tenant_context dependency, analytics/category/AI session gap fixes,
full UUID endpoint audit, TargetList dead code audit, teams orphan
check, and CI grep check for missing tenant filters.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: CRITICAL — scope copilot tree query to current account

A user who knew another account's tree UUID could start a copilot
conversation, causing the tree's full node structure, names, and
descriptions to be sent to the AI as part of the system prompt.

Fix: add account_id (or is_default / visibility='public') filter to
the tree SELECT in copilot_service.start_conversation(). Returns 404
for inaccessible trees. Test added in test_tenant_isolation_p0.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:41:30 -04:00
chihlasm
56775eca04 docs: add tenant isolation Phase 0 implementation plan
8 tasks covering: CRITICAL copilot hotfix, tenant_filter() helper,
get_tenant_context dependency, analytics/category/AI session gap fixes,
full UUID endpoint audit, TargetList dead code audit, teams orphan
check, and CI grep check for missing tenant filters.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 03:02:19 +00:00
chihlasm
82bb7967d8 docs: add background job isolation policy to tenant isolation spec
Documents policy for all 5 existing background jobs:
- Knowledge Flywheel and PSA Retry flagged for account_id threading
- Chat Retention already follows correct pattern (model for others)
- Maintenance Schedule Firing needs account_id in queries + Session creation
- AI Conversation Expiry approved as cross-tenant with justification

Adds approved cross-tenant query registry and Phase 2 checklist items.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:38:55 +00:00
chihlasm
a7dff9e143 docs: add tenant data isolation design spec
Complete architecture plan for multi-tenant data isolation across
all layers (PostgreSQL RLS, application-layer filtering, schema
migration, testing strategy, and phased rollout checklist).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 02:24:38 +00:00
Claude
ba0680ce06 docs: update CHANGELOG with image support, header actions, and design token normalization
- Added image support in Assistant Chat with S3 upload and vision integration
- Moved session lifecycle actions to header bar in AssistantChatPage
- Normalized design system tokens across FlowPilot, AssistantChat, ScriptBuilder
- Fixed 'sorry something went wrong' errors and image display in chat
- Fixed Task Lane stale data and chat ref invalidation race conditions

https://claude.ai/code/session_01LGJSDQqPi3sPWjC6vh9Uyj
2026-04-08 10:40:44 +00:00
chihlasm
290f2be2fd fix: resolve "sorry something went wrong" errors and show images in chat
Three fixes from beta tester session feedback:

1. MCP error handling (backend/app/services/assistant_chat_service.py)
   - The MCP Microsoft Learn integration was catching only BadRequestError.
     Any other error type (APIStatusError, APIConnectionError, timeout) from
     the external MCP server propagated as a 502, causing the generic error.
   - Now catches all Exception types when MCP is active and retries without
     MCP using the stable client.messages.create endpoint.

2. Frontend error UX (frontend/src/pages/AssistantChatPage.tsx)
   - catch {} was silently swallowing all errors and inserting a generic
     assistant message. Now: differentiates 429 (rate limit) vs 502/503
     (AI unavailable), removes the optimistic user message on failure,
     restores the failed message to the input so users can retry without
     retyping, and logs errors to console for debugging.

3. Image attachments visible in chat (frontend/src/components/assistant/ChatMessage.tsx)
   - Uploaded images were sent to the AI correctly but never shown in the
     chat thread. Now captures preview URLs before clearing pendingUploads
     and renders thumbnails above the user bubble, clickable to full size.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 13:09:16 +00:00
chihlasm
e8e12cc7e5 fix: move session lifecycle actions to header bar in AssistantChatPage
- Add persistent session header with title, status badge, Resolve,
  Escalate, and Update Ticket/Share Update buttons — mirrors
  FlowPilotSessionPage pattern exactly
- Update Ticket label when psa_ticket_id present, Share Update otherwise
- Full mobile support via ⋯ overflow menu (Resolve, Escalate, Update, Pause)
- Strip _(not yet completed)_ markers from stored conversation_messages
  in unified_chat_service to prevent stale task lane items from prior
  turns leaking into new sessions via the AI's re-include instruction
- Add currentChatRef guard to handleResumeNew (was missing unlike handleSend)
- Remove Update/Conclude from chatbar — toolbar is now input utilities only

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 06:31:24 +00:00
chihlasm
bf45322c46 Merge pull request #126 from resolutionflow/refactor/dashboard-design-critique
refactor: normalize FlowPilot/Assistant/ScriptBuilder to design system tokens
2026-04-06 20:23:50 -04:00
Michael Chihlas
f45b045943 refactor: resolve merge conflicts — combine main improvements with token normalization
- .gitignore: keep both graphify-out/ entries and main's .gitnexus entry
- ScriptCodeBlock/ScriptPreviewModal: take main's border-border and text-accent-text
  for filename labels; use neutral ghost style for Save button in ScriptCodeBlock;
  use bg-accent (normalized from bg-primary) for Save button in ScriptPreviewModal

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:23:36 -04:00
Michael Chihlas
cef853d7ea refactor: normalize FlowPilot/Assistant/ScriptBuilder to design system tokens
Replace hardcoded Tailwind color utilities with semantic CSS variable tokens
across 31 files in the FlowPilot, Assistant Chat, and Script Builder feature
communities — the areas graphify identified as design-system-free.

- text-blue-400 → text-accent, bg-blue-500/10 → bg-accent-dim, border-blue-500/20 → border-accent/20
- text-amber-400 → text-warning, bg-amber-400/10 → bg-warning-dim, border-l-amber-500 → border-l-warning
- text-rose-400/500 → text-danger, bg-rose-500/10 → bg-danger-dim
- text-emerald-400 → text-success, bg-emerald-500/10 → bg-success-dim, border-l-emerald-500 → border-l-success
- bg-white/[0.08] → bg-elevated (opacity hack → semantic surface token)
- bg-gradient-to-r from-blue-500 to-blue-400 → bg-accent (no gradient surfaces)
- bg-[#60a5fa] → bg-accent (hard-coded hex removed)

Also adds graphify-out/ to .gitignore.

Theme resilience: accent color has changed twice in 5 weeks. Semantic tokens
mean the next change is a 1-line edit in index.css, not 110 grep-and-replace.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:20:07 -04:00
chihlasm
87cf874199 fix: invalidate currentChatRef before await in handleNewChat and handleResumeNew
The previous fix (990f044) moved state clears before the createChatSession
await but left currentChatRef.current pointing at the old session during the
entire network call. Any in-flight handleSend/handleTaskSubmit for the old
session would pass the guard (oldId === oldId) and re-apply stale task lane
data to the new empty session.

Setting currentChatRef.current = null before the await ensures in-flight
handlers from the previous session see a mismatch and bail — matching the
same pattern already used correctly in selectChat.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:56:10 +00:00
chihlasm
2b53315cc9 Merge pull request #125 from resolutionflow/fix/task-lane-partial-submit
fix: resolve task lane stale state, partial submit, and closure bugs
2026-04-06 16:31:41 -04:00
chihlasm
1811889ed9 chore: update docs and redesign landing page hero
- CLAUDE.md: correct Docker container names, update migration format
  docs (hash IDs now default), fix Node path in Lesson 63, update
  design system values to electric blue accent, add retracted lessons
  note, add GitNexus section
- .gitignore: add .gitnexus
- Landing page: replace animated chat preview with ticket-comparison
  hero layout; remove backdrop-filter from scrolled nav (aligns with
  design system); clean up removed chat animation CSS

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:17:47 +00:00
chihlasm
990f04489f fix: prevent TaskLane showing stale data when starting new chat
Three race conditions in AssistantChatPage:

1. handleNewChat cleared showTaskLane/activeQuestions/activeActions
   AFTER the createChatSession await — old lane was visible during
   the network call. Moved clears before the await.

2. handleResumeNew never cleared old TaskLane state at all. Added
   upfront clears before the first await.

3. handleSend and handleTaskSubmit had no stale-session guard. If
   the user switched chats while sendChatMessage was in flight, the
   response would set showTaskLane on the wrong session. Added
   sentForChatId snapshot + currentChatRef guard (same pattern
   already used in selectChat).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 20:17:39 +00:00
chihlasm
ba815d3ee5 Merge remote-tracking branch 'origin/main' into fix/task-lane-partial-submit 2026-04-06 20:14:45 +00:00
chihlasm
8bd395a0c7 fix: resolve task lane stale state, partial submit, and closure bugs
- Import and call clearTaskState before updating questions/actions in
  handleSend and handleTaskSubmit so new AI tasks always replace stale
  sessionStorage cache instead of being overridden by it
- Include pending (not yet completed) tasks in the AI message on partial
  submit so the AI knows which tasks were left unanswered
- Fix stale closure in TaskLane saveTaskLane useEffect — use refs for
  questions/actions so the debounced backend save always uses current values
- Add responses field to pending_task_lane TypeScript type, removing the
  unsafe double-cast in selectChat
- Instruct the AI to re-surface incomplete tasks unless ≥75% confident
  the information is no longer needed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-06 16:53:48 +00:00
Claude
7198c165b2 docs: update CHANGELOG with session documentation overhaul and client communications
Added entries for:
- Session documentation overhaul with reformatted PSA notes, decimal hour display,
  and follow-up recommendations
- Client communication improvements with request_info audience type
- PSA documentation formatting enhancements
- Status update generation improvements
- Option label resolution fix

https://claude.ai/code/session_01GpyJYk4F3eGiJXwsgycChK
2026-04-06 10:35:01 +00:00
chihlasm
58fe3574bf docs: resolve all contract decisions from codex readiness review
Addresses every Red and Yellow item from the codex review:
- Canonical handoff: ResolutionOutputGenerator is the source of truth
- AI vs manual authority: manual edits win, AI never overwrites
- evidence_items: full-list replacement, frontend is merge authority
- TaskLane persistence: lifted into hook, StepsPanel is presentation-only
- Quick replies: immediate-send, full-stack contract change
- issue_category + asset_name: free text in v1
- Adds 5 implementation guardrails and Phase 2 gate for triage extraction
- Execution order updated to 37 steps with persistence extraction step

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:41:43 +00:00
chihlasm
63a84be921 docs: merge codex insights into claude super plan
Adds key architectural choices summary, assumptions section,
sidebar visual demotion (F9), message click-to-expand in compact
log, and backend-first rationale from the codex plan.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:41:43 +00:00
chihlasm
75971d8b97 docs: add MSP assistant harness super plan (claude synthesis)
Merges MSP_Assistant_Harness_Implementation_Plan.docx with the
brainstorming design spec into a single executable plan. Resolves
all open questions from the original docx, expands scope to include
backend changes, and adds a 35-step phased execution order.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:41:43 +00:00
chihlasm
7998dd237d docs: add MSP assistant harness cockpit design spec
Design spec for evolving /assistant into a live triage cockpit.
Covers layout decisions (stacked zones, drag-resizable split),
incident header (labelled fields, AI-inferred + editable),
work zone (steps checklist + FlowPilot Asks + What We Know),
conclude modal redesign, and all required backend changes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:41:43 +00:00
chihlasm
f4143e52a1 feat: overhaul session documentation, PSA notes, and client communications
- Reformat PSA resolution/escalation notes: clean single-line header,
  steps with engineer responses inline, remove duplicate timing blocks,
  remove AI confidence section, add follow-up recommendations
- Standardize time display to decimal hours (e.g. 0.25 hrs) across all
  note formatters and status update context
- Add follow_up_recommendations to SessionDocumentation schema and
  surface in SessionDocView; extracted from resolution suggestion steps
- Add _build_what_we_know() helper: uses session.evidence_items when
  cockpit branch merges, falls back to deriving findings from steps
- Fix option label lookup in generate_status_update (was passing raw
  machine values to AI instead of human-readable labels)
- Add 'What We Know' section to status update ticket notes prompt
- Improve _build_session_context in resolution_output_generator to
  include intake text and full step details instead of truncated chat
- Add request_info audience type: client-facing information request
  that skips the length step and generates a numbered question list
- Improve client_update and email_draft prompts with per-context
  guidance (status/resolution/escalation) and fix escalation subject
  line from 'Specialist Review' to 'Specialist Assistance'

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 15:18:31 +00:00
199 changed files with 10885 additions and 17010 deletions

View File

@@ -47,6 +47,11 @@ jobs:
- name: Install dependencies
run: pip install -r backend/requirements.txt -r backend/requirements-dev.txt
- name: Check tenant filter enforcement
run: cd backend && python scripts/check_tenant_filters.py
# Warn mode only (exits 0). Switch to --fail after Phase 1 backlog clears.
# See: docs/superpowers/specs/2026-04-09-tenant-data-isolation-design.md Section 3f
- name: Run tests with coverage
run: cd backend && python -m pytest --override-ini="addopts=" --cov=app --cov-report=term-missing --cov-report=json:coverage.json --cov-fail-under=50

5
.gitignore vendored
View File

@@ -233,3 +233,8 @@ package.json
package-lock.json
.worktrees/
.gstack/
.gitnexus
# graphify knowledge graph outputs
graphify-out/
.graphify_python

View File

@@ -4,37 +4,46 @@ All notable changes to ResolutionFlow are documented here.
## [Unreleased]
## [2026-04-04] Network Diagram Editor UX Improvements
### Added
- Snap-to-grid (20px) on Network Diagram canvas — nodes align consistently when dragged
- NodeResizer on group nodes (subnet/VLAN/site/DMZ) — select a group and drag its handles to resize
- Group node dimensions now saved to and restored from the backend on reload
### Fixed
- Connection edges now render as straight lines instead of orthogonal bent paths
- ISP device now appears inside the Cloud category in the sidebar instead of a standalone "Internet" section; respects search and item count
- Group nodes now restore correctly as `type: 'group'` on diagram load (previously loaded as `type: 'device'`, breaking group display after save)
---
### Added
- Tree Templates + Import/Export marketplace (#66)
- Recurring Issue Detection — client-specific pattern alerts (#60)
- Step Feedback Flag — "This Step is Wrong" reporting (#58)
- **Tenant Isolation Phase 0** — multi-tenant data isolation (#132) with app-layer filtering helpers (`tenant_filter()`, `get_tenant_context`), cross-tenant access audit (analytics, categories, AI sessions, trees), UUID endpoint isolation with 404 responses for unauthorized access, ownership checks on all sensitive operations, and CI grep gate for missing tenant filters
- **Tenant Isolation Phase 1** — PostgreSQL Row-Level Security (RLS) enforcement across all core tables (trees, tags, categories, psa_connections, flow_proposals) with database role separation (`resolutionflow_app` for user operations, `resolutionflow_admin` with BYPASSRLS for admin endpoints), admin database engine isolation, tenant context via `ContextVar` with automatic transaction-scoped enforcement, `account_id` column backfill on 35+ tables (sessions, AI branching, PSA, notifications, scripts, targets, folders), global content separation via platform account, fresh-DB migration order fixes
- **Script Library default view** — "All Scripts" tab now displays all accessible scripts (team + library)
- **Session documentation overhaul** — reformatted PSA resolution/escalation notes with cleaner headers, inline engineer responses, decimal hour display (0.25 hrs), follow-up recommendations, and improved "What We Know" section from evidence items
- **Client communication improvements** — new `request_info` audience type for client-facing information requests, improved status update and email draft prompts with per-context guidance
- **Image support in Assistant Chat** — paste/attach images in chat input, uploaded to S3, resized for vision model, displayed in conversation history
### Changed
- **Edit Procedure page** — layout overhaul and color system refinements for better visual hierarchy
- **Flows sidebar navigation** — collapsed to reduce visual noise; session recovery removed from library view
- **Account settings page** — audit fixes for improved consistency and usability
- **PSA documentation formatting** — removed duplicate timing blocks and AI confidence sections; added client-facing communication context guidance
- **Status update generation** — fixed option label lookup to use human-readable labels instead of machine values
- **Assistant Chat session actions** — moved Pause/Resume/Close actions from action bar to page header for consistency with FlowPilot
- **Design system token normalization** — unified FlowPilot, AssistantChat, and ScriptBuilder components to use consistent design tokens
- **Tenant data boundaries** — all session and tree endpoints now return 404 (not 403) for cross-tenant access attempts to avoid confirming resource existence
- **Admin database routing** — privileged operations (analytics, user management) now bypass RLS via dedicated admin engine
### Fixed
- **CRITICAL: Copilot tree query isolation** (#131) — user could access any tree UUID if known, exposing full tree structure to AI. Now scoped to current account with 404 for inaccessible trees.
- **AI session search isolation** — search endpoint leaked other users' sessions via OR(user_id, account_id). Now restricted to current user only.
- **Analytics endpoint isolation** — GET `/analytics/flows/{tree_id}` exposed session counts for any tree UUID. Now returns 404 if tree doesn't belong to requesting account.
- **Category tree counts** — cross-tenant row count leakage via tree_count field in GET `/categories/{id}`. Now scoped to requesting account.
- **PSA retry ownership check** — retry-psa-push had no ownership validation (CRITICAL). Now validates user ownership before allowing retry.
- **Task Lane save operation** — invalid task_lane_item UUIDs returned 403 revealing existence. Now returns 404 and uses query-level filtering.
- Dark text rendering on blue accent step-number badges across all flow types
- Script Library tab ownership filter now preserved across category and search changes
- Race conditions in script builder session creation and slug generation
- Stale async results in Assistant Chat (selectChat) no longer clobber new session task lane
- Sentry DSN hardcoded fallback removed — now uses environment variable only
- Option label resolution in status update context generation
- "Sorry something went wrong" errors in chat when rendering unsupported message types
- Task Lane stale data when creating new chat or resuming from concluded session
- Chat ref invalidation race condition between handleNewChat and async data loads
- Images now properly display in chat message history instead of blank placeholders
- Non-default, no-team trees now properly handled in global content migration
---

156
CLAUDE.md
View File

@@ -1,6 +1,6 @@
# CLAUDE.md - Patherly / ResolutionFlow Project Context
> **Last Updated:** March 27, 2026
> **Last Updated:** April 6, 2026
---
@@ -16,7 +16,8 @@
| Context | Name Used |
|---------|-----------|
| Repository / directory / database / Docker | `patherly` / `patherly_postgres` |
| Repository / directory / database | `patherly` (internal name) |
| Docker containers | `resolutionflow_postgres`, `resolutionflow_frontend`, `resolutionflow_backend` |
| Backend, frontend UI, production URLs | **ResolutionFlow** |
- **Design system:** [DESIGN-SYSTEM.md](DESIGN-SYSTEM.md) — THE source of truth for all design decisions
@@ -44,7 +45,7 @@
- **Phase:** Go-to-Market Validation (Pre-PMF)
- **Backend:** Complete (55+ API endpoints, 100+ integration tests)
- **Frontend:** Core features complete, Tree Editor functional
- **Database:** PostgreSQL with Docker, 98 migrations
- **Database:** PostgreSQL with Docker, 101 migrations
- **Detailed status:** [CURRENT-STATE.md](CURRENT-STATE.md)
### What's In Progress
@@ -96,7 +97,7 @@ patherly/
│ │ ├── services/knowledge_flywheel.py # AI session analysis → flow proposals
│ │ ├── services/knowledge_flywheel_scheduler.py # APScheduler job for batch analysis
│ │ └── services/knowledge_gap_service.py # Weak options & escalation signal detection
│ ├── alembic/ # Database migrations (001-029+)
│ ├── alembic/ # Database migrations (001-070 sequential, then hash IDs)
│ ├── scripts/ # seed_data.py, seed_trees.py
│ └── tests/ # pytest integration tests
├── frontend/
@@ -188,8 +189,8 @@ Official ConnectWise developer guides live in `docs/connectwise/best-practices/`
## Development Commands
```powershell
# Start PostgreSQL
docker start patherly_postgres
# Start PostgreSQL (run from VPS SSH — docker not available inside code-server, see Lesson 103)
docker start resolutionflow_postgres
# Backend (from backend/)
source venv/bin/activate # Linux/Mac
@@ -203,21 +204,19 @@ npm run dev
pytest --override-ini="addopts="
# First time only: create test database
docker exec -it patherly_postgres psql -U postgres -c "CREATE DATABASE patherly_test;"
docker exec -it resolutionflow_postgres psql -U postgres -c "CREATE DATABASE resolutionflow_test;"
# Frontend build (IMPORTANT: stricter than tsc --noEmit — always use as final check)
cd frontend && npm run build
# Database migrations
cd backend && alembic upgrade head
alembic revision --autogenerate -m "Description" --rev-id=NNN # NNN = next sequential number
# IMPORTANT: Migrations use sequential 3-digit IDs (001, 002, ..., 068, 069).
# Check the latest: ls backend/alembic/versions/ | grep -E '^\d{3}_' | sort | tail -1
# The revision ID and filename prefix MUST match (e.g., revision="068", file=068_description.py).
# down_revision MUST point to the previous sequential number. Never use hex hash IDs for new migrations.
alembic revision --autogenerate -m "Description"
# Sequential 3-digit IDs (001070) were used historically. New migrations use Alembic's default hex hash IDs.
# Do NOT pass --rev-id — let Alembic generate the hash automatically.
# Access PostgreSQL
docker exec -it patherly_postgres psql -U postgres -d patherly
# Access PostgreSQL (run from VPS SSH — docker not available inside code-server, see Lesson 103)
docker exec -it resolutionflow_postgres psql -U postgres -d resolutionflow
# Seed data
cd backend && pip install httpx && python -m scripts.seed_trees
@@ -292,7 +291,7 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
**62. Playwright strict mode — scope selectors to avoid ambiguity:** Step titles appear in both the sidebar checklist and main content heading. Use `getByRole('heading', { name })` for the main content, or scope with `page.locator('.animate-scale-in')` for command palette items. `getByText()` frequently matches multiple elements due to the sidebar + main content layout.
**63. Node 20 required for frontend builds:** Vite 7+ requires Node 20.19+. The system Node may be v18; use nvm: `export NVM_DIR="$HOME/.nvm" && source "$NVM_DIR/nvm.sh" && nvm use 20`. For direct binary access without nvm sourcing: `PATH="/home/michaelchihlas/.nvm/versions/node/v20.19.0/bin:$PATH"`.
**63. Node 20 required for frontend builds:** Vite 7+ requires Node 20.19+. The system Node may be v18; use nvm: `export NVM_DIR="$HOME/.nvm" && source "$NVM_DIR/nvm.sh" && nvm use 20`. For direct binary access without nvm sourcing: `PATH="$HOME/.nvm/versions/node/v20.19.0/bin:$PATH"`.
**64. PostHog product analytics:** Initialized via `PostHogProvider` in `main.tsx` with explicit `posthog.init()` + `client` prop pattern. Event helpers in `lib/analytics.ts` — use `analytics.eventName(props)` to track. `identifyUser()` called in `authStore.fetchUser()`, `resetAnalytics()` on logout. Env vars: `VITE_PUBLIC_POSTHOG_KEY`, `VITE_PUBLIC_POSTHOG_HOST`. Autocapture enabled.
@@ -332,7 +331,7 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
**82. `bun` requires PATH setup on devserver01:** `export BUN_INSTALL="$HOME/.bun" && export PATH="$BUN_INSTALL/bin:$PATH"`. The gstack browse binary and Playwright need this. Chromium system deps: `libatk1.0-0 libatk-bridge2.0-0 libcups2 libxkbcommon0 libatspi2.0-0 libxcomposite1 libxdamage1 libxfixes3 libxrandr2 libgbm1 libasound2`.
**83. FlowPilot ActionBar is `position: fixed; bottom: 0`:** Any UI element placed in normal document flow below the session content will be hidden behind it. New fixed-position elements (like the message bar) must use `bottom: 68px` (action bar height) and the same `left: var(--sidebar-w)` pattern. The conversation column uses `pb-32` for clearance.
**83. ~~FlowPilot ActionBar fixed bottom~~ (Superseded by Lesson 93):** Actions moved to the page header. `FlowPilotActionBar` component exists but is no longer used in the main session flow. The only fixed-bottom element is the message input.
**84. AI session `abandoned` status is fully wired:** `POST /ai-sessions/{id}/abandon` sets status to `abandoned` with optional `reason` param. Frontend: `aiSessionsApi.abandonSession()`, `useFlowPilotSession().abandonSession()`, "Close" button in `FlowPilotActionBar`. Redirects to `/sessions` after closing.
@@ -344,6 +343,7 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
**88. Charcoal palette — sidebar-darkest approach:** Sidebar `#0e1016`, page `#16181f`, cards `#1e2028`, borders `#2a2e3a`. This gives more contrast range than true-dark. All colors via CSS variables in `index.css` `@theme` block. Accent is electric blue (#60a5fa), not orange or cyan.
*(Lessons 8991 were retracted.)*
**92. `tsc -b` in Dockerfile is stricter than `npx tsc --noEmit`:** The production build (`tsc -b && vite build`) enforces `noUnusedLocals` and `noUnusedParameters` as hard errors. After any refactor that moves logic between components or removes features, trace every import and destructured prop to remove orphans. IDE warnings (yellow squiggles) flag these — check them before pushing.
@@ -353,7 +353,7 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
**95. Image upload → AI vision pipeline:** Paste/attach images → upload to Railway S3 bucket via `uploadsApi.upload()` → send `upload_ids` with chat message → backend fetches from S3 via `storage_service.download_file()` → resized via `storage_service.resize_image_for_vision()` (Pillow, 1568px max, PNG→JPEG) → base64-encoded → sent as Claude multimodal content blocks. Max 3 images/message. Images are NOT stored in conversation history (text-only). Vision helpers live in `storage_service.py`.
**96. `bg-accent` is ember orange — never use for code/kbd elements:** In Tailwind v4, `bg-accent` maps to `--color-accent: #f97316`. Use `bg-code` for code blocks, `bg-white/[0.12] border border-white/[0.06]` for inline code/badges, `bg-white/[0.08]` for kbd shortcuts. Orange is reserved for interactive elements only (buttons, active nav, links).
**96. `bg-accent` is electric blue — never use for code/kbd elements:** In Tailwind v4, `bg-accent` maps to `--color-accent: #60a5fa` (dark) / `#2563eb` (light). Use `bg-code` for code blocks, `bg-white/[0.12] border border-white/[0.06]` for inline code/badges, `bg-white/[0.08]` for kbd shortcuts. Blue accent is reserved for interactive elements only (buttons, active nav, links). Ember orange (#f97316) is deprecated — do not use.
**97. Railway Object Storage (S3 bucket) is provisioned:** Bucket `resolutionflow-uploads` on Railway canvas. Variables: `STORAGE_ENDPOINT`, `STORAGE_ACCESS_KEY`, `STORAGE_SECRET_KEY`, `STORAGE_BUCKET_NAME`, `STORAGE_REGION` — mapped via variable references on the `patherly` backend service. Accessed via boto3 in `storage_service.py`. Pillow (`Pillow>=10.0.0`) + `libjpeg-dev`/`zlib1g-dev` in Dockerfile for image resize.
@@ -390,16 +390,16 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
**Source of truth:** [DESIGN-SYSTEM.md](DESIGN-SYSTEM.md) — always read this before making visual or UI decisions.
- **Theme:** Flat, high-contrast dark theme (Sentry/PostHog-inspired). No glass morphism, no backdrop blur, no ambient orbs, no gradient backgrounds on surfaces. Light mode planned.
- **Backgrounds:** `bg-page` (`#1a1c23`), `bg-sidebar` (`#10121a`), `bg-card` (`#22252e`), `bg-elevated` (`#2e3140`)
- **Cards:** `bg-card` with 1px `border-default` (`#2e3240`), 8px radius. No shadows, no blur, no gradients. Hover: `border-hover` (`#3d4252`)
- **Buttons:** Primary: solid `accent` (#f97316), white text, 5px radius. Ghost: transparent + 1px border, hover `bg-elevated`
- **Inputs:** `bg-input` (`#282b35`) with 1px `border-default`, 5px radius. Focus: `border-color: accent` + `box-shadow: 0 0 0 2px accent-dim`
- **Text:** `text-heading` (`#f0f2f5`) → `text-primary` (`#e2e5eb`) → `text-muted-foreground` (`#848b9b`) → `text-muted` (`#4f5666`). NEVER use `text-secondary` — in Tailwind v4 it maps to a surface color (#2e3140), not a text color.
- **Borders:** `border-default` (`#2e3240`), `border-hover` (`#3d4252`)
- **Functional colors:** `#34d399` (success), `#eab308` (warning), `#f87171` (danger) — each with `-dim` variant at 10% opacity
- **Accent:** Ember orange `#f97316` — used sparingly (≤5% of UI). `accent-dim` = `rgba(249,115,22,0.10)`, `accent-text` = `#fdba74`
- **Deprecated:** Do NOT use `glass-card`, `glass-stat`, `bg-gradient-brand`, `text-gradient-brand`, `backdrop-filter: blur()`, ambient orbs, purple gradients, or cyan accent (`#22d3ee`)
- **Theme:** Flat, high-contrast dark theme (Sentry/PostHog-inspired). No glass morphism, no backdrop blur, no ambient orbs, no gradient backgrounds on surfaces. Light mode fully specified (v6).
- **Backgrounds:** `bg-page` (`#16181f`), `bg-sidebar` (`#0e1016`), `bg-card` (`#1e2028`), `bg-elevated` (`#2a2d38`)
- **Cards:** `bg-card` with 1px `border-default` (`#2a2e3a`), 8px radius. No shadows, no blur, no gradients. Hover: `border-hover` (`#3d4252`)
- **Buttons:** Primary: solid `accent` (#60a5fa dark / #2563eb light), white text, 5px radius. Ghost: transparent + 1px border, hover `bg-elevated`
- **Inputs:** `bg-input` (`#252830`) with 1px `border-default`, 5px radius. Focus: `border-color: accent` + `box-shadow: 0 0 0 2px accent-dim`
- **Text:** `text-heading` (`#f0f2f5`) → `text-primary` (`#e2e5eb`) → `text-muted-foreground` (`#848b9b`) → `text-muted` (`#4f5666`). NEVER use `text-secondary` — in Tailwind v4 it maps to a surface color, not a text color.
- **Borders:** `border-default` (`#2a2e3a`), `border-hover` (`#3d4252`)
- **Functional colors:** `#34d399` (success), `#fbbf24` (warning/amber), `#f87171` (danger), `#67e8f9` (info/cyan) — each with `-dim` variant at 10% opacity
- **Accent:** Electric blue `#60a5fa` (dark) / `#2563eb` (light) — used sparingly (≤5% of UI). `accent-dim` = `rgba(96,165,250,0.10)`, `accent-text` = `#93c5fd`
- **Deprecated:** Do NOT use `glass-card`, `glass-stat`, `bg-gradient-brand`, `text-gradient-brand`, `backdrop-filter: blur()`, ambient orbs, purple gradients, ember orange (`#f97316`), or cyan (`#22d3ee`) as accent — cyan is now the info color only
---
@@ -518,3 +518,105 @@ When a feature, fix, or significant piece of work is finished and merged/committ
| Bugs & Fixes | CLAUDE.md → Critical Lessons Learned section |
| Design System | [DESIGN-SYSTEM.md](DESIGN-SYSTEM.md) |
| Dev Environment | [DEV-ENV.md](DEV-ENV.md) — 46.202.92.250 setup, Docker, CORS, networking |
<!-- gitnexus:start -->
# GitNexus — Code Intelligence
This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
## Always Do
- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user.
- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows.
- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits.
- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance.
- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`.
## When Debugging
1. `gitnexus_query({query: "<error or symptom>"})` — find execution flows related to the issue
2. `gitnexus_context({name: "<suspect function>"})` — see all callers, callees, and process participation
3. `READ gitnexus://repo/resolutionflow/process/{processName}` — trace the full execution flow step by step
4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed
## When Refactoring
- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`.
- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code.
- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed.
## Never Do
- NEVER edit a function, class, or method without first running `gitnexus_impact` on it.
- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis.
- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph.
- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope.
## Tools Quick Reference
| Tool | When to use | Command |
|------|-------------|---------|
| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` |
| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` |
| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` |
| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` |
| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` |
| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` |
## Impact Risk Levels
| Depth | Meaning | Action |
|-------|---------|--------|
| d=1 | WILL BREAK — direct callers/importers | MUST update these |
| d=2 | LIKELY AFFECTED — indirect deps | Should test |
| d=3 | MAY NEED TESTING — transitive | Test if critical path |
## Resources
| Resource | Use for |
|----------|---------|
| `gitnexus://repo/resolutionflow/context` | Codebase overview, check index freshness |
| `gitnexus://repo/resolutionflow/clusters` | All functional areas |
| `gitnexus://repo/resolutionflow/processes` | All execution flows |
| `gitnexus://repo/resolutionflow/process/{name}` | Step-by-step execution trace |
## Self-Check Before Finishing
Before completing any code modification task, verify:
1. `gitnexus_impact` was run for all modified symbols
2. No HIGH/CRITICAL risk warnings were ignored
3. `gitnexus_detect_changes()` confirms changes match expected scope
4. All d=1 (WILL BREAK) dependents were updated
## Keeping the Index Fresh
After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it:
```bash
npx gitnexus analyze
```
If the index previously included embeddings, preserve them by adding `--embeddings`:
```bash
npx gitnexus analyze --embeddings
```
To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.**
> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`.
## CLI
| Task | Read this skill file |
|------|---------------------|
| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` |
| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` |
| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` |
| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` |
| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` |
| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` |
<!-- gitnexus:end -->

View File

@@ -2,7 +2,7 @@
> **Purpose:** Quick-reference file showing exactly where the project stands.
> **For Claude Code:** Read this first to understand what's done and what's next.
> **Last Updated:** April 4, 2026 (evening)
> **Last Updated:** March 23, 2026
---
@@ -13,8 +13,8 @@
## What's Complete
### Core Platform
- FastAPI project structure with 55+ API endpoints
- PostgreSQL database with Docker, 100+ Alembic migrations
- FastAPI project structure with 35+ API endpoints
- PostgreSQL database with Docker, 75+ Alembic migrations
- User authentication (JWT, register, login, refresh, logout, invite codes)
- Refresh token rotation with JTI-based revocation
- Trees CRUD with full-text search (FTS index)
@@ -29,7 +29,7 @@
### Frontend Core
- React 19 + Vite + TypeScript + Tailwind CSS v4 (`@tailwindcss/vite`)
- **Charcoal Design System v6** — Flat, high-contrast dark theme (Sentry/PostHog-inspired), charcoal palette; accent color is electric blue (#60a5fa), replacing ember orange
- **Charcoal Design System** — Flat, high-contrast dark theme (Sentry/PostHog-inspired), charcoal palette with sidebar-darkest approach
- **Brand fonts:** Bricolage Grotesque (headings), IBM Plex Sans (body), JetBrains Mono (code)
- Authentication UI (login, register, email verification)
- Tree library/browsing page with grid/list/table views
@@ -130,36 +130,6 @@
- Enhanced PSA metrics: time entries, hours logged, push success funnel, daily trend chart
- 13 new backend tests for coverage and flow quality endpoints
### Conversational Branching (Complete)
- SessionBranch, ForkPoint, SessionHandoff, SessionResolutionOutput models + migration (4 tables, 13 columns)
- BranchManager service, BranchAwarePromptBuilder, HandoffManager service with integration tests
- Branch API endpoints: `session_branches.py`, `session_handoffs.py`, `session_resolutions.py`
- Integrated into `unified_chat_service.py` and AI session step creation
- Frontend: BranchNode, ForkCard, BranchMap, BranchRevivalCard, BranchTransitionBar, HandoffModal, ResolutionOutputPanel components
- Wired into FlowPilotSession and `useFlowPilotSession` hook
### Script Library Enhancements (Complete)
- ParameterizeAndSavePanel replaces SaveToLibraryDialog — accepts `script_body` and `parameters_schema` in save flow
- "New from Script" button on ScriptLibraryPage for one-click script creation from template
- Default tab is "All Scripts" (previously filtered to owned scripts)
- Ownership filter state preserved across category and search changes
- Backend: `save-to-library` endpoint accepts `script_body` + `parameters_schema`
### AI Vision Support (Complete)
- Image uploads (paste/drag-drop) wired into AI assistant chat via `upload_ids`
- Server-side image resize before sending to Claude (Pillow, 1568px max, PNG→JPEG)
- `storage_service.resize_image_for_vision()` handles vision pipeline
- Images are NOT stored in conversation history (text-only history)
### Mid-Session Status Updates (Complete)
- AI assistant can generate `status_update` steps (step_type added to CHECK constraint)
- Status update generation wired into `unified_chat_service.py`
- Frontend renders status update cards in session view
### Search & Recall + Evidence-Rich Sessions (Complete)
**Evidence:**
@@ -193,7 +163,7 @@
- SQL wildcard escaping in tag search
- PSA credentials encrypted at rest (Fernet)
### Copilot-First Dashboard (MarchApril 2026)
### Copilot-First Dashboard (March 2026)
- Redesigned dashboard as FlowPilot copilot launchpad (ChatGPT-style input)
- Chat-style input with paste images, drag-drop files, attach button, paste logs
@@ -203,33 +173,9 @@
- Unified Command Palette (Cmd+K) — merged QuickLaunch into omnibar
- "Solutions Library" rename (from "Step Library") site-wide
- Maintenance flows hidden from UI for pilot (backend still supports them)
- Charcoal color palette: sidebar `#0e1016`, page `#16181f`, cards `#1e2028`
- **Landing page redesign** — scroll-driven reveal animations, live chat animation, FAQ section, improved trust signals; copy: "Resolve tickets faster. Notes write themselves."
- **Session History redesign** — tabbed layout with Load More pagination
- **Edit Procedure page** — layout and color system overhaul
- **TaskLane UX** improvements in assistant chat; persistence across page reload
- TaskLane answers persist in sessionStorage; correct behavior on all three chat paths (send, prefill, resume)
- **Action bar consolidation** — Deduplicated actions across FlowPilot/Cockpit headers and chat toolbars; chat toolbar now only has input tools (Attach, Paste Logs, Tasks)
- **ViewToggle redesigned** as persistent tab bar with bottom-border active indicator and ARIA attributes (FlowPilot/Cockpit switcher)
- **Standardized action naming** across all session pages: Resolve (emerald), Update (blue), Close (rose), Pause (muted)
- **ConcludeSessionModal copy refresh** — Forward-facing action verbs, "Close & Generate" CTA, consistent outcome labels
- Deleted unused FlowPilotActionBar component (227 lines dead code)
### Network Diagrams (In Progress)
- Network diagram editor with React Flow (@xyflow/react v12) canvas
- Device node system: 27 device types across 7 categories (network, compute, storage, cloud, endpoint, infrastructure, security)
- Custom device type creation via DeviceToolbar
- Connection edges with 6 types (ethernet, fiber, wifi, vpn, vlan, wan) — color-coded, dashed for wireless/VPN
- Properties panel for editing device and connection details
- AI-assisted diagram generation (describe network → auto-layout)
- Auto-save every 30 seconds, manual save, JSON export
- **React Flow UI Components** — Cherry-picked and Charcoal-restyled: BaseNode (structured header/content/footer slots), BaseHandle (styled connection handles), LabeledHandle (named port labels), NodeStatusIndicator (status border effect: emerald/red/yellow), NodeTooltip (hover details via NodeToolbar), LabeledGroupNode (subnet/VLAN/site/DMZ containers), AnimatedSvgEdge (traffic flow visualization)
- Grouping category in toolbar: Subnet, VLAN, Site, DMZ drag-drop to canvas
- Traffic flow toggle on edges (switches between static and animated)
- Context menu with copy/paste/duplicate/select all shortcuts
- Drop position uses `screenToFlowPosition()` for correct placement at any zoom/pan level
- **Bug fix:** PropertiesPanel inputs now work — selection uses IDs instead of stale object snapshots
- Landing page copy rewrite: "Resolve tickets faster. Notes write themselves."
- Spring bounce hover animation on dashboard cards
- Charcoal color palette: sidebar `#10121a`, page `#1a1c23`, cards `#22252e`
### Maintenance Flows (Hidden from UI)
@@ -289,22 +235,21 @@
### Start Development
```bash
# Start PostgreSQL (Docker — container name resolutionflow_postgres, port 5433, DB resolutionflow)
docker start resolutionflow_postgres
# Start PostgreSQL (Docker Compose)
docker compose up -d
# Backend (from backend/)
source venv/bin/activate
uvicorn app.main:app --reload
# Frontend (from frontend/, requires Node 20)
# Frontend (from frontend/)
npm run dev
```
### URLs
- Frontend: http://46.202.92.250:5173 (or https via Traefik reverse proxy)
- Backend API: http://46.202.92.250:8000
- API Docs: http://46.202.92.250:8000/api/docs
- Dev env runs on Hostinger VPS (46.202.92.250) with Traefik + HTTPS; see [DEV-ENV.md](DEV-ENV.md)
- Frontend: http://192.168.0.9:5173
- Backend API: http://192.168.0.9:8000
- API Docs: http://192.168.0.9:8000/api/docs
### Run Tests
```bash

View File

@@ -1,31 +0,0 @@
"""add triage fields to ai_sessions for cockpit harness
Revision ID: 071
Revises: 070
Create Date: 2026-04-01
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
revision = "071"
down_revision = "070"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("ai_sessions", sa.Column("client_name", sa.String(255), nullable=True))
op.add_column("ai_sessions", sa.Column("asset_name", sa.String(255), nullable=True))
op.add_column("ai_sessions", sa.Column("issue_category", sa.String(100), nullable=True))
op.add_column("ai_sessions", sa.Column("triage_hypothesis", sa.Text(), nullable=True))
op.add_column("ai_sessions", sa.Column("evidence_items", JSONB(), nullable=True))
def downgrade() -> None:
op.drop_column("ai_sessions", "evidence_items")
op.drop_column("ai_sessions", "triage_hypothesis")
op.drop_column("ai_sessions", "issue_category")
op.drop_column("ai_sessions", "asset_name")
op.drop_column("ai_sessions", "client_name")

View File

@@ -1,61 +0,0 @@
"""Seed flowpilot_cockpit feature flag with plan defaults.
Revision ID: 072
Revises: 071
Create Date: 2026-04-02
"""
from alembic import op
import sqlalchemy as sa
revision = "072"
down_revision = "071"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Insert the feature flag
op.execute(
sa.text(
"INSERT INTO feature_flags (id, flag_key, display_name, description) "
"VALUES (gen_random_uuid(), 'flowpilot_cockpit', 'FlowPilot Cockpit', "
"'Access to the FlowPilot Cockpit triage view') "
"ON CONFLICT (flag_key) DO NOTHING"
)
)
# Set plan defaults: disabled for free, enabled for pro and team
op.execute(
sa.text(
"INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) "
"SELECT gen_random_uuid(), 'free', id, false FROM feature_flags WHERE flag_key = 'flowpilot_cockpit' "
"ON CONFLICT (plan, flag_id) DO NOTHING"
)
)
op.execute(
sa.text(
"INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) "
"SELECT gen_random_uuid(), 'pro', id, true FROM feature_flags WHERE flag_key = 'flowpilot_cockpit' "
"ON CONFLICT (plan, flag_id) DO NOTHING"
)
)
op.execute(
sa.text(
"INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) "
"SELECT gen_random_uuid(), 'team', id, true FROM feature_flags WHERE flag_key = 'flowpilot_cockpit' "
"ON CONFLICT (plan, flag_id) DO NOTHING"
)
)
def downgrade() -> None:
op.execute(
sa.text(
"DELETE FROM plan_feature_defaults WHERE flag_id IN "
"(SELECT id FROM feature_flags WHERE flag_key = 'flowpilot_cockpit')"
)
)
op.execute(
sa.text("DELETE FROM feature_flags WHERE flag_key = 'flowpilot_cockpit'")
)

View File

@@ -1,95 +0,0 @@
"""Add device_types table with system seed data.
Revision ID: 073
Revises: 072
Create Date: 2026-04-04
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
import uuid
revision = "073"
down_revision = "072"
branch_labels = None
depends_on = None
SYSTEM_DEVICE_TYPES = [
("router", "Router", "network", 0),
("switch", "Switch", "network", 1),
("firewall", "Firewall", "network", 2),
("access-point", "Access Point", "network", 3),
("load-balancer", "Load Balancer", "network", 4),
("server", "Server", "compute", 0),
("workstation", "Workstation", "compute", 1),
("vm", "Virtual Machine", "compute", 2),
("container", "Container", "compute", 3),
("nas", "NAS", "storage", 0),
("san", "SAN", "storage", 1),
("cloud-storage", "Cloud Storage", "storage", 2),
("cloud", "Cloud", "cloud", 0),
("aws", "AWS", "cloud", 1),
("azure", "Azure", "cloud", 2),
("gcp", "Google Cloud", "cloud", 3),
("printer", "Printer", "endpoint", 0),
("phone", "Phone", "endpoint", 1),
("iot", "IoT Device", "endpoint", 2),
("camera", "Camera", "endpoint", 3),
("tablet", "Tablet", "endpoint", 4),
("laptop", "Laptop", "endpoint", 5),
("ups", "UPS", "infrastructure", 0),
("pdu", "PDU", "infrastructure", 1),
("rack", "Rack", "infrastructure", 2),
("patch-panel", "Patch Panel", "infrastructure", 3),
("nvr", "NVR", "security", 0),
("badge-reader", "Badge Reader", "security", 1),
]
def upgrade() -> None:
op.create_table(
"device_types",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("slug", sa.String(50), nullable=False),
sa.Column("label", sa.String(100), nullable=False),
sa.Column("category", sa.String(50), nullable=False),
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=True),
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.execute(
"ALTER TABLE device_types ADD CONSTRAINT uq_device_types_slug_team "
"UNIQUE NULLS NOT DISTINCT (slug, team_id)"
)
op.create_index("idx_device_types_team", "device_types", ["team_id"])
device_types_table = sa.table(
"device_types",
sa.column("id", UUID(as_uuid=True)),
sa.column("slug", sa.String),
sa.column("label", sa.String),
sa.column("category", sa.String),
sa.column("is_system", sa.Boolean),
sa.column("team_id", UUID(as_uuid=True)),
sa.column("sort_order", sa.Integer),
)
op.bulk_insert(device_types_table, [
{
"id": uuid.uuid4(),
"slug": slug,
"label": label,
"category": category,
"is_system": True,
"team_id": None,
"sort_order": sort_order,
}
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
])
def downgrade() -> None:
op.drop_table("device_types")

View File

@@ -1,41 +0,0 @@
"""Add network_diagrams table.
Revision ID: 074
Revises: 073
Create Date: 2026-04-04
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
revision = "074"
down_revision = "073"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"network_diagrams",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
sa.Column("team_id", UUID(as_uuid=True), sa.ForeignKey("teams.id", ondelete="CASCADE"), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("client_name", sa.String(255), nullable=True),
sa.Column("asset_name", sa.String(255), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("nodes", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
sa.Column("edges", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
sa.Column("thumbnail_url", sa.Text(), nullable=True),
sa.Column("is_archived", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
)
op.create_index("idx_network_diagrams_team", "network_diagrams", ["team_id"])
op.create_index("idx_network_diagrams_client", "network_diagrams", ["team_id", "client_name"])
def downgrade() -> None:
op.drop_table("network_diagrams")

View File

@@ -0,0 +1,102 @@
"""create_db_roles
Revision ID: 0b470d9e6cf1
Revises: a9f3b2c1d4e5
Create Date: 2026-04-10 03:58:10.207919
"""
import os
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision: str = '0b470d9e6cf1'
down_revision: Union[str, None] = 'a9f3b2c1d4e5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Passwords from env vars. For local dev, defaults are sufficient.
# For production (Railway), set DB_APP_ROLE_PASSWORD and
# DB_ADMIN_ROLE_PASSWORD as environment variables before running migrations.
# Passwords must not contain single quotes.
app_pw = os.environ.get("DB_APP_ROLE_PASSWORD", "app_secret_change_me")
admin_pw = os.environ.get("DB_ADMIN_ROLE_PASSWORD", "admin_secret_change_me")
# Fetch the current database name dynamically — avoids hardcoding
# (the DB is named 'resolutionflow' in dev, potentially different elsewhere).
conn = op.get_bind()
db_name = conn.execute(text("SELECT current_database()")).scalar()
# ── Application role ────────────────────────────────────────────────────
# Subject to RLS. Used by FastAPI at runtime via DATABASE_URL.
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_app') THEN
CREATE ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}';
ELSE
ALTER ROLE resolutionflow_app LOGIN PASSWORD '{app_pw}';
END IF;
END $$
""")
op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_app")
op.execute("GRANT USAGE ON SCHEMA public TO resolutionflow_app")
op.execute(
"GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public "
"TO resolutionflow_app"
)
op.execute(
"GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO resolutionflow_app"
)
# Ensure future tables automatically get the same permissions
op.execute(
"ALTER DEFAULT PRIVILEGES IN SCHEMA public "
"GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO resolutionflow_app"
)
op.execute(
"ALTER DEFAULT PRIVILEGES IN SCHEMA public "
"GRANT USAGE, SELECT ON SEQUENCES TO resolutionflow_app"
)
# ── Admin role ──────────────────────────────────────────────────────────
# BYPASSRLS. Used by Alembic (DATABASE_URL_SYNC) and /admin/* endpoints
# (ADMIN_DATABASE_URL) after Task 11.
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'resolutionflow_admin') THEN
CREATE ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}';
ELSE
ALTER ROLE resolutionflow_admin LOGIN PASSWORD '{admin_pw}';
END IF;
END $$
""")
op.execute("GRANT resolutionflow_app TO resolutionflow_admin")
op.execute("ALTER ROLE resolutionflow_admin BYPASSRLS")
op.execute(f"GRANT CONNECT ON DATABASE {db_name} TO resolutionflow_admin")
def downgrade() -> None:
conn = op.get_bind()
db_name = conn.execute(text("SELECT current_database()")).scalar()
op.execute(
"REVOKE ALL ON ALL TABLES IN SCHEMA public FROM resolutionflow_app"
)
op.execute(
"REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM resolutionflow_app"
)
op.execute(
f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_app"
)
op.execute(
f"REVOKE CONNECT ON DATABASE {db_name} FROM resolutionflow_admin"
)
op.execute("DROP ROLE IF EXISTS resolutionflow_admin")
op.execute("DROP ROLE IF EXISTS resolutionflow_app")

View File

@@ -0,0 +1,86 @@
"""set NOT NULL on all previously-nullable account_id columns
Revision ID: 174f442795b7
Revises: 3a40fe11b427
Create Date: 2026-04-09 00:00:00.000000
All tables in this migration had account_id set to nullable previously.
Task 9 (create_global_content_tables) cleared all NULL rows.
This migration enforces the NOT NULL constraint.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '174f442795b7'
down_revision: Union[str, None] = '3a40fe11b427'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# tree_embeddings: backfill from trees (must happen before SET NOT NULL)
op.execute("""
UPDATE tree_embeddings te
SET account_id = t.account_id
FROM trees t
WHERE te.tree_id = t.id
AND te.account_id IS NULL
""")
# feedback: backfill from users
op.execute("""
UPDATE feedback f
SET account_id = u.account_id
FROM users u
WHERE f.user_id = u.id
AND f.account_id IS NULL
""")
# Verify ALL tables before touching any SET NOT NULL
tables_with_account_id = [
'users', 'trees', 'tree_categories', 'tree_tags',
'step_categories', 'step_library', 'tree_embeddings', 'feedback',
]
for table in tables_with_account_id:
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} NULL account_id rows in {table}. "
"Run Task 9 (create_global_content_tables) first, or "
"manually backfill/delete orphaned rows."
)
# SET NOT NULL on all
for table in tables_with_account_id:
op.alter_column(table, 'account_id', nullable=False)
# Create indexes where they don't already exist
new_indexes = [
('tree_embeddings', 'ix_tree_embeddings_account_id'),
('feedback', 'ix_feedback_account_id'),
]
for table, index_name in new_indexes:
result = op.get_bind().execute(sa.text(
f"SELECT 1 FROM pg_indexes WHERE tablename='{table}' AND indexname='{index_name}'"
))
if not result.fetchone():
op.create_index(index_name, table, ['account_id'])
def downgrade() -> None:
# Revert to nullable
for table in ('users', 'trees', 'tree_categories', 'tree_tags',
'step_categories', 'step_library', 'tree_embeddings', 'feedback'):
op.alter_column(table, 'account_id', nullable=True)
for table, index_name in (
('tree_embeddings', 'ix_tree_embeddings_account_id'),
('feedback', 'ix_feedback_account_id'),
):
try:
op.drop_index(index_name, table_name=table)
except Exception:
pass

View File

@@ -0,0 +1,62 @@
"""add account_id to target_lists (keep team_id)
Revision ID: 2c6aabd89bc6
Revises: 78fc200abac1
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '2c6aabd89bc6'
down_revision: Union[str, None] = '78fc200abac1'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('target_lists', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_target_lists_account_id', 'target_lists', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Primary: team_id → team admin user → account_id
op.execute("""
UPDATE target_lists tl
SET account_id = u.account_id
FROM users u
WHERE u.team_id = tl.team_id
AND u.is_team_admin = TRUE
AND u.account_id IS NOT NULL
AND tl.account_id IS NULL
""")
# Fallback: created_by → users.account_id
op.execute("""
UPDATE target_lists tl
SET account_id = u.account_id
FROM users u
WHERE tl.created_by = u.id
AND u.account_id IS NOT NULL
AND tl.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM target_lists WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} target_lists rows have NULL account_id. "
"No team admin found for these teams. Resolve before re-running."
)
op.alter_column('target_lists', 'account_id', nullable=False)
op.create_index('ix_target_lists_account_id', 'target_lists', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_target_lists_account_id', table_name='target_lists')
op.drop_constraint('fk_target_lists_account_id', 'target_lists', type_='foreignkey')
op.drop_column('target_lists', 'account_id')

View File

@@ -0,0 +1,175 @@
"""create template_trees and platform_steps global content tables
Revision ID: 3a40fe11b427
Revises: 2c6aabd89bc6
Create Date: 2026-04-09 00:00:00.000000
These tables hold platform-owned content that is readable by all
authenticated users. No account_id. No RLS. Ever.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSONB
revision: str = '3a40fe11b427'
down_revision: Union[str, None] = '2c6aabd89bc6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Create template_trees ─────────────────────────────────────────────────
op.create_table(
'template_trees',
sa.Column('id', UUID(), primary_key=True),
sa.Column('name', sa.String(255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('category', sa.String(100), nullable=True),
sa.Column('tree_type', sa.String(20), nullable=False),
sa.Column('tree_structure', JSONB(), nullable=False),
sa.Column('tags', JSONB(), nullable=False, server_default='[]'),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('source_tree_id', UUID(), sa.ForeignKey('trees.id', ondelete='SET NULL'), nullable=True),
)
op.create_index('ix_template_trees_tree_type', 'template_trees', ['tree_type'])
# ── Create platform_steps ────────────────────────────────────────────────
op.create_table(
'platform_steps',
sa.Column('id', UUID(), primary_key=True),
sa.Column('title', sa.String(255), nullable=False),
sa.Column('step_type', sa.String(50), nullable=False),
sa.Column('content', JSONB(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('source_step_id', UUID(), sa.ForeignKey('step_library.id', ondelete='SET NULL'), nullable=True),
)
op.create_index('ix_platform_steps_step_type', 'platform_steps', ['step_type'])
# ── Copy is_default=TRUE trees → template_trees ─────────────────────────
# Note: trees.tags is a relationship via tree_tags join table — no direct column.
# Aggregate tag names via a correlated subquery.
op.execute("""
INSERT INTO template_trees
(id, name, description, category, tree_type, tree_structure,
tags, is_active, created_at, updated_at, source_tree_id)
SELECT
gen_random_uuid(), t.name, t.description, t.category, t.tree_type,
t.tree_structure,
COALESCE(
(SELECT jsonb_agg(tt.name ORDER BY tt.name)
FROM tree_tag_assignments ta
JOIN tree_tags tt ON tt.id = ta.tag_id
WHERE ta.tree_id = t.id),
'[]'::jsonb
),
t.is_active,
COALESCE(t.created_at, NOW()), COALESCE(t.updated_at, NOW()), t.id
FROM trees t
WHERE t.is_default = TRUE
""")
# ── Copy visibility='public' steps → platform_steps ─────────────────────
op.execute("""
INSERT INTO platform_steps
(id, title, step_type, content, is_active, created_at, updated_at, source_step_id)
SELECT
gen_random_uuid(), title, step_type, content, is_active,
COALESCE(created_at, NOW()), COALESCE(updated_at, NOW()), id
FROM step_library
WHERE visibility = 'public'
""")
# ── Create platform sentinel account ─────────────────────────────────────
op.execute("""
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
VALUES (
'00000000-0000-0000-0000-000000000001',
'ResolutionFlow Platform',
'PLATFORM',
NOW(),
NOW()
)
ON CONFLICT (id) DO NOTHING
""")
# ── Assign is_default trees to platform account ──────────────────────────
op.execute("""
UPDATE trees
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE is_default = TRUE
AND account_id IS NULL
""")
# ── Assign remaining trees to their author's account ─────────────────────
# Handles trees with no team_id that aren't is_default (e.g. inactive test
# trees, trees created before the team system existed).
op.execute("""
UPDATE trees
SET account_id = u.account_id
FROM users u
WHERE trees.author_id = u.id
AND trees.account_id IS NULL
AND u.account_id IS NOT NULL
""")
# ── Final fallback: any still-NULL trees go to platform account ───────────
# Covers trees whose author has no account (seeded content, system rows).
op.execute("""
UPDATE trees
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE account_id IS NULL
""")
# ── Assign global categories/tags/steps to platform account ─────────────
op.execute("""
UPDATE tree_categories
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE account_id IS NULL
""")
op.execute("""
UPDATE tree_tags
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE account_id IS NULL
""")
op.execute("""
UPDATE step_categories
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE account_id IS NULL
""")
op.execute("""
UPDATE step_library
SET account_id = '00000000-0000-0000-0000-000000000001'
WHERE account_id IS NULL
""")
# ── Verify zero NULLs in all 5 tables ───────────────────────────────────
for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'):
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} NULL account_id rows remain in {table} "
"after platform account assignment. Investigate before re-running."
)
def downgrade() -> None:
platform_id = '00000000-0000-0000-0000-000000000001'
for table in ('trees', 'tree_categories', 'tree_tags', 'step_categories', 'step_library'):
op.execute(f"UPDATE {table} SET account_id = NULL WHERE account_id = '{platform_id}'")
op.execute(f"DELETE FROM accounts WHERE id = '{platform_id}'")
op.drop_index('ix_platform_steps_step_type', table_name='platform_steps')
op.drop_index('ix_template_trees_tree_type', table_name='template_trees')
op.drop_table('platform_steps')
op.drop_table('template_trees')

View File

@@ -0,0 +1,77 @@
"""add account_id to AI branching tables
Revision ID: 478c159e5654
Revises: cc214c63aa30
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '478c159e5654'
down_revision: Union[str, None] = 'cc214c63aa30'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
ai_tables = ('session_branches', 'session_handoffs', 'fork_points', 'ai_session_steps')
# Step 1: ADD COLUMN (nullable)
for table in ai_tables:
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
op.add_column('ai_suggestions', sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_ai_suggestions_account_id', 'ai_suggestions', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Step 2: BACKFILL
for table in ai_tables:
op.execute(f"""
UPDATE {table} t
SET account_id = ai.account_id
FROM ai_sessions ai
WHERE t.session_id = ai.id
AND t.account_id IS NULL
""")
op.execute("""
UPDATE ai_suggestions s
SET account_id = u.account_id
FROM users u
WHERE s.user_id = u.id
AND s.account_id IS NULL
""")
# Step 3: VERIFY zero NULLs
for table in ai_tables + ('ai_suggestions',):
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} NULL account_id rows in {table}."
)
# Step 4: SET NOT NULL
for table in ai_tables + ('ai_suggestions',):
op.alter_column(table, 'account_id', nullable=False)
# Step 5: CREATE INDEX
for table in ai_tables + ('ai_suggestions',):
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('session_branches', 'session_handoffs', 'fork_points',
'ai_session_steps', 'ai_suggestions'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,46 @@
"""add account_id to step_ratings and step_usage_log
Revision ID: 7167e9374b0c
Revises: 478c159e5654
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '7167e9374b0c'
down_revision: Union[str, None] = '478c159e5654'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
for table in ('step_ratings', 'step_usage_log'):
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Backfill: from the RATER/LOGGER user's account (not the step's account)
op.execute(f"""
UPDATE {table} t
SET account_id = u.account_id
FROM users u
WHERE t.user_id = u.id
AND t.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
op.alter_column(table, 'account_id', nullable=False)
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('step_ratings', 'step_usage_log'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,103 @@
"""add account_id to script_builder_sessions, script_templates, script_generations
Revision ID: 78fc200abac1
Revises: 7f136778f5a8
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '78fc200abac1'
down_revision: Union[str, None] = '7f136778f5a8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
PLATFORM_ACCOUNT_ID = '00000000-0000-0000-0000-000000000001'
def upgrade() -> None:
# Ensure the platform sentinel account exists before any fallback assignments.
# Migration 3a40fe11b427 also inserts this with ON CONFLICT DO NOTHING — safe.
op.execute(f"""
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
VALUES (
'{PLATFORM_ACCOUNT_ID}',
'ResolutionFlow Platform',
'PLATFORM',
NOW(),
NOW()
)
ON CONFLICT (id) DO NOTHING
""")
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# script_builder_sessions: user_id → users.account_id
op.execute("""
UPDATE script_builder_sessions sbs
SET account_id = u.account_id
FROM users u
WHERE sbs.user_id = u.id
AND sbs.account_id IS NULL
""")
# script_templates: created_by → users.account_id (nullable created_by)
op.execute("""
UPDATE script_templates st
SET account_id = u.account_id
FROM users u
WHERE st.created_by = u.id
AND st.account_id IS NULL
""")
# Fallback: team_id → team admin user
op.execute("""
UPDATE script_templates st
SET account_id = u.account_id
FROM users u
WHERE u.team_id = st.team_id
AND u.is_team_admin = TRUE
AND u.account_id IS NOT NULL
AND st.account_id IS NULL
""")
# Final fallback: platform-seeded templates with NULL team_id AND NULL created_by
# (e.g. the 6 AD templates inserted by migration 057) → platform sentinel account
op.execute(f"""
UPDATE script_templates
SET account_id = '{PLATFORM_ACCOUNT_ID}'
WHERE account_id IS NULL
""")
# script_generations: user_id → users.account_id
op.execute("""
UPDATE script_generations sg
SET account_id = u.account_id
FROM users u
WHERE sg.user_id = u.id
AND sg.account_id IS NULL
""")
# VERIFY
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
op.alter_column(table, 'account_id', nullable=False)
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('script_builder_sessions', 'script_templates', 'script_generations'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,62 @@
"""add account_id to maintenance_schedules
Revision ID: 7f136778f5a8
Revises: 8aac5b372402
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '7f136778f5a8'
down_revision: Union[str, None] = '8aac5b372402'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('maintenance_schedules',
sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
'fk_maintenance_schedules_account_id', 'maintenance_schedules', 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Primary: tree_id → trees.account_id (only where tree.account_id is NOT NULL)
op.execute("""
UPDATE maintenance_schedules ms
SET account_id = t.account_id
FROM trees t
WHERE ms.tree_id = t.id
AND t.account_id IS NOT NULL
AND ms.account_id IS NULL
""")
# Fallback: created_by → users.account_id (for is_default trees with NULL account_id)
op.execute("""
UPDATE maintenance_schedules ms
SET account_id = u.account_id
FROM users u
WHERE ms.created_by = u.id
AND u.account_id IS NOT NULL
AND ms.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text("SELECT COUNT(*) FROM maintenance_schedules WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} maintenance_schedules rows have NULL account_id. "
"Check if created_by is NULL — those rows need manual resolution."
)
op.alter_column('maintenance_schedules', 'account_id', nullable=False)
op.create_index('ix_maintenance_schedules_account_id', 'maintenance_schedules', ['account_id'])
def downgrade() -> None:
op.drop_index('ix_maintenance_schedules_account_id', table_name='maintenance_schedules')
op.drop_constraint('fk_maintenance_schedules_account_id', 'maintenance_schedules', type_='foreignkey')
op.drop_column('maintenance_schedules', 'account_id')

View File

@@ -0,0 +1,81 @@
"""add account_id to PSA and notification tables
Revision ID: 8aac5b372402
Revises: a1d2a84b9abb
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '8aac5b372402'
down_revision: Union[str, None] = 'a1d2a84b9abb'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Step 1: ADD COLUMN
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
# Step 2: BACKFILL
# psa_post_log: prefer psa_connection → fallback to posted_by user
# Note: cannot reference the updated table (ppl) inside the FROM clause JOIN,
# so use a correlated subquery for psa_connections lookup instead.
op.execute("""
UPDATE psa_post_log ppl
SET account_id = COALESCE(
(SELECT account_id FROM psa_connections WHERE id = ppl.psa_connection_id),
u.account_id
)
FROM users u
WHERE ppl.posted_by = u.id
AND ppl.account_id IS NULL
""")
# psa_member_mappings: via psa_connection
op.execute("""
UPDATE psa_member_mappings pmm
SET account_id = pc.account_id
FROM psa_connections pc
WHERE pmm.psa_connection_id = pc.id
AND pmm.account_id IS NULL
""")
# notification_logs: via notification_config
op.execute("""
UPDATE notification_logs nl
SET account_id = nc.account_id
FROM notification_configs nc
WHERE nl.notification_config_id = nc.id
AND nl.account_id IS NULL
""")
# Step 3: VERIFY
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
# Step 4: SET NOT NULL
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
op.alter_column(table, 'account_id', nullable=False)
# Step 5: CREATE INDEX
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('psa_post_log', 'psa_member_mappings', 'notification_logs'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,45 @@
"""add account_id to user personalization tables
Revision ID: a1d2a84b9abb
Revises: 7167e9374b0c
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = 'a1d2a84b9abb'
down_revision: Union[str, None] = '7167e9374b0c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
for table in ('user_folders', 'user_pinned_trees'):
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id', table, 'accounts',
['account_id'], ['id'], ondelete='CASCADE',
)
op.execute(f"""
UPDATE {table} t
SET account_id = u.account_id
FROM users u
WHERE t.user_id = u.id
AND t.account_id IS NULL
""")
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(f"ROLLBACK: {count} NULL account_id rows in {table}.")
op.alter_column(table, 'account_id', nullable=False)
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('user_folders', 'user_pinned_trees'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -0,0 +1,24 @@
"""merge Phase 1 tenant isolation chain with main head
Revision ID: a9f3b2c1d4e5
Revises: 070, 174f442795b7
Create Date: 2026-04-09 00:00:00.000000
Merge migration: consolidates the Phase 1 account_id chain (cc214c63aa30 → … → 174f442795b7)
with the main sequential chain (… → 070) into a single head so that
`alembic upgrade head` works without ambiguity.
"""
from typing import Sequence, Union
revision: str = 'a9f3b2c1d4e5'
down_revision: Union[str, tuple] = ('070', '174f442795b7')
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

View File

@@ -0,0 +1,108 @@
"""enable_rls_phase1
Revision ID: c5f48b9890f9
Revises: 0b470d9e6cf1
Create Date: 2026-04-10 04:01:13.043321
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'c5f48b9890f9'
down_revision: Union[str, None] = '0b470d9e6cf1'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
_CURRENT_ACCOUNT = (
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
f"'{_NULL_UUID}')::uuid"
)
def upgrade() -> None:
# ── trees ───────────────────────────────────────────────────────────────
# Extended policy mirrors can_access_tree() in app/core/permissions.py.
# Tenant sees: own rows, platform rows, any default tree, any public tree,
# any gallery-featured tree.
# is_gallery_featured = TRUE is included because /public/templates is a
# no-auth endpoint — no tenant context is set, so gallery trees must pass
# RLS on their own flag rather than relying on account_id or is_public.
# Private/team trees from other accounts are hidden.
op.execute("ALTER TABLE trees ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE trees FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON trees
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
OR is_default = TRUE
OR is_public = TRUE
OR is_gallery_featured = TRUE
)
""")
# ── tree_tags ────────────────────────────────────────────────────────────
# Own account + platform tags (global tags visible to all tenants).
op.execute("ALTER TABLE tree_tags ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE tree_tags FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON tree_tags
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
# ── tree_categories ──────────────────────────────────────────────────────
op.execute("ALTER TABLE tree_categories ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE tree_categories FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON tree_categories
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
# ── step_categories ──────────────────────────────────────────────────────
op.execute("ALTER TABLE step_categories ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE step_categories FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON step_categories
USING (
account_id = {_CURRENT_ACCOUNT}
OR account_id = '{_PLATFORM_UUID}'::uuid
)
""")
# ── psa_connections ──────────────────────────────────────────────────────
# Tenant-only — PSA credentials must never cross tenant boundaries.
op.execute("ALTER TABLE psa_connections ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE psa_connections FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON psa_connections
USING (account_id = {_CURRENT_ACCOUNT})
""")
# ── flow_proposals ────────────────────────────────────────────────────────
# Tenant-only.
op.execute("ALTER TABLE flow_proposals ENABLE ROW LEVEL SECURITY")
op.execute("ALTER TABLE flow_proposals FORCE ROW LEVEL SECURITY")
op.execute(f"""
CREATE POLICY tenant_isolation ON flow_proposals
USING (account_id = {_CURRENT_ACCOUNT})
""")
def downgrade() -> None:
for table in ["trees", "tree_tags", "tree_categories", "step_categories",
"psa_connections", "flow_proposals"]:
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")

View File

@@ -0,0 +1,95 @@
"""add account_id to core session tables
Revision ID: cc214c63aa30
Revises: b8d2f4a6c091
Create Date: 2026-04-09 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = 'cc214c63aa30'
down_revision: Union[str, None] = '064'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = ('067',)
def upgrade() -> None:
# ── Step 1: ADD COLUMN (nullable) ────────────────────────────────────────
for table in ('sessions', 'attachments', 'session_supporting_data',
'session_resolution_outputs'):
op.add_column(table, sa.Column('account_id', sa.UUID(), nullable=True))
op.create_foreign_key(
f'fk_{table}_account_id',
table, 'accounts',
['account_id'], ['id'],
ondelete='CASCADE',
)
# ── Step 2: BACKFILL ─────────────────────────────────────────────────────
# sessions: direct join to users
op.execute("""
UPDATE sessions s
SET account_id = u.account_id
FROM users u
WHERE s.user_id = u.id
AND s.account_id IS NULL
""")
# attachments: chain through sessions (now backfilled above)
op.execute("""
UPDATE attachments a
SET account_id = s.account_id
FROM sessions s
WHERE a.session_id = s.id
AND a.account_id IS NULL
""")
# session_supporting_data: same chain
op.execute("""
UPDATE session_supporting_data sd
SET account_id = s.account_id
FROM sessions s
WHERE sd.session_id = s.id
AND sd.account_id IS NULL
""")
# session_resolution_outputs: FK is to ai_sessions, not sessions
op.execute("""
UPDATE session_resolution_outputs sro
SET account_id = ai.account_id
FROM ai_sessions ai
WHERE sro.session_id = ai.id
AND sro.account_id IS NULL
""")
# ── Step 3: VERIFY zero NULLs — raises if any remain ────────────────────
for table in ('sessions', 'attachments', 'session_supporting_data',
'session_resolution_outputs'):
result = op.get_bind().execute(
sa.text(f"SELECT COUNT(*) FROM {table} WHERE account_id IS NULL")
)
count = result.scalar()
if count > 0:
raise RuntimeError(
f"ROLLBACK: {count} NULL account_id rows remain in {table}. "
f"Fix the backfill before re-running."
)
# ── Step 4: SET NOT NULL ─────────────────────────────────────────────────
for table in ('sessions', 'attachments', 'session_supporting_data',
'session_resolution_outputs'):
op.alter_column(table, 'account_id', nullable=False)
# ── Step 5: CREATE INDEX ─────────────────────────────────────────────────
for table in ('sessions', 'attachments', 'session_supporting_data',
'session_resolution_outputs'):
op.create_index(f'ix_{table}_account_id', table, ['account_id'])
def downgrade() -> None:
for table in ('sessions', 'attachments', 'session_supporting_data',
'session_resolution_outputs'):
op.drop_index(f'ix_{table}_account_id', table_name=table)
op.drop_constraint(f'fk_{table}_account_id', table, type_='foreignkey')
op.drop_column(table, 'account_id')

View File

@@ -10,6 +10,8 @@ from app.core.database import get_db
from app.core.security import decode_token
from app.models.user import User
from app.models.plan_limits import PlanLimits
from app.core.tenant_context import set_current_account_id, clear_current_account_id
from app.core.admin_database import get_admin_db # noqa: F401 — re-exported for use in endpoints
# Routes that are allowed even when must_change_password is True
_PASSWORD_CHANGE_ALLOWLIST = {
@@ -190,3 +192,44 @@ async def get_plan_limits_for_user(
"""Get plan limits for the current user's account."""
from app.core.subscriptions import get_user_plan_limits
return await get_user_plan_limits(current_user.account_id, db)
async def require_tenant_context(
current_user: Annotated[User, Depends(get_current_active_user)],
):
"""Set per-request tenant context for RLS.
Raises 403 if the authenticated user has no account_id — never falls back
to PLATFORM_ACCOUNT_ID (that would grant platform-scope access to a
malformed account).
Sets the ContextVar that the SQLAlchemy transaction-begin listener reads to
issue set_config('app.current_account_id', …, true) on every transaction.
Applied to every user-facing router. NOT applied to /admin/* routers or
public endpoints (auth, shared, webhooks).
"""
if current_user.account_id is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account required",
)
token = set_current_account_id(current_user.account_id)
try:
yield
finally:
clear_current_account_id(token)
async def require_admin_db(
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
) -> AsyncSession:
"""Return a BYPASSRLS admin DB session after verifying super_admin role.
Use on /admin/* endpoints that query RLS-protected tables. Replaces
Depends(get_db) on the db parameter of those endpoints.
The current_user dep is still declared separately on the endpoint if
the user object is needed in the handler.
"""
return db

View File

@@ -5,10 +5,10 @@ from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, or_
from sqlalchemy.orm import selectinload, aliased
from sqlalchemy import select, func
from sqlalchemy.orm import selectinload
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.core.audit import log_audit
from app.core.config import settings
from app.core.security import get_password_hash, generate_temp_password, create_password_reset_token, decode_token, hash_token
@@ -24,44 +24,21 @@ from app.models.invite_code import InviteCode
from app.models.account_invite import AccountInvite
from app.models.tree import Tree
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
from app.schemas.admin import (
MoveUserAccount,
AdminUserCreate,
AdminUserCreateResponse,
AdminPasswordReset,
AdminPasswordResetResponse,
HardDeleteCheckResponse,
AdminUserListItem,
AdminUserListResponse,
AdminAccountMember,
AdminAccountListItem,
AdminAccountListResponse,
AdminAccountOwnerSummary,
AdminAccountSubscriptionSummary,
AdminAccountUsageSummary,
AdminAccountDetailResponse,
AdminAccountInviteSummary,
AdminAccountCreate,
AdminAccountUpdate,
)
from app.schemas.admin import MoveUserAccount, AdminUserCreate, AdminUserCreateResponse, AdminPasswordReset, AdminPasswordResetResponse, HardDeleteCheckResponse
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
from app.schemas.user_detail import (
UserDetailResponse, AccountSummary, SubscriptionSummary,
SessionSummary, AuditLogSummary, InviteCodeUsedSummary,
)
from app.api.deps import require_admin
from app.core.subscriptions import get_account_usage
router = APIRouter(prefix="/admin", tags=["admin"])
@router.get("/users", response_model=AdminUserListResponse)
@router.get("/users", response_model=list[UserResponse])
async def list_users(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
page: Optional[int] = Query(None, ge=1),
size: Optional[int] = Query(None, ge=1, le=100),
search: Optional[str] = Query(None, description="Search by user or account fields"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
@@ -69,240 +46,23 @@ async def list_users(
account_id: Optional[UUID] = Query(None, description="Filter by account"),
include_archived: bool = Query(False, description="Include archived (soft-deleted) users"),
):
"""List users for super admin global people search."""
resolved_limit = size or limit
resolved_skip = skip
current_page = 1
if page is not None:
resolved_skip = (page - 1) * resolved_limit
current_page = page
elif resolved_limit > 0:
current_page = (resolved_skip // resolved_limit) + 1
count_query = (
select(func.count())
.select_from(User)
.outerjoin(Account, User.account_id == Account.id)
)
query = (
select(
User,
Account.name.label("account_name"),
Account.display_code.label("account_display_code"),
)
.outerjoin(Account, User.account_id == Account.id)
)
"""List all users (super admin only)."""
query = select(User)
if not include_archived:
query = query.where(User.deleted_at.is_(None))
count_query = count_query.where(User.deleted_at.is_(None))
if is_active is not None:
query = query.where(User.is_active == is_active)
count_query = count_query.where(User.is_active == is_active)
if role:
query = query.where(User.role == role)
count_query = count_query.where(User.role == role)
if account_id:
query = query.where(User.account_id == account_id)
count_query = count_query.where(User.account_id == account_id)
if search:
search_term = f"%{search.strip()}%"
search_filter = or_(
User.name.ilike(search_term),
User.email.ilike(search_term),
Account.name.ilike(search_term),
Account.display_code.ilike(search_term),
)
query = query.where(search_filter)
count_query = count_query.where(search_filter)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
query = query.order_by(User.created_at.desc()).offset(resolved_skip).limit(resolved_limit)
result = await db.execute(query)
rows = result.all()
items = [
AdminUserListItem(
id=user.id,
email=user.email,
name=user.name,
role=user.role,
is_super_admin=user.is_super_admin,
is_active=user.is_active,
account_id=user.account_id,
account_role=user.account_role,
account_name=account_name,
account_display_code=account_display_code,
created_at=user.created_at,
last_login=user.last_login,
deleted_at=user.deleted_at,
)
for user, account_name, account_display_code in rows
]
return AdminUserListResponse(
items=items,
total=total,
page=current_page,
per_page=resolved_limit,
)
@router.get("/accounts", response_model=AdminAccountListResponse)
async def list_accounts(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
page: int = Query(1, ge=1),
size: int = Query(12, ge=1, le=100),
search: Optional[str] = Query(None, description="Search by account, display code, or owner"),
plan: Optional[str] = Query(None, description="Filter by subscription plan"),
status: Optional[str] = Query(None, description="Filter by subscription status"),
include_archived: bool = Query(False, description="Include archived users in account member lists"),
):
"""List accounts with embedded members for the admin panel."""
owner_user = aliased(User)
count_query = (
select(func.count(func.distinct(Account.id)))
.select_from(Account)
.outerjoin(owner_user, Account.owner_id == owner_user.id)
.outerjoin(Subscription, Subscription.account_id == Account.id)
)
accounts_query = (
select(
Account,
owner_user.id.label("owner_user_id"),
owner_user.name.label("owner_name"),
owner_user.email.label("owner_email"),
Subscription.id.label("subscription_id"),
Subscription.plan.label("subscription_plan"),
Subscription.status.label("subscription_status"),
Subscription.billing_interval.label("subscription_billing_interval"),
Subscription.current_period_end.label("subscription_current_period_end"),
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
)
.outerjoin(owner_user, Account.owner_id == owner_user.id)
.outerjoin(Subscription, Subscription.account_id == Account.id)
)
if search:
search_term = f"%{search.strip()}%"
search_filter = or_(
Account.name.ilike(search_term),
Account.display_code.ilike(search_term),
owner_user.name.ilike(search_term),
owner_user.email.ilike(search_term),
)
count_query = count_query.where(search_filter)
accounts_query = accounts_query.where(search_filter)
if plan:
count_query = count_query.where(Subscription.plan == plan)
accounts_query = accounts_query.where(Subscription.plan == plan)
if status:
count_query = count_query.where(Subscription.status == status)
accounts_query = accounts_query.where(Subscription.status == status)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
accounts_result = await db.execute(
accounts_query
.order_by(Account.created_at.desc())
.offset((page - 1) * size)
.limit(size)
)
rows = accounts_result.all()
accounts = [row.Account for row in rows]
account_ids = [account.id for account in accounts]
members_by_account: dict[UUID, list[AdminAccountMember]] = {account_id: [] for account_id in account_ids}
pending_invites_by_account: dict[UUID, int] = {account_id: 0 for account_id in account_ids}
usage_by_account: dict[UUID, AdminAccountUsageSummary] = {}
if account_ids:
members_query = select(User).where(User.account_id.in_(account_ids))
if not include_archived:
members_query = members_query.where(User.deleted_at.is_(None))
members_query = members_query.order_by(User.created_at.asc())
members_result = await db.execute(members_query)
for member in members_result.scalars().all():
members_by_account.setdefault(member.account_id, []).append(
AdminAccountMember(
id=member.id,
email=member.email,
name=member.name,
role=member.role,
is_super_admin=member.is_super_admin,
is_active=member.is_active,
account_role=member.account_role,
created_at=member.created_at,
last_login=member.last_login,
deleted_at=member.deleted_at,
)
)
pending_invites_result = await db.execute(
select(AccountInvite.account_id, func.count(AccountInvite.id))
.where(
AccountInvite.account_id.in_(account_ids),
AccountInvite.used_at.is_(None),
)
.group_by(AccountInvite.account_id)
)
pending_invites_by_account.update({row[0]: row[1] for row in pending_invites_result.all()})
for account_id in account_ids:
usage = await get_account_usage(account_id, db)
usage_by_account[account_id] = AdminAccountUsageSummary(
tree_count=usage.get("tree_count", 0),
session_count_this_month=usage.get("session_count_this_month", 0),
)
items = [
AdminAccountListItem(
id=row.Account.id,
name=row.Account.name,
display_code=row.Account.display_code,
created_at=row.Account.created_at,
owner_id=row.Account.owner_id,
owner=(
AdminAccountOwnerSummary(
id=row.owner_user_id,
name=row.owner_name,
email=row.owner_email,
) if row.owner_user_id and row.owner_name and row.owner_email else None
),
subscription=(
AdminAccountSubscriptionSummary(
id=row.subscription_id,
plan=row.subscription_plan,
status=row.subscription_status,
billing_interval=row.subscription_billing_interval,
current_period_end=row.subscription_current_period_end,
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
),
usage=usage_by_account.get(row.Account.id, AdminAccountUsageSummary()),
member_count=len(members_by_account.get(row.Account.id, [])),
active_member_count=sum(1 for member in members_by_account.get(row.Account.id, []) if member.is_active),
pending_invite_count=pending_invites_by_account.get(row.Account.id, 0),
sso_enabled=row.Account.sso_enabled,
branding_company_name=row.Account.branding_company_name,
members=members_by_account.get(row.Account.id, []),
)
for row in rows
]
return AdminAccountListResponse(
items=items,
total=total,
page=page,
per_page=size,
)
users = result.scalars().all()
return users
def _generate_display_code() -> str:
@@ -311,187 +71,10 @@ def _generate_display_code() -> str:
return ''.join(secrets.choice(chars) for _ in range(8))
async def _generate_unique_display_code(db: AsyncSession) -> str:
"""Generate a unique display code for a new account."""
while True:
display_code = _generate_display_code()
existing = await db.execute(select(Account.id).where(Account.display_code == display_code))
if existing.scalar_one_or_none() is None:
return display_code
async def _get_account_detail_payload(
account_id: UUID,
db: AsyncSession,
include_archived: bool = False,
) -> AdminAccountDetailResponse:
owner_user = aliased(User)
result = await db.execute(
select(
Account,
owner_user.id.label("owner_user_id"),
owner_user.name.label("owner_name"),
owner_user.email.label("owner_email"),
Subscription.id.label("subscription_id"),
Subscription.plan.label("subscription_plan"),
Subscription.status.label("subscription_status"),
Subscription.billing_interval.label("subscription_billing_interval"),
Subscription.current_period_end.label("subscription_current_period_end"),
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
)
.outerjoin(owner_user, Account.owner_id == owner_user.id)
.outerjoin(Subscription, Subscription.account_id == Account.id)
.where(Account.id == account_id)
)
row = result.one_or_none()
if not row:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
members_query = select(User).where(User.account_id == account_id).order_by(User.created_at.asc())
if not include_archived:
members_query = members_query.where(User.deleted_at.is_(None))
members_result = await db.execute(members_query)
members = [
AdminAccountMember(
id=member.id,
email=member.email,
name=member.name,
role=member.role,
is_super_admin=member.is_super_admin,
is_active=member.is_active,
account_role=member.account_role,
created_at=member.created_at,
last_login=member.last_login,
deleted_at=member.deleted_at,
)
for member in members_result.scalars().all()
]
invites_result = await db.execute(
select(AccountInvite)
.where(AccountInvite.account_id == account_id)
.order_by(AccountInvite.created_at.desc())
)
invites = [
AdminAccountInviteSummary(
id=invite.id,
email=invite.email,
role=invite.role,
expires_at=invite.expires_at,
created_at=invite.created_at,
used_at=invite.used_at,
)
for invite in invites_result.scalars().all()
if invite.used_at is None
]
usage = await get_account_usage(account_id, db)
return AdminAccountDetailResponse(
id=row.Account.id,
name=row.Account.name,
display_code=row.Account.display_code,
created_at=row.Account.created_at,
owner_id=row.Account.owner_id,
owner=(
AdminAccountOwnerSummary(
id=row.owner_user_id,
name=row.owner_name,
email=row.owner_email,
) if row.owner_user_id and row.owner_name and row.owner_email else None
),
subscription=(
AdminAccountSubscriptionSummary(
id=row.subscription_id,
plan=row.subscription_plan,
status=row.subscription_status,
billing_interval=row.subscription_billing_interval,
current_period_end=row.subscription_current_period_end,
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
),
usage=AdminAccountUsageSummary(
tree_count=usage.get("tree_count", 0),
session_count_this_month=usage.get("session_count_this_month", 0),
),
member_count=len(members),
active_member_count=sum(1 for member in members if member.is_active),
pending_invite_count=len(invites),
sso_enabled=row.Account.sso_enabled,
branding_company_name=row.Account.branding_company_name,
members=members,
invites=invites,
)
@router.post("/accounts", response_model=AdminAccountDetailResponse, status_code=status.HTTP_201_CREATED)
async def create_account(
data: AdminAccountCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Create a new account without requiring an initial user."""
display_code = await _generate_unique_display_code(db)
new_account = Account(
name=data.name.strip(),
display_code=display_code,
)
db.add(new_account)
await db.flush()
new_subscription = Subscription(
account_id=new_account.id,
plan=data.plan,
status="active",
)
db.add(new_subscription)
await log_audit(
db, current_user.id, "account.create_admin", "account", new_account.id,
{"name": new_account.name, "plan": data.plan},
)
await db.commit()
return await _get_account_detail_payload(new_account.id, db)
@router.get("/accounts/{account_id}", response_model=AdminAccountDetailResponse)
async def get_account_detail(
account_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
include_archived: bool = Query(False),
):
"""Get detailed account information for admin management."""
return await _get_account_detail_payload(account_id, db, include_archived=include_archived)
@router.put("/accounts/{account_id}", response_model=AdminAccountDetailResponse)
async def update_account(
account_id: UUID,
data: AdminAccountUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update account settings from the admin panel."""
result = await db.execute(select(Account).where(Account.id == account_id))
account = result.scalar_one_or_none()
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
old_name = account.name
account.name = data.name.strip()
await log_audit(
db, current_user.id, "account.update_admin", "account", account.id,
{"old_name": old_name, "new_name": account.name},
)
await db.commit()
return await _get_account_detail_payload(account.id, db)
@router.post("/users", response_model=AdminUserCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
data: AdminUserCreate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Create a new user with a temporary password (super admin only).
@@ -616,7 +199,7 @@ async def create_user(
@router.get("/users/{user_id}", response_model=UserDetailResponse)
async def get_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Get enriched user details (super admin only)."""
@@ -734,7 +317,7 @@ async def get_user(
async def update_user_role(
user_id: UUID,
role_data: RoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Change user role (super admin only)."""
@@ -766,7 +349,7 @@ async def update_user_role(
async def update_account_role(
user_id: UUID,
data: AccountRoleUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Change a user's account role (super admin only)."""
@@ -792,7 +375,7 @@ async def update_account_role(
async def update_super_admin_status(
user_id: UUID,
data: dict,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Promote or demote a user to/from super admin (super admin only)."""
@@ -831,7 +414,7 @@ async def update_super_admin_status(
@router.put("/users/{user_id}/deactivate", response_model=UserResponse)
async def deactivate_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Deactivate a user account (super admin only)."""
@@ -860,7 +443,7 @@ async def deactivate_user(
@router.put("/users/{user_id}/activate", response_model=UserResponse)
async def activate_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)]
):
"""Reactivate a user account (super admin only)."""
@@ -884,7 +467,7 @@ async def activate_user(
async def move_user_account(
user_id: UUID,
data: MoveUserAccount,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Move a user to a different account (super admin only)."""
@@ -933,33 +516,11 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
return user, subscription
async def _get_account_subscription(account_id: UUID, db: AsyncSession) -> tuple[Account, Subscription]:
"""Helper to load account and its subscription."""
account_result = await db.execute(select(Account).where(Account.id == account_id))
account = account_result.scalar_one_or_none()
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
sub_result = await db.execute(
select(Subscription).where(Subscription.account_id == account.id)
)
subscription = sub_result.scalar_one_or_none()
if not subscription:
subscription = Subscription(
account_id=account.id,
plan="free",
status="active",
)
db.add(subscription)
await db.flush()
return account, subscription
@router.put("/users/{user_id}/subscription/plan")
async def update_user_plan(
user_id: UUID,
data: SubscriptionPlanUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Change a user's subscription plan (super admin only)."""
@@ -974,36 +535,11 @@ async def update_user_plan(
return {"plan": subscription.plan, "status": subscription.status}
@router.put("/accounts/{account_id}/subscription/plan")
async def update_account_plan(
account_id: UUID,
data: SubscriptionPlanUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Change an account subscription plan (super admin only)."""
if data.plan not in ("free", "pro", "team"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
account, subscription = await _get_account_subscription(account_id, db)
old_plan = subscription.plan
subscription.plan = data.plan
await log_audit(
db,
current_user.id,
"subscription.plan_change",
"subscription",
subscription.id,
{"old_plan": old_plan, "new_plan": data.plan, "account_id": str(account_id)},
)
await db.commit()
return {"plan": subscription.plan, "status": subscription.status}
@router.put("/users/{user_id}/subscription/extend-trial")
async def extend_user_trial(
user_id: UUID,
data: ExtendTrialRequest,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Extend or start a trial for a user's subscription (super admin only)."""
@@ -1029,48 +565,11 @@ async def extend_user_trial(
"current_period_end": subscription.current_period_end}
@router.put("/accounts/{account_id}/subscription/extend-trial")
async def extend_account_trial(
account_id: UUID,
data: ExtendTrialRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Extend or start a trial for an account subscription (super admin only)."""
if data.days < 1 or data.days > 90:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Days must be 1-90")
account, subscription = await _get_account_subscription(account_id, db)
now = datetime.now(timezone.utc)
if subscription.status == "trialing" and subscription.current_period_end:
new_end = subscription.current_period_end + timedelta(days=data.days)
else:
subscription.status = "trialing"
subscription.current_period_start = now
new_end = now + timedelta(days=data.days)
subscription.current_period_end = new_end
await log_audit(
db,
current_user.id,
"subscription.extend_trial",
"subscription",
subscription.id,
{"days": data.days, "new_end": new_end.isoformat(), "account_id": str(account.id)},
)
await db.commit()
return {
"plan": subscription.plan,
"status": subscription.status,
"current_period_end": subscription.current_period_end,
}
@router.post("/users/{user_id}/password-reset", response_model=AdminPasswordResetResponse)
async def admin_reset_password(
user_id: UUID,
data: AdminPasswordReset,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Admin-triggered password reset (super admin only).
@@ -1141,7 +640,7 @@ async def admin_reset_password(
@router.put("/users/{user_id}/archive", response_model=UserResponse)
async def archive_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Archive (soft delete) a user (super admin only)."""
@@ -1176,7 +675,7 @@ async def archive_user(
@router.put("/users/{user_id}/restore", response_model=UserResponse)
async def restore_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Restore an archived user (super admin only)."""
@@ -1201,7 +700,7 @@ async def restore_user(
@router.get("/users/{user_id}/hard-delete-check", response_model=HardDeleteCheckResponse)
async def hard_delete_check(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Check if a user can be hard-deleted (super admin only). Returns blockers."""
@@ -1274,7 +773,7 @@ async def hard_delete_check(
@router.delete("/users/{user_id}/hard-delete", status_code=status.HTTP_204_NO_CONTENT)
async def hard_delete_user(
user_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Permanently delete a user (super admin only). User must be archived first."""
@@ -1334,7 +833,7 @@ async def hard_delete_user(
@router.post("/invites", status_code=status.HTTP_201_CREATED)
async def admin_create_invite(
data: dict,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Quick-invite a user to an account (super admin only).

View File

@@ -4,25 +4,26 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.core.audit import log_audit
from app.models.user import User
from app.models.category import TreeCategory
from app.models.tree import Tree
from app.schemas.admin import GlobalCategoryCreate, GlobalCategoryUpdate, GlobalCategoryResponse
from app.api.deps import require_admin
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/admin/categories", tags=["admin-categories"])
@router.get("/global", response_model=list[GlobalCategoryResponse])
async def list_global_categories(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all global categories (account_id IS NULL)."""
result = await db.execute(
select(TreeCategory).where(TreeCategory.account_id.is_(None)).order_by(TreeCategory.name)
select(TreeCategory).where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID).order_by(TreeCategory.name)
)
categories = result.scalars().all()
@@ -45,36 +46,36 @@ async def list_global_categories(
@router.post("/global", response_model=GlobalCategoryResponse, status_code=status.HTTP_201_CREATED)
async def create_global_category(
data: GlobalCategoryCreate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Create a global category."""
# Check slug uniqueness for global categories
existing = await db.execute(
select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None))
select(TreeCategory).where(TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Global category with this slug already exists")
category = TreeCategory(name=data.name, slug=data.slug, account_id=None)
category = TreeCategory(name=data.name, slug=data.slug, account_id=PLATFORM_ACCOUNT_ID)
db.add(category)
await log_audit(db, current_user.id, "global_category.create", "category", details={"name": data.name})
await db.commit()
await db.refresh(category)
return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=None, tree_count=0)
return GlobalCategoryResponse(id=category.id, name=category.name, slug=category.slug, account_id=PLATFORM_ACCOUNT_ID, tree_count=0)
@router.put("/global/{category_id}", response_model=GlobalCategoryResponse)
async def update_global_category(
category_id: UUID,
data: GlobalCategoryUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update a global category."""
result = await db.execute(
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None))
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
)
category = result.scalar_one_or_none()
if not category:
@@ -86,7 +87,7 @@ async def update_global_category(
# Check slug uniqueness
existing = await db.execute(
select(TreeCategory).where(
TreeCategory.slug == data.slug, TreeCategory.account_id.is_(None), TreeCategory.id != category_id
TreeCategory.slug == data.slug, TreeCategory.account_id == PLATFORM_ACCOUNT_ID, TreeCategory.id != category_id
)
)
if existing.scalar_one_or_none():
@@ -103,19 +104,19 @@ async def update_global_category(
return GlobalCategoryResponse(
id=category.id, name=category.name, slug=category.slug,
account_id=None, tree_count=tree_count,
account_id=PLATFORM_ACCOUNT_ID, tree_count=tree_count,
)
@router.delete("/global/{category_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_global_category(
category_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Delete (archive) a global category."""
result = await db.execute(
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id.is_(None))
select(TreeCategory).where(TreeCategory.id == category_id, TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
)
category = result.scalar_one_or_none()
if not category:

View File

@@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.models.user import User
from app.models.subscription import Subscription
from app.models.tree import Tree
@@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin/dashboard", tags=["admin-dashboard"])
@router.get("/metrics", response_model=DashboardMetrics)
async def get_dashboard_metrics(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Get platform overview metrics."""
@@ -45,7 +45,7 @@ async def get_dashboard_metrics(
@router.get("/activity", response_model=list[ActivityEntry])
async def get_dashboard_activity(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Get recent audit log entries for activity feed."""

View File

@@ -12,7 +12,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import require_admin
from app.core.database import get_db
from app.core.admin_database import get_admin_db
from app.models.script_template import ScriptTemplate
from app.models.tree import Tree
from app.models.user import User
@@ -66,7 +66,7 @@ def _script_summary(script: ScriptTemplate) -> dict:
@router.get("/featured")
async def list_featured(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List all featured flows and scripts (super admin only)."""
@@ -92,7 +92,7 @@ async def list_featured(
@router.get("/items")
async def list_all_items(
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""List ALL flows and scripts with their gallery status (super admin only)."""
@@ -119,7 +119,7 @@ async def list_all_items(
async def toggle_flow_featured(
flow_id: UUID,
body: FeatureToggle,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Toggle is_gallery_featured on a flow (super admin only)."""
@@ -138,7 +138,7 @@ async def toggle_flow_featured(
async def update_flow_sort_order(
flow_id: UUID,
body: SortOrderUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update gallery_sort_order on a flow (super admin only)."""
@@ -157,7 +157,7 @@ async def update_flow_sort_order(
async def toggle_script_featured(
script_id: UUID,
body: FeatureToggle,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Toggle is_gallery_featured on a script (super admin only)."""
@@ -176,7 +176,7 @@ async def toggle_script_featured(
async def update_script_sort_order(
script_id: UUID,
body: SortOrderUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
db: Annotated[AsyncSession, Depends(get_admin_db)],
current_user: Annotated[User, Depends(require_admin)],
):
"""Update gallery_sort_order on a script (super admin only)."""

View File

@@ -49,8 +49,6 @@ from app.schemas.ai_session import (
ChatMessageRequest,
ChatMessageResponse,
SaveTaskLaneRequest,
TriagePatchRequest,
TriagePatchResponse,
)
from app.services import flowpilot_engine
from app.services import unified_chat_service
@@ -122,11 +120,6 @@ def _build_session_detail(session: AISession) -> AISessionDetail:
pending_task_lane=session.pending_task_lane,
is_branching=getattr(session, 'is_branching', False),
active_branch_id=str(session.active_branch_id) if getattr(session, 'active_branch_id', None) else None,
client_name=getattr(session, 'client_name', None),
asset_name=getattr(session, 'asset_name', None),
issue_category=getattr(session, 'issue_category', None),
triage_hypothesis=getattr(session, 'triage_hypothesis', None),
evidence_items=getattr(session, 'evidence_items', None),
)
@@ -308,7 +301,7 @@ async def send_chat_message(
message = f"{message}\n\n[Attached document content]\n{doc_context}"
try:
ai_content, suggested_flows, session, fork_metadata, actions_data, questions_data, triage_update_data = await unified_chat_service.send_chat_message(
ai_content, suggested_flows, session, fork_metadata, actions_data, questions_data = await unified_chat_service.send_chat_message(
session_id=session_id,
user_id=user_id,
account_id=account_id,
@@ -353,7 +346,6 @@ async def send_chat_message(
fork=fork_metadata,
actions=actions_data,
questions=questions_data,
triage_update=triage_update_data,
)
@@ -450,12 +442,7 @@ async def resolve_session(
try:
from app.services.resolution_output_generator import ResolutionOutputGenerator
gen = ResolutionOutputGenerator(db)
await gen.generate_all(
session_id,
root_cause=data.root_cause,
steps_taken=data.steps_taken,
recommendations=data.recommendations,
)
await gen.generate_all(session_id)
except Exception:
logger.exception(f"Failed to generate resolution outputs for session {session_id}")
@@ -532,11 +519,15 @@ async def save_task_lane(
_: None = Depends(require_engineer_or_admin),
):
"""Save the current task lane state including user's in-progress responses."""
session = await db.get(AISession, session_id)
result = await db.execute(
select(AISession).where(
AISession.id == session_id,
AISession.user_id == current_user.id,
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not your session")
payload = {
"questions": [q.model_dump() for q in body.questions],
@@ -553,122 +544,6 @@ async def save_task_lane(
await db.commit()
# ── Triage Metadata ──
@router.patch("/{session_id}/triage", response_model=TriagePatchResponse)
@limiter.limit("30/minute")
async def update_triage(
request: Request,
session_id: UUID,
body: TriagePatchRequest,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Update triage metadata on a session (incident header fields)."""
session = await db.get(AISession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not your session")
patch_data = body.model_dump(exclude_unset=True)
for field, value in patch_data.items():
setattr(session, field, value)
await db.commit()
await db.refresh(session)
return TriagePatchResponse(
id=session.id,
client_name=session.client_name,
asset_name=session.asset_name,
issue_category=session.issue_category,
triage_hypothesis=session.triage_hypothesis,
evidence_items=session.evidence_items,
)
# ── Handoff Draft ──
@router.post("/{session_id}/handoff-draft")
@limiter.limit("10/minute")
async def handoff_draft(
request: Request,
session_id: UUID,
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
_: None = Depends(require_engineer_or_admin),
):
"""Stream a structured handoff draft for the conclude modal."""
from fastapi.responses import StreamingResponse
from app.services.assistant_chat_service import _call_ai
session = await db.get(AISession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not your session")
# Build context from session data
context_parts = [
f"Problem: {session.problem_summary or 'Unknown'}",
f"Domain: {session.problem_domain or 'Unknown'}",
f"Client: {session.client_name or 'Unknown'}",
f"Asset: {session.asset_name or 'Unknown'}",
f"Hypothesis: {session.triage_hypothesis or 'None'}",
]
if session.evidence_items:
context_parts.append("\nEvidence collected:")
for item in session.evidence_items:
status_icon = {"confirmed": "", "ruled_out": "", "pending": "?"}.get(item.get("status", ""), "?")
context_parts.append(f" {status_icon} {item.get('text', '')}")
# Include task lane steps if available
if session.pending_task_lane:
actions = session.pending_task_lane.get("actions", [])
if actions:
context_parts.append("\nSteps taken:")
for a in actions:
context_parts.append(f" - {a.get('label', '')}")
# Include last 20 conversation messages
msgs = session.conversation_messages or []
if msgs:
context_parts.append("\nRecent conversation:")
for msg in msgs[-20:]:
role = msg.get("role", "unknown")
content = msg.get("content", "")[:300]
context_parts.append(f" [{role}]: {content}")
context = "\n".join(context_parts)
prompt = (
"Generate a structured handoff summary for this troubleshooting session.\n"
"Return ONLY valid JSON with exactly these four fields:\n"
'{"root_cause": "...", "resolution": "...", "steps_taken": ["step1", "step2"], "recommendations": "..."}\n\n'
f"Session context:\n{context}"
)
async def generate():
try:
content, _, _ = await _call_ai(
system_base="You are a concise technical documentation assistant for MSP teams. Return only JSON.",
rag_context="",
history=[],
new_message=prompt,
max_tokens=1024,
)
yield f"data: {content}\n\n"
except Exception as e:
logger.exception(f"Handoff draft generation failed for session {session_id}")
import json
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
# ── Resume ──
@router.post("/{session_id}/resume", status_code=204)
@@ -891,13 +766,13 @@ async def search_sessions(
limit: int = Query(5, ge=1, le=20),
):
"""Search AI sessions by content using full-text search. Used by Command Palette."""
# Sessions are user-scoped. The list endpoint uses user_id only;
# search must be consistent. Cross-user access requires explicit
# escalation or session sharing — not ambient account membership.
result = await db.execute(
select(AISession)
.where(
or_(
AISession.user_id == current_user.id,
AISession.account_id == current_user.account_id,
),
AISession.user_id == current_user.id,
text("ai_sessions.search_vector @@ plainto_tsquery('english', :q)"),
)
.params(q=q)
@@ -1030,7 +905,7 @@ async def get_session(
pkg = session.escalation_package or {}
is_handler = pkg.get("picked_up_by") == str(current_user.id)
if session.user_id != current_user.id and session.escalated_to_id != current_user.id and not is_handler:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return _build_session_detail(session)
@@ -1046,6 +921,18 @@ async def get_documentation(
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get auto-generated documentation for a session."""
# Verify session ownership — owner only. Documentation endpoints require direct
# ownership; escalated_to_id / picked_up_by handlers use get_session (read-only).
# This is consistent with stream_documentation which has the same owner-only check.
result = await db.execute(
select(AISession).where(
AISession.id == session_id,
AISession.user_id == current_user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Session not found")
try:
return await flowpilot_engine.get_session_documentation(
session_id=session_id,
@@ -1071,13 +958,14 @@ async def stream_documentation(
# Verify session ownership
result = await db.execute(
select(AISession).where(AISession.id == session_id)
select(AISession).where(
AISession.id == session_id,
AISession.user_id == current_user.id,
)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Not authorized")
async def event_generator():
try:
@@ -1172,6 +1060,19 @@ async def retry_psa_push_endpoint(
"""Manually retry a failed PSA documentation push."""
from app.models.psa_post_log import PsaPostLog
# Verify the session belongs to the current user
session_result = await db.execute(
select(AISession).where(
AISession.id == session_id,
AISession.user_id == current_user.id,
)
)
if not session_result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found",
)
# Find the latest failed push log for this session
result = await db.execute(
select(PsaPostLog)

View File

@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_active_user
from app.core.filters import tenant_filter
from app.models import User, Session, Tree, SessionRating
from app.schemas.analytics import (
TeamAnalyticsResponse, PersonalAnalyticsResponse, FlowAnalyticsResponse,
@@ -290,8 +291,13 @@ async def get_flow_analytics(
current_user: User = Depends(get_current_active_user),
):
"""Analytics for a specific flow."""
# Verify tree exists
result = await db.execute(select(Tree).where(Tree.id == tree_id))
# Verify tree exists and belongs to the requesting user's account.
result = await db.execute(
select(Tree).where(
Tree.id == tree_id,
tenant_filter(Tree, current_user.account_id),
)
)
tree = result.scalar_one_or_none()
if not tree:
raise HTTPException(status_code=404, detail="Flow not found")

View File

@@ -1,6 +1,5 @@
import secrets
import string
import uuid
from datetime import datetime, timezone, timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
@@ -27,7 +26,6 @@ from app.models.refresh_token import RefreshToken
from app.models.account import Account
from app.models.subscription import Subscription
from app.models.account_invite import AccountInvite
from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
from app.schemas.token import Token
from app.schemas.auth_password import (
@@ -720,59 +718,3 @@ async def verify_email(
await db.commit()
return {"message": "Email verified successfully"}
@router.get("/me/feature-flags", response_model=dict[str, bool])
async def get_my_feature_flags(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
) -> dict[str, bool]:
"""Resolve feature flags for the current user's account and plan."""
plan = "free"
if current_user.account_id:
sub_result = await db.execute(
select(Subscription).where(
Subscription.account_id == current_user.account_id,
Subscription.status.in_(["active", "trialing"]),
)
)
sub = sub_result.scalar_one_or_none()
if sub:
plan = sub.plan
flags_result = await db.execute(select(FeatureFlag))
flags = flags_result.scalars().all()
if not flags:
return {}
flag_ids = [f.id for f in flags]
defaults_result = await db.execute(
select(PlanFeatureDefault).where(
PlanFeatureDefault.flag_id.in_(flag_ids),
PlanFeatureDefault.plan == plan,
)
)
plan_defaults = {d.flag_id: d.enabled for d in defaults_result.scalars().all()}
overrides: dict[uuid.UUID, bool] = {}
if current_user.account_id:
overrides_result = await db.execute(
select(AccountFeatureOverride).where(
AccountFeatureOverride.flag_id.in_(flag_ids),
AccountFeatureOverride.account_id == current_user.account_id,
)
)
overrides = {o.flag_id: o.enabled for o in overrides_result.scalars().all()}
resolved = {}
for flag in flags:
if flag.id in overrides:
resolved[flag.flag_key] = overrides[flag.id]
elif flag.id in plan_defaults:
resolved[flag.flag_key] = plan_defaults[flag.id]
else:
resolved[flag.flag_key] = False
return resolved

View File

@@ -12,6 +12,8 @@ from app.models.user import User
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse, CategoryListResponse
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_category, can_create_category
from app.core.service_account import PLATFORM_ACCOUNT_ID
from app.core.filters import tenant_filter
router = APIRouter(prefix="/categories", tags=["categories"])
@@ -47,13 +49,13 @@ async def list_categories(
elif current_user.account_id:
query = query.where(
or_(
TreeCategory.account_id.is_(None), # Global
TreeCategory.account_id == PLATFORM_ACCOUNT_ID, # Global
TreeCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no account, only show global categories
query = query.where(TreeCategory.account_id.is_(None))
query = query.where(TreeCategory.account_id == PLATFORM_ACCOUNT_ID)
query = query.order_by(TreeCategory.display_order, TreeCategory.name)
@@ -108,10 +110,12 @@ async def get_category(
detail="You don't have access to this category"
)
# Get tree count
# Get tree count — scoped to the requesting account so cross-account
# trees in shared categories are not counted.
count_query = select(func.count(Tree.id)).where(
Tree.category_id == category.id,
Tree.is_active == True
Tree.is_active == True,
tenant_filter(Tree, current_user.account_id),
)
count_result = await db.execute(count_query)
tree_count = count_result.scalar() or 0
@@ -173,7 +177,7 @@ async def create_category(
name=category_data.name,
slug=slug,
description=category_data.description,
account_id=category_data.account_id,
account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID,
display_order=max_order + 1,
created_by=current_user.id
)

View File

@@ -1,119 +0,0 @@
"""Device types API endpoints."""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_active_user
from app.models.user import User
from app.models.device_type import DeviceType
from app.schemas.device_type import (
DeviceTypeCreate,
DeviceTypeUpdate,
DeviceTypeResponse,
)
router = APIRouter(prefix="/device-types", tags=["device-types"])
@router.get("/", response_model=list[DeviceTypeResponse])
async def list_device_types(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[DeviceTypeResponse]:
stmt = (
select(DeviceType)
.where(
or_(
DeviceType.is_system.is_(True),
DeviceType.team_id == current_user.team_id,
)
)
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
)
result = await db.execute(stmt)
rows = result.scalars().all()
return [DeviceTypeResponse.model_validate(r) for r in rows]
@router.post("/", response_model=DeviceTypeResponse, status_code=201)
async def create_device_type(
data: DeviceTypeCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DeviceTypeResponse:
existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.team_id == current_user.team_id,
)
)
if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your team")
system_existing = await db.execute(
select(DeviceType).where(
DeviceType.slug == data.slug,
DeviceType.is_system.is_(True),
)
)
if system_existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' conflicts with a system type")
device_type = DeviceType(
slug=data.slug,
label=data.label,
category=data.category,
is_system=False,
team_id=current_user.team_id,
sort_order=data.sort_order,
)
db.add(device_type)
await db.commit()
await db.refresh(device_type)
return DeviceTypeResponse.model_validate(device_type)
@router.put("/{device_type_id}", response_model=DeviceTypeResponse)
async def update_device_type(
device_type_id: UUID,
data: DeviceTypeUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DeviceTypeResponse:
device_type = await db.get(DeviceType, device_type_id)
if not device_type:
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot modify system device types")
if device_type.team_id != current_user.team_id:
raise HTTPException(status_code=404, detail="Device type not found")
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(device_type, field, value)
await db.commit()
await db.refresh(device_type)
return DeviceTypeResponse.model_validate(device_type)
@router.delete("/{device_type_id}", status_code=204)
async def delete_device_type(
device_type_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> None:
device_type = await db.get(DeviceType, device_type_id)
if not device_type:
raise HTTPException(status_code=404, detail="Device type not found")
if device_type.is_system:
raise HTTPException(status_code=403, detail="Cannot delete system device types")
if device_type.team_id != current_user.team_id:
raise HTTPException(status_code=404, detail="Device type not found")
await db.delete(device_type)
await db.commit()

View File

@@ -29,8 +29,8 @@ def _compute_next_run(cron_expression: str, tz_name: str) -> datetime:
return cron.get_next(datetime).astimezone(timezone.utc)
async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree":
"""Fetch tree and verify the current user's team owns it."""
async def _get_tree_or_404(tree_id: UUID, current_user: User, db: AsyncSession) -> "Tree":
"""Fetch tree and verify the current user's team owns it. Raises 404 if not found or access denied."""
result = await db.execute(select(Tree).where(Tree.id == tree_id))
tree = result.scalar_one_or_none()
if not tree:
@@ -38,7 +38,7 @@ async def _get_tree_or_403(tree_id: UUID, current_user: User, db: AsyncSession)
# Super admins can access any tree; regular users must be on the same team
if not getattr(current_user, 'is_super_admin', False):
if tree.team_id != current_user.team_id:
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=404, detail="Tree not found")
return tree
@@ -51,7 +51,7 @@ async def create_schedule(
):
"""Create a cron schedule for a maintenance flow. One per flow."""
# Verify user's team owns the tree
tree = await _get_tree_or_403(data.tree_id, current_user, db)
tree = await _get_tree_or_404(data.tree_id, current_user, db)
if tree.tree_type != "maintenance":
raise HTTPException(status_code=400, detail="Schedules are only supported for maintenance flows")
@@ -94,7 +94,7 @@ async def get_schedule_for_tree(
):
"""Get the schedule for a specific maintenance flow."""
# Verify user's team owns the tree before returning schedule data
await _get_tree_or_403(tree_id, current_user, db)
await _get_tree_or_404(tree_id, current_user, db)
result = await db.execute(
select(MaintenanceSchedule).where(MaintenanceSchedule.tree_id == tree_id)
@@ -122,7 +122,7 @@ async def update_schedule(
raise HTTPException(status_code=404, detail="Schedule not found")
# Verify user's team owns the tree this schedule belongs to
await _get_tree_or_403(schedule.tree_id, current_user, db)
await _get_tree_or_404(schedule.tree_id, current_user, db)
update_fields = data.model_fields_set
was_active = schedule.is_active

View File

@@ -1,332 +0,0 @@
"""Network diagrams API endpoints."""
import logging
from datetime import datetime, timezone
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db
from app.api.deps import get_current_active_user
from app.models.user import User
from app.models.device_type import DeviceType
from app.models.network_diagram import NetworkDiagram
from app.schemas.network_diagram import (
NetworkDiagramCreate,
NetworkDiagramUpdate,
NetworkDiagramResponse,
NetworkDiagramListItem,
AIGenerateRequest,
AIGenerateResponse,
DiagramImportRequest,
DiagramImportResponse,
DiagramExportResponse,
DiagramNode,
DiagramEdge,
)
from app.services import network_diagram_ai_service
# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts
_SLUG_CATEGORY: dict[str, str] = {
"router": "network", "switch": "network", "access-point": "network", "load-balancer": "network",
"firewall": "security", "badge-reader": "security",
"server": "compute", "vm": "compute", "container": "compute",
"nas": "storage", "san": "storage", "cloud-storage": "storage",
"cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud",
"workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint",
"phone": "endpoint", "printer": "endpoint",
"ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure",
"patch-panel": "infrastructure", "camera": "infrastructure",
"nvr": "infrastructure", "iot": "infrastructure",
}
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
async def _get_diagram_or_404(
diagram_id: UUID,
team_id: UUID,
db: AsyncSession,
) -> NetworkDiagram:
diagram = await db.get(NetworkDiagram, diagram_id)
if not diagram or diagram.team_id != team_id or diagram.is_archived:
raise HTTPException(status_code=404, detail="Diagram not found")
return diagram
def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse:
return NetworkDiagramResponse.model_validate(diagram)
def _diagram_to_list_item(
diagram: NetworkDiagram,
custom_slug_category: dict[str, str] | None = None,
) -> NetworkDiagramListItem:
nodes = diagram.nodes if isinstance(diagram.nodes, list) else []
slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})}
category_counts: dict[str, int] = {}
for node in nodes:
slug = node.get("type", "") if isinstance(node, dict) else ""
cat = slug_to_cat.get(slug, "other")
category_counts[cat] = category_counts.get(cat, 0) + 1
return NetworkDiagramListItem(
id=diagram.id,
name=diagram.name,
client_name=diagram.client_name,
description=diagram.description,
node_count=len(nodes),
category_counts=category_counts,
created_by=diagram.created_by,
created_at=diagram.created_at,
updated_at=diagram.updated_at,
)
async def _get_available_slugs(team_id: UUID, db: AsyncSession) -> set[str]:
stmt = select(DeviceType.slug).where(
or_(DeviceType.is_system.is_(True), DeviceType.team_id == team_id)
)
result = await db.execute(stmt)
return {row[0] for row in result.all()}
@router.get("/clients", response_model=list[str])
async def list_client_names(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> list[str]:
stmt = (
select(NetworkDiagram.client_name)
.where(
NetworkDiagram.team_id == current_user.team_id,
NetworkDiagram.is_archived.is_(False),
NetworkDiagram.client_name.isnot(None),
NetworkDiagram.client_name != "",
)
.distinct()
.order_by(NetworkDiagram.client_name)
)
result = await db.execute(stmt)
return [row[0] for row in result.all()]
@router.get("/", response_model=list[NetworkDiagramListItem])
async def list_diagrams(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
client_name: str | None = Query(default=None),
search: str | None = Query(default=None),
) -> list[NetworkDiagramListItem]:
stmt = (
select(NetworkDiagram)
.where(
NetworkDiagram.team_id == current_user.team_id,
NetworkDiagram.is_archived.is_(False),
)
.order_by(NetworkDiagram.updated_at.desc())
)
if client_name:
stmt = stmt.where(NetworkDiagram.client_name == client_name)
if search:
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
search_filter = f"%{escaped}%"
stmt = stmt.where(
or_(
NetworkDiagram.name.ilike(search_filter),
NetworkDiagram.client_name.ilike(search_filter),
)
)
# Single query for custom device types so category_counts is accurate
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
DeviceType.is_system.is_(False),
DeviceType.team_id == current_user.team_id,
)
dt_result = await db.execute(dt_stmt)
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
result = await db.execute(stmt)
rows = result.scalars().all()
return [_diagram_to_list_item(r, custom_slug_category) for r in rows]
@router.post("/", response_model=NetworkDiagramResponse, status_code=201)
async def create_diagram(
data: NetworkDiagramCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
if current_user.team_id is None:
raise HTTPException(
status_code=422,
detail="Network Diagrams require a team account. Assign your account to a team first.",
)
diagram = NetworkDiagram(
team_id=current_user.team_id,
name=data.name,
client_name=data.client_name,
asset_name=data.asset_name,
description=data.description,
nodes=[n.model_dump() for n in data.nodes],
edges=[e.model_dump() for e in data.edges],
created_by=current_user.id,
)
db.add(diagram)
await db.commit()
await db.refresh(diagram)
return _diagram_to_response(diagram)
@router.get("/{diagram_id}", response_model=NetworkDiagramResponse)
async def get_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
return _diagram_to_response(diagram)
@router.put("/{diagram_id}", response_model=NetworkDiagramResponse)
async def update_diagram(
diagram_id: UUID,
data: NetworkDiagramUpdate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
update_data = data.model_dump(exclude_unset=True)
if "nodes" in update_data and update_data["nodes"] is not None:
update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]]
if "edges" in update_data and update_data["edges"] is not None:
update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]]
for field, value in update_data.items():
setattr(diagram, field, value)
diagram.updated_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(diagram)
return _diagram_to_response(diagram)
@router.delete("/{diagram_id}", status_code=204)
async def archive_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> None:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
diagram.is_archived = True
diagram.updated_at = datetime.now(timezone.utc)
await db.commit()
@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201)
async def duplicate_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> NetworkDiagramResponse:
source = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
copy = NetworkDiagram(
team_id=current_user.team_id,
name=f"Copy of {source.name}",
client_name=source.client_name,
asset_name=source.asset_name,
description=source.description,
nodes=source.nodes,
edges=source.edges,
created_by=current_user.id,
)
db.add(copy)
await db.commit()
await db.refresh(copy)
return _diagram_to_response(copy)
@router.get("/{diagram_id}/export", response_model=DiagramExportResponse)
async def export_diagram(
diagram_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramExportResponse:
diagram = await _get_diagram_or_404(diagram_id, current_user.team_id, db)
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
return DiagramExportResponse(
schemaVersion=1,
name=diagram.name,
client_name=diagram.client_name,
description=diagram.description,
nodes=nodes,
edges=edges,
exportedAt=datetime.now(timezone.utc).isoformat(),
)
@router.post("/import", response_model=DiagramImportResponse, status_code=201)
async def import_diagram(
data: DiagramImportRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> DiagramImportResponse:
available_slugs = await _get_available_slugs(current_user.team_id, db)
warnings: list[str] = []
for node in data.nodes:
if node.type not in available_slugs:
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
diagram = NetworkDiagram(
team_id=current_user.team_id,
name=data.name,
client_name=data.client_name,
description=data.description,
nodes=[n.model_dump() for n in data.nodes],
edges=[e.model_dump() for e in data.edges],
created_by=current_user.id,
)
db.add(diagram)
await db.commit()
await db.refresh(diagram)
return DiagramImportResponse(
diagram=_diagram_to_response(diagram),
warnings=warnings,
)
@router.post("/ai-generate", response_model=AIGenerateResponse)
async def ai_generate_diagram(
data: AIGenerateRequest,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
) -> AIGenerateResponse:
available_slugs_set = await _get_available_slugs(current_user.team_id, db)
available_slugs = list(available_slugs_set)
existing_node_ids: list[str] | None = None
if data.mode == "merge" and data.existingBounds:
existing_node_ids = []
try:
return await network_diagram_ai_service.generate_diagram(
request=data,
available_slugs=available_slugs,
existing_node_ids=existing_node_ids,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception:
logger.exception("AI diagram generation failed")
raise HTTPException(status_code=500, detail="Diagram generation failed")

View File

@@ -197,6 +197,7 @@ async def create_template(
template = ScriptTemplate(
category_id=data.category_id,
team_id=current_user.team_id,
account_id=current_user.account_id,
created_by=current_user.id,
name=data.name,
slug=slug,
@@ -364,6 +365,7 @@ async def generate_script(
generation = ScriptGeneration(
template_id=template.id,
user_id=current_user.id,
account_id=current_user.account_id,
team_id=current_user.team_id,
session_id=data.session_id,
ai_session_id=data.ai_session_id,

View File

@@ -143,8 +143,8 @@ async def get_session(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
return session
@@ -234,8 +234,8 @@ async def update_session(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
if session.completed_at:
@@ -281,8 +281,8 @@ async def complete_session(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
if session.completed_at:
@@ -319,8 +319,8 @@ async def update_scratchpad(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
session.scratchpad = data.scratchpad
@@ -348,8 +348,8 @@ async def update_session_variables(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
if session.completed_at:
@@ -387,8 +387,8 @@ async def export_session(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
# PDF export — separate path with binary response
@@ -830,8 +830,8 @@ async def link_ticket(
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
if not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session",
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found",
)
# Unlink

View File

@@ -72,8 +72,8 @@ async def create_share(
if session.user_id != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the session owner can create share links"
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
# Require account_id for account-scoped shares
@@ -170,8 +170,8 @@ async def revoke_share(
if share.created_by != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the share creator can revoke it"
status_code=status.HTTP_404_NOT_FOUND,
detail="Share not found"
)
share.is_active = False

View File

@@ -16,6 +16,7 @@ from app.schemas.step_category import (
)
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_step_category, can_create_step_category
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/step-categories", tags=["step-categories"])
@@ -44,13 +45,13 @@ async def list_step_categories(
elif current_user.account_id:
query = query.where(
or_(
StepCategory.account_id.is_(None), # Global
StepCategory.account_id == PLATFORM_ACCOUNT_ID, # Global
StepCategory.account_id == current_user.account_id # User's account
)
)
else:
# User has no account, only show global categories
query = query.where(StepCategory.account_id.is_(None))
query = query.where(StepCategory.account_id == PLATFORM_ACCOUNT_ID)
query = query.order_by(StepCategory.display_order, StepCategory.name)
@@ -94,8 +95,8 @@ async def get_step_category(
# Check access: global categories visible to all, account categories only to account members
if category.account_id and category.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this step category"
status_code=status.HTTP_404_NOT_FOUND,
detail="Step category not found"
)
return StepCategoryResponse(
@@ -155,7 +156,7 @@ async def create_step_category(
name=category_data.name,
slug=slug,
description=category_data.description,
account_id=category_data.account_id,
account_id=category_data.account_id if category_data.account_id is not None else PLATFORM_ACCOUNT_ID,
display_order=max_order + 1,
created_by=current_user.id
)

View File

@@ -47,10 +47,10 @@ async def get_step_or_404(
raise HTTPException(status_code=404, detail="Step not found")
if check_view and not can_view_step(current_user, step):
raise HTTPException(status_code=403, detail="Not authorized to view this step")
raise HTTPException(status_code=404, detail="Step not found")
if check_edit and not can_edit_step(current_user, step):
raise HTTPException(status_code=403, detail="Not authorized to modify this step")
raise HTTPException(status_code=404, detail="Step not found")
return step

View File

@@ -12,6 +12,7 @@ from app.models.user import User
from app.schemas.tag import TagCreate, TagResponse, TagListResponse, TagAssignment
from app.api.deps import get_current_active_user
from app.core.permissions import can_manage_tree_tags, can_create_tag
from app.core.service_account import PLATFORM_ACCOUNT_ID
router = APIRouter(prefix="/tags", tags=["tags"])
@@ -33,13 +34,13 @@ async def list_tags(
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.account_id.is_(None), # Global
TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global
TreeTag.account_id == current_user.account_id # User's account
)
)
else:
# Only show global tags
query = query.where(TreeTag.account_id.is_(None))
query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID)
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name)
@@ -71,12 +72,12 @@ async def search_tags(
if include_account and current_user.account_id:
query = query.where(
or_(
TreeTag.account_id.is_(None),
TreeTag.account_id == PLATFORM_ACCOUNT_ID,
TreeTag.account_id == current_user.account_id
)
)
else:
query = query.where(TreeTag.account_id.is_(None))
query = query.where(TreeTag.account_id == PLATFORM_ACCOUNT_ID)
query = query.order_by(TreeTag.usage_count.desc(), TreeTag.name).limit(limit)
@@ -105,8 +106,8 @@ async def get_tag(
# Check access: global tags visible to all, account tags only to account members
if tag.account_id and tag.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tag"
status_code=status.HTTP_404_NOT_FOUND,
detail="Tag not found"
)
return TagResponse.model_validate(tag)
@@ -147,7 +148,7 @@ async def create_tag(
new_tag = TreeTag(
name=tag_data.name,
slug=slug,
account_id=tag_data.account_id,
account_id=tag_data.account_id if tag_data.account_id is not None else PLATFORM_ACCOUNT_ID,
created_by=current_user.id
)
db.add(new_tag)
@@ -206,7 +207,7 @@ async def add_tags_to_tree(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.account_id.is_(None), # Global tag
TreeTag.account_id == PLATFORM_ACCOUNT_ID, # Global tag
TreeTag.account_id == tag_account_id # Account tag
)
)
@@ -340,7 +341,7 @@ async def replace_tree_tags(
tag_query = select(TreeTag).where(
TreeTag.slug == slug,
or_(
TreeTag.account_id.is_(None),
TreeTag.account_id == PLATFORM_ACCOUNT_ID,
TreeTag.account_id == tag_account_id
)
)

View File

@@ -29,6 +29,7 @@ from app.core.subscriptions import check_tree_limit, get_account_subscription, g
from app.core.audit import log_audit
from app.core.config import settings
from app.core.tree_validation import can_publish_tree
from app.core.service_account import PLATFORM_ACCOUNT_ID
from app.core.step_sync import sync_steps_from_tree, deactivate_synced_steps_for_tree
from app.services.rag_service import index_tree as rag_index_tree
@@ -391,9 +392,10 @@ async def get_tree(
)
if not tree.is_active or not can_access_tree(current_user, tree):
# Always 404, never 403. A 403 confirms the resource exists.
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
return build_full_tree_response(tree)
@@ -470,7 +472,7 @@ async def create_tree(
tree_structure=tree_data.tree_structure,
intake_form=intake_form_data,
author_id=service_account_id if is_default else current_user.id,
account_id=None if is_default else current_user.account_id,
account_id=PLATFORM_ACCOUNT_ID if is_default else current_user.account_id,
is_public=True if is_default else tree_data.is_public, # Default trees are always public
is_default=is_default,
status=tree_data.status
@@ -610,9 +612,17 @@ async def update_tree(
)
if not can_edit_tree(current_user, tree):
# If the user can see this tree (same account, team visibility), give a 403 with
# a clear message — returning 404 here would be confusing since GET returns 200.
# For truly inaccessible trees (cross-account), return 404 to avoid confirming existence.
if can_access_tree(current_user, tree):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to edit this flow"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only edit your own trees"
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
# Extract tags for separate handling
@@ -1144,9 +1154,17 @@ async def update_tree_visibility(
)
if not can_edit_tree(current_user, tree):
# If the user can see this tree (same account, team visibility), give a 403 with
# a clear message — returning 404 here would be confusing since GET returns 200.
# For truly inaccessible trees (cross-account), return 404 to avoid confirming existence.
if can_access_tree(current_user, tree):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to edit this flow"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only edit your own trees"
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
# Update visibility

View File

@@ -255,9 +255,9 @@ async def get_upload_url(
if upload is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
# Verify the upload belongs to the user's account
# Verify the upload belongs to the user's account — 404 to avoid revealing existence
if upload.account_id != current_user.account_id and not current_user.is_super_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
url = storage_service.get_presigned_url(upload.storage_key)
return {"url": url}
@@ -311,9 +311,9 @@ async def delete_upload(
if upload is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
# Verify ownership
# Verify ownership — 404 to avoid revealing existence
if upload.uploaded_by != current_user.id and not current_user.is_super_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found")
# Delete from S3
await storage_service.delete_file(upload.storage_key)

View File

@@ -1,53 +1,89 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, sidebar, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares, shared, tree_markdown
from app.api.endpoints import admin_dashboard, admin_audit, admin_plan_limits, admin_feature_flags, admin_settings, admin_categories
from app.api.endpoints import ratings, analytics
from app.api.endpoints import target_lists
from app.api.endpoints import maintenance_schedules
from app.api.endpoints import feedback
from app.api.endpoints import ai_builder
from app.api.endpoints import ai_fix
from app.api.endpoints import ai_chat
from app.api.endpoints import copilot
from app.api.endpoints import assistant_chat
from app.api.endpoints import survey
from app.api.endpoints import admin_survey
from app.api.endpoints import tree_transfer
from app.api.endpoints import ai_suggestions
from app.api.endpoints import kb_accelerator
from app.api.endpoints import beta_signup
from app.api.endpoints import scripts
from app.api.endpoints import integrations
from app.api.endpoints import onboarding
from app.api.endpoints import branding
from app.api.endpoints import supporting_data
from app.api.endpoints import ai_sessions
from app.api.endpoints import flow_proposals
from app.api.endpoints import flowpilot_analytics
from app.api.endpoints import notifications
from app.api.endpoints import public_templates
from app.api.endpoints import admin_gallery
from app.api.endpoints import uploads
from app.api.endpoints import script_builder
from app.api.endpoints import beta_feedback
from app.api.endpoints import session_branches
from app.api.endpoints import session_handoffs
from app.api.endpoints import session_resolutions
from app.api.endpoints import device_types
from app.api.endpoints import network_diagrams
from fastapi import APIRouter, Depends
from app.api.deps import require_tenant_context
from app.api.endpoints import (
admin,
admin_audit,
admin_categories,
admin_dashboard,
admin_feature_flags,
admin_gallery,
admin_plan_limits,
admin_settings,
admin_survey,
ai_builder,
ai_chat,
ai_fix,
ai_sessions,
ai_suggestions,
analytics,
assistant_chat,
auth,
beta_feedback,
beta_signup,
branding,
categories,
copilot,
feedback,
flow_proposals,
flowpilot_analytics,
folders,
integrations,
invite,
kb_accelerator,
maintenance_schedules,
notifications,
onboarding,
public_templates,
ratings,
scripts,
script_builder,
session_branches,
session_handoffs,
session_resolutions,
sessions,
shared,
shares,
sidebar,
step_categories,
steps,
supporting_data,
survey,
tags,
target_lists,
tree_markdown,
tree_transfer,
trees,
uploads,
webhooks,
accounts,
)
api_router = APIRouter()
# ---------------------------------------------------------------------------
# Public / unauthenticated endpoints — no tenant context
#
# Note: auth.router contains both public endpoints (register, login,
# forgot-password, reset-password, email/verify) and authenticated endpoints
# (GET/PATCH /me, logout, change-password, email/send-verification).
# The authenticated auth endpoints only query the `users` table, which is
# excluded from Phase 1 RLS. They work correctly without tenant context
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
# ---------------------------------------------------------------------------
api_router.include_router(auth.router)
api_router.include_router(trees.router)
api_router.include_router(sidebar.router)
api_router.include_router(sessions.router)
api_router.include_router(invite.router)
api_router.include_router(categories.router)
api_router.include_router(tags.router)
api_router.include_router(folders.router)
api_router.include_router(step_categories.router)
api_router.include_router(steps.router)
api_router.include_router(shared.router) # Public share links (no auth)
api_router.include_router(beta_signup.router)
api_router.include_router(webhooks.router) # Stripe webhook receiver
api_router.include_router(public_templates.router) # Public gallery (no auth, rate-limited)
# ---------------------------------------------------------------------------
# Admin endpoints — super_admin only
# admin_categories, admin_gallery, admin_dashboard, admin query Phase 1 RLS
# tables and MUST use get_admin_db (migrated in Task 8). The remaining admin
# endpoints (admin_audit, admin_plan_limits, admin_feature_flags,
# admin_settings, admin_survey) are safe until Phase 2 extends RLS.
# ---------------------------------------------------------------------------
api_router.include_router(admin.router)
api_router.include_router(admin_dashboard.router)
api_router.include_router(admin_audit.router)
@@ -55,44 +91,54 @@ api_router.include_router(admin_plan_limits.router)
api_router.include_router(admin_feature_flags.router)
api_router.include_router(admin_settings.router)
api_router.include_router(admin_categories.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)
api_router.include_router(shares.router)
api_router.include_router(shared.router) # Public endpoints (no auth)
api_router.include_router(tree_markdown.router)
api_router.include_router(ratings.router)
api_router.include_router(analytics.router)
api_router.include_router(target_lists.router)
api_router.include_router(maintenance_schedules.router)
api_router.include_router(feedback.router)
api_router.include_router(ai_builder.router)
api_router.include_router(ai_fix.router)
api_router.include_router(ai_chat.router)
api_router.include_router(copilot.router)
api_router.include_router(assistant_chat.router)
api_router.include_router(survey.router)
api_router.include_router(admin_survey.router)
api_router.include_router(tree_transfer.router)
api_router.include_router(ai_suggestions.router)
api_router.include_router(kb_accelerator.router)
api_router.include_router(beta_signup.router)
api_router.include_router(scripts.router)
api_router.include_router(integrations.router)
api_router.include_router(onboarding.router)
api_router.include_router(branding.router)
api_router.include_router(supporting_data.router)
api_router.include_router(network_diagrams.router) # Must be before ai_sessions to avoid /{diagram_id} conflict
api_router.include_router(session_handoffs.queue_router) # Must be before ai_sessions to avoid /{session_id} conflict
api_router.include_router(session_resolutions.router) # Must be before ai_sessions to avoid /{session_id} conflict
api_router.include_router(ai_sessions.router)
api_router.include_router(flow_proposals.router)
api_router.include_router(flowpilot_analytics.router)
api_router.include_router(notifications.router)
api_router.include_router(public_templates.router)
api_router.include_router(admin_gallery.router)
api_router.include_router(uploads.router)
api_router.include_router(script_builder.router)
api_router.include_router(beta_feedback.router)
api_router.include_router(session_branches.router)
api_router.include_router(session_handoffs.router)
api_router.include_router(device_types.router)
# ---------------------------------------------------------------------------
# User-facing endpoints — tenant context required
# ---------------------------------------------------------------------------
_tenant_deps = [Depends(require_tenant_context)]
api_router.include_router(trees.router, dependencies=_tenant_deps)
api_router.include_router(sidebar.router, dependencies=_tenant_deps)
api_router.include_router(sessions.router, dependencies=_tenant_deps)
api_router.include_router(invite.router, dependencies=_tenant_deps)
api_router.include_router(categories.router, dependencies=_tenant_deps)
api_router.include_router(tags.router, dependencies=_tenant_deps)
api_router.include_router(folders.router, dependencies=_tenant_deps)
api_router.include_router(step_categories.router, dependencies=_tenant_deps)
api_router.include_router(steps.router, dependencies=_tenant_deps)
api_router.include_router(accounts.router, dependencies=_tenant_deps)
api_router.include_router(shares.router, dependencies=_tenant_deps)
api_router.include_router(tree_markdown.router, dependencies=_tenant_deps)
api_router.include_router(ratings.router, dependencies=_tenant_deps)
api_router.include_router(analytics.router, dependencies=_tenant_deps)
api_router.include_router(target_lists.router, dependencies=_tenant_deps)
api_router.include_router(maintenance_schedules.router, dependencies=_tenant_deps)
api_router.include_router(feedback.router, dependencies=_tenant_deps)
api_router.include_router(ai_builder.router, dependencies=_tenant_deps)
api_router.include_router(ai_fix.router, dependencies=_tenant_deps)
api_router.include_router(ai_chat.router, dependencies=_tenant_deps)
api_router.include_router(copilot.router, dependencies=_tenant_deps)
api_router.include_router(assistant_chat.router, dependencies=_tenant_deps)
api_router.include_router(survey.router, dependencies=_tenant_deps)
api_router.include_router(tree_transfer.router, dependencies=_tenant_deps)
api_router.include_router(ai_suggestions.router, dependencies=_tenant_deps)
api_router.include_router(kb_accelerator.router, dependencies=_tenant_deps)
api_router.include_router(scripts.router, dependencies=_tenant_deps)
api_router.include_router(integrations.router, dependencies=_tenant_deps)
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
api_router.include_router(branding.router, dependencies=_tenant_deps)
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
# session_handoffs queue router must come before ai_sessions to avoid conflict
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
api_router.include_router(ai_sessions.router, dependencies=_tenant_deps)
api_router.include_router(flow_proposals.router, dependencies=_tenant_deps)
api_router.include_router(flowpilot_analytics.router, dependencies=_tenant_deps)
api_router.include_router(notifications.router, dependencies=_tenant_deps)
api_router.include_router(uploads.router, dependencies=_tenant_deps)
api_router.include_router(script_builder.router, dependencies=_tenant_deps)
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)

View File

@@ -0,0 +1,36 @@
# backend/app/core/admin_database.py
"""
Admin database engine — connects as resolutionflow_admin (BYPASSRLS).
Use ONLY for /admin/* endpoints and internal tooling.
Never use this engine from user-facing endpoints.
"""
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
admin_engine = create_async_engine(
settings.ADMIN_DATABASE_URL,
echo=settings.DEBUG,
future=True,
)
_admin_session_factory = async_sessionmaker(
admin_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints."""
async with _admin_session_factory() as session:
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()

View File

@@ -23,10 +23,33 @@ class Settings(BaseSettings):
return v.replace("postgresql://", "postgresql+asyncpg://", 1)
return v
@property
def DATABASE_URL_SYNC(self) -> str:
"""Get sync URL by removing asyncpg prefix from DATABASE_URL."""
return self.DATABASE_URL.replace("postgresql+asyncpg://", "postgresql://", 1)
# Sync URL for Alembic migrations. Defaults to DATABASE_URL (sync-converted).
# Set explicitly in .env to use a different role for migrations (e.g. superuser)
# when DATABASE_URL has been switched to the app role.
DATABASE_URL_SYNC: str = ""
@field_validator("DATABASE_URL_SYNC", mode="before")
@classmethod
def default_database_url_sync(cls, v: str, info) -> str:
"""Fall back to sync-converted DATABASE_URL if not explicitly set."""
if not v:
base = info.data.get("DATABASE_URL", "")
return base.replace("postgresql+asyncpg://", "postgresql://", 1)
return v
# Admin database — resolutionflow_admin role, BYPASSRLS.
# Used by /admin/* endpoints. Defaults to DATABASE_URL for local dev.
ADMIN_DATABASE_URL: str = ""
@field_validator("ADMIN_DATABASE_URL", mode="before")
@classmethod
def default_admin_database_url(cls, v: str, info) -> str:
"""Fall back to DATABASE_URL if ADMIN_DATABASE_URL is not set."""
if not v:
return info.data.get("DATABASE_URL", "")
if v.startswith("postgresql://"):
return v.replace("postgresql://", "postgresql+asyncpg://", 1)
return v
# JWT Settings
SECRET_KEY: str = _DEFAULT_SECRET_KEY
@@ -105,7 +128,6 @@ class Settings(BaseSettings):
"variable_inference": "fast",
"kb_convert": "standard",
"script_build": "standard",
"network_diagram_generate": "standard",
}
def get_model_for_action(self, action_type: str) -> str:

View File

@@ -1,6 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from .config import settings
from app.core.tenant_context import register_tenant_listener
# Create async engine
engine = create_async_engine(
@@ -16,6 +17,11 @@ async_session_maker = async_sessionmaker(
expire_on_commit=False
)
# Register the RLS tenant context listener on the app engine.
# Fires at the start of every transaction; issues set_config automatically.
# Must NOT be called on admin_engine — admin connections bypass RLS.
register_tenant_listener(engine)
class Base(DeclarativeBase):
"""Base class for all database models."""

View File

@@ -1,10 +1,12 @@
"""
Centralized query filters for ResolutionFlow.
Provides reusable SQLAlchemy filter builders for tree access control
and step visibility, used across multiple endpoint modules.
Provides reusable SQLAlchemy filter builders for tree access control,
step visibility, and the canonical tenant_filter used by all queries
on tenant-scoped tables.
"""
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING
from sqlalchemy import or_, and_, true as sa_true
@@ -13,6 +15,18 @@ if TYPE_CHECKING:
from app.models.user import User
def tenant_filter(model, account_id: uuid.UUID):
"""Primary app-layer tenant filter.
MUST be used in every SELECT/UPDATE/DELETE on tenant tables.
RLS (Phase 2) is the safety net — this is the primary enforcement.
Usage:
stmt = select(Tree).where(tenant_filter(Tree, current_user.account_id), ...)
"""
return model.account_id == account_id
def build_tree_access_filter(current_user: User):
"""Build the access filter for trees based on user permissions.
@@ -36,10 +50,11 @@ def build_tree_access_filter(current_user: User):
Tree.author_id == current_user.id,
]
if current_user.account_id:
# Team-visible trees: use tenant_filter as the account match
conditions.append(
and_(
Tree.visibility == 'team',
Tree.account_id == current_user.account_id
tenant_filter(Tree, current_user.account_id),
)
)
return or_(*conditions)
@@ -58,11 +73,14 @@ def build_step_visibility_filter(current_user: User):
if current_user.account_id:
return or_(
StepLibrary.visibility == 'public',
and_(StepLibrary.visibility == 'team', StepLibrary.account_id == current_user.account_id),
StepLibrary.created_by == current_user.id # Own private steps
and_(
StepLibrary.visibility == 'team',
tenant_filter(StepLibrary, current_user.account_id),
),
StepLibrary.created_by == current_user.id,
)
else:
return or_(
StepLibrary.visibility == 'public',
StepLibrary.created_by == current_user.id
StepLibrary.created_by == current_user.id,
)

View File

@@ -18,6 +18,10 @@ logger = logging.getLogger(__name__)
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
SERVICE_ACCOUNT_NAME = "ResolutionFlow"
# Well-known UUID for the platform account — owns all default/global content.
# Created by migration 3a40fe11b427_create_global_content_tables.
PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
SYSTEM_ACCOUNT_NAME = "ResolutionFlow System"
SYSTEM_ACCOUNT_DISPLAY_CODE = "RF-SYS-1"

View File

@@ -0,0 +1,92 @@
# backend/app/core/tenant_context.py
"""
Per-request tenant context for row-level security.
Flow:
1. require_tenant_context (FastAPI dep) calls set_current_account_id().
2. The SQLAlchemy transaction-begin listener fires on every new transaction
and calls set_config('app.current_account_id', <id>, true) automatically.
3. PostgreSQL RLS policies read current_setting('app.current_account_id', TRUE)
to filter rows.
The ContextVar is asyncio-task-scoped: each concurrent request has its own value.
set_config with is_local=true is transaction-scoped: it resets on COMMIT or
ROLLBACK, so the listener re-applies it at the start of every transaction.
"""
import contextvars
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import event, or_, text
from sqlalchemy.ext.asyncio import AsyncEngine
if TYPE_CHECKING:
from app.models.user import User
# One slot per async task — each concurrent request gets its own value.
_current_account_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"current_account_id", default=None
)
# Platform account — global content visible to all tenants.
PLATFORM_ACCOUNT_ID = UUID("00000000-0000-0000-0000-000000000001")
def set_current_account_id(account_id: UUID) -> contextvars.Token:
"""Set tenant context for the current request coroutine.
Returns a token so the caller can reset it after the request.
"""
return _current_account_id.set(str(account_id))
def clear_current_account_id(token: contextvars.Token) -> None:
"""Reset the ContextVar to its previous value (call in finally block)."""
_current_account_id.reset(token)
def get_current_account_id() -> str | None:
"""Return the account_id string for the current request, or None."""
return _current_account_id.get()
def register_tenant_listener(engine: AsyncEngine) -> None:
"""Register the transaction-begin listener on the given engine.
Must be called once at application startup, AFTER the engine is created.
The listener issues set_config() at the start of every transaction so that
the setting is re-applied automatically even when a request commits
mid-flight and starts a new transaction.
Do NOT call this on admin_engine — admin connections must never set tenant
context automatically.
"""
@event.listens_for(engine.sync_engine, "begin")
def _on_transaction_begin(conn) -> None: # noqa: ANN001
account_id = _current_account_id.get()
if account_id:
# set_config(name, value, is_local=true) ≡ SET LOCAL.
# Unlike SET LOCAL, set_config IS parameterisable.
conn.execute(
text("SELECT set_config('app.current_account_id', :id, true)"),
{"id": account_id},
)
# If no account_id is set, do nothing. The RLS policy falls back to a
# null-matching UUID and returns zero rows — fail-closed behaviour.
def tenant_filter(Model, current_user: "User"): # noqa: ANN001
"""SQLAlchemy filter clause for tables that contain platform-owned rows.
Use for: tree_tags, tree_categories, step_categories, step_library,
template_trees, platform_steps.
For tenant-only tables (trees, sessions, psa_connections, etc.) use:
Model.account_id == current_user.account_id
directly.
"""
return or_(
Model.account_id == current_user.account_id,
Model.account_id == PLATFORM_ACCOUNT_ID,
)

View File

@@ -54,8 +54,8 @@ from .session_branch import SessionBranch
from .fork_point import ForkPoint
from .session_handoff import SessionHandoff
from .session_resolution_output import SessionResolutionOutput
from .device_type import DeviceType
from .network_diagram import NetworkDiagram
from .template_tree import TemplateTree
from .platform_step import PlatformStep
__all__ = [
"User",
@@ -124,6 +124,6 @@ __all__ = [
"ForkPoint",
"SessionHandoff",
"SessionResolutionOutput",
"DeviceType",
"NetworkDiagram",
"TemplateTree",
"PlatformStep",
]

View File

@@ -137,28 +137,6 @@ class AISession(Base):
comment="Snapshot of PSA ticket data at session start",
)
# ── Triage / Cockpit Header ──
client_name: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True,
comment="MSP client name for incident header (AI-inferred or manual)",
)
asset_name: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True,
comment="Device, asset, or user being worked on",
)
issue_category: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True,
comment="Human-readable category (e.g. DNS / Networking)",
)
triage_hypothesis: Mapped[Optional[str]] = mapped_column(
Text, nullable=True,
comment="Current working hypothesis — AI-updated + engineer-editable",
)
evidence_items: Mapped[Optional[list[dict[str, Any]]]] = mapped_column(
JSONB, nullable=True,
comment='What We Know list: [{"text": str, "status": "confirmed"|"ruled_out"|"pending"}]',
)
# ── Resolution / Escalation ──
resolution_summary: Mapped[Optional[str]] = mapped_column(
Text, nullable=True,

View File

@@ -50,6 +50,13 @@ class AISessionStep(Base):
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Denormalized from ai_sessions.account_id for direct tenant filtering.",
)
step_order: Mapped[int] = mapped_column(
Integer, nullable=False,
comment="Sequential position in the session (0-indexed)",

View File

@@ -28,6 +28,12 @@ class AISuggestion(Base):
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
session_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("ai_chat_sessions.id", ondelete="SET NULL"),

View File

@@ -20,6 +20,12 @@ class Attachment(Base):
ForeignKey("sessions.id"),
nullable=False
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
node_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
file_type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)

View File

@@ -39,10 +39,10 @@ class TreeCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -1,47 +0,0 @@
"""Device type model for network diagrams."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class DeviceType(Base):
"""A device type for network diagram nodes (system or team-custom)."""
__tablename__ = "device_types"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
slug: Mapped[str] = mapped_column(
String(50), nullable=False,
comment="Unique identifier used in diagram node data",
)
label: Mapped[str] = mapped_column(
String(100), nullable=False,
comment="Display name",
)
category: Mapped[str] = mapped_column(
String(50), nullable=False,
comment="network, compute, storage, cloud, endpoint, infrastructure, security",
)
is_system: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False,
comment="True for built-in types that cannot be deleted",
)
team_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True,
comment="NULL for system types, set for team-custom types",
)
sort_order: Mapped[int] = mapped_column(
Integer, nullable=False, default=0,
comment="Display order within category",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)

View File

@@ -1,6 +1,5 @@
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import String, Text, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID
@@ -11,7 +10,7 @@ class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="SET NULL"), nullable=True)
account_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=False)
email: Mapped[str] = mapped_column(String(255), nullable=False)
feedback_type: Mapped[str] = mapped_column(String(50), nullable=False)

View File

@@ -46,6 +46,12 @@ class UserFolder(Base):
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
color: Mapped[str] = mapped_column(String(7), nullable=False, default="#6366f1")
icon: Mapped[str] = mapped_column(String(50), nullable=False, default="folder")

View File

@@ -23,6 +23,12 @@ class ForkPoint(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
parent_branch_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=False)
trigger_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True)
fork_reason: Mapped[str] = mapped_column(Text, nullable=False)

View File

@@ -23,6 +23,12 @@ class MaintenanceSchedule(Base):
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
cron_expression: Mapped[str] = mapped_column(String(100), nullable=False)
timezone: Mapped[str] = mapped_column(String(100), nullable=False, default="UTC")
target_list_id: Mapped[Optional[uuid.UUID]] = mapped_column(

View File

@@ -1,53 +0,0 @@
"""Network diagram model."""
import uuid
from datetime import datetime, timezone
from typing import Any, TYPE_CHECKING
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
class NetworkDiagram(Base):
"""A network topology diagram, team-scoped."""
__tablename__ = "network_diagrams"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
team_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
is_archived: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False,
)
created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])

View File

@@ -31,6 +31,12 @@ class NotificationLog(Base):
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
event: Mapped[str] = mapped_column(String(50), nullable=False)
payload: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
status: Mapped[str] = mapped_column(String(20), default="sent")

View File

@@ -0,0 +1,37 @@
"""Platform step model — platform-owned steps, readable by all users.
No account_id. No RLS. Readable by any authenticated user.
Populated by promoting visibility='public' steps from step_library.
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from sqlalchemy import String, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
class PlatformStep(Base):
__tablename__ = "platform_steps"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(255), nullable=False)
step_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
content: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
source_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("step_library.id", ondelete="SET NULL"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)

View File

@@ -25,6 +25,12 @@ class PsaMemberMapping(Base):
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),

View File

@@ -35,6 +35,12 @@ class PsaPostLog(Base):
ForeignKey("psa_connections.id", ondelete="SET NULL"),
nullable=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
ticket_id: Mapped[str] = mapped_column(String(100), nullable=False)
note_type: Mapped[str] = mapped_column(String(50), nullable=False)
content_posted: Mapped[str] = mapped_column(Text, nullable=False)

View File

@@ -29,6 +29,12 @@ class ScriptBuilderSession(Base):
nullable=False,
index=True,
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("teams.id", ondelete="SET NULL"),

View File

@@ -44,6 +44,12 @@ class ScriptTemplate(Base):
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"), nullable=True, index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
@@ -97,6 +103,12 @@ class ScriptGeneration(Base):
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
team_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="SET NULL"), nullable=True, index=True
)

View File

@@ -31,6 +31,12 @@ class Session(Base):
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tree_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
path_taken: Mapped[list[str]] = mapped_column(JSONB, nullable=False, default=list)
decisions: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default=list)

View File

@@ -35,6 +35,12 @@ class SessionBranch(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
parent_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="CASCADE"), nullable=True)
fork_point_step_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_session_steps.id", ondelete="SET NULL"), nullable=True)
branch_order: Mapped[int] = mapped_column(Integer, nullable=False, default=1)

View File

@@ -27,6 +27,12 @@ class SessionHandoff(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
handed_off_by: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
intent: Mapped[str] = mapped_column(String(20), nullable=False)
source_branch_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True), ForeignKey("session_branches.id", ondelete="SET NULL"), nullable=True)

View File

@@ -23,6 +23,12 @@ class SessionResolutionOutput(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("ai_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
output_type: Mapped[str] = mapped_column(String(30), nullable=False)
generated_content: Mapped[str] = mapped_column(Text, nullable=False)
structured_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB, nullable=True, comment="For KB: {symptoms, root_cause, steps, tags}")

View File

@@ -38,10 +38,10 @@ class StepCategory(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
display_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -46,10 +46,10 @@ class StepLibrary(Base):
ForeignKey("teams.id", ondelete="CASCADE"),
nullable=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
@@ -143,6 +143,13 @@ class StepRating(Base):
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Account of the RATER (not the step owner).",
)
rating: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
was_helpful: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
review_text: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
@@ -187,6 +194,13 @@ class StepUsageLog(Base):
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Account of the user who logged this usage.",
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),

View File

@@ -14,6 +14,12 @@ class SessionSupportingData(Base):
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
session_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("sessions.id", ondelete="CASCADE"), nullable=False, index=True)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
label: Mapped[str] = mapped_column(String(255), nullable=False)
data_type: Mapped[str] = mapped_column(Enum("text_snippet", "screenshot", name="supporting_data_type"), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)

View File

@@ -51,10 +51,10 @@ class TreeTag(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)

View File

@@ -9,6 +9,7 @@ from app.core.database import Base
if TYPE_CHECKING:
from app.models.user import User
from app.models.team import Team
from app.models.account import Account
class TargetList(Base):
@@ -21,6 +22,12 @@ class TargetList(Base):
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
nullable=False, index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)

View File

@@ -0,0 +1,40 @@
"""Template tree model — platform-owned troubleshooting trees, readable by all users.
No account_id. No RLS. Readable by any authenticated user.
Populated by promoting is_default=TRUE trees from the trees table.
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
class TemplateTree(Base):
__tablename__ = "template_trees"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
category: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
tree_type: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
tree_structure: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False)
tags: Mapped[list] = mapped_column(JSONB, nullable=False, default=list)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
source_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="SET NULL"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)

View File

@@ -76,10 +76,10 @@ class Tree(Base):
nullable=True,
index=True
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
index=True
)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)

View File

@@ -37,10 +37,10 @@ class TreeEmbedding(Base):
ForeignKey("trees.id", ondelete="CASCADE"),
nullable=False,
)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=True,
nullable=False,
)
chunk_type: Mapped[str] = mapped_column(
String(30),

View File

@@ -43,10 +43,10 @@ class User(Base):
must_change_password: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
# Account-based multi-tenancy (new)
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="RESTRICT"),
nullable=True,
nullable=False,
index=True
)
account_role: Mapped[str] = mapped_column(String(50), nullable=False, default="engineer")

View File

@@ -24,6 +24,12 @@ class UserPinnedTree(Base):
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tree_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="CASCADE"),

View File

@@ -28,110 +28,6 @@ class ActivityEntry(BaseModel):
from_attributes = True
# --- Admin Accounts & People Search ---
class AdminUserListItem(BaseModel):
id: UUID
email: EmailStr
name: str
role: str
is_super_admin: bool = False
is_active: bool = True
account_id: Optional[UUID] = None
account_role: Optional[str] = None
account_name: Optional[str] = None
account_display_code: Optional[str] = None
created_at: datetime
last_login: Optional[datetime] = None
deleted_at: Optional[datetime] = None
class AdminUserListResponse(BaseModel):
items: list[AdminUserListItem]
total: int
page: int
per_page: int
class AdminAccountMember(BaseModel):
id: UUID
email: EmailStr
name: str
role: str
is_super_admin: bool = False
is_active: bool = True
account_role: Optional[str] = None
created_at: datetime
last_login: Optional[datetime] = None
deleted_at: Optional[datetime] = None
class AdminAccountOwnerSummary(BaseModel):
id: UUID
name: str
email: EmailStr
class AdminAccountSubscriptionSummary(BaseModel):
id: UUID
plan: str
status: str
billing_interval: Optional[str] = None
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
class AdminAccountUsageSummary(BaseModel):
tree_count: int = 0
session_count_this_month: int = 0
class AdminAccountInviteSummary(BaseModel):
id: UUID
email: EmailStr
role: str
expires_at: Optional[datetime] = None
created_at: datetime
used_at: Optional[datetime] = None
class AdminAccountListItem(BaseModel):
id: UUID
name: str
display_code: str
created_at: datetime
owner_id: Optional[UUID] = None
owner: Optional[AdminAccountOwnerSummary] = None
subscription: Optional[AdminAccountSubscriptionSummary] = None
usage: AdminAccountUsageSummary = Field(default_factory=AdminAccountUsageSummary)
member_count: int = 0
active_member_count: int = 0
pending_invite_count: int = 0
sso_enabled: bool = False
branding_company_name: Optional[str] = None
members: list[AdminAccountMember] = Field(default_factory=list)
class AdminAccountListResponse(BaseModel):
items: list[AdminAccountListItem]
total: int
page: int
per_page: int
class AdminAccountDetailResponse(AdminAccountListItem):
invites: list[AdminAccountInviteSummary] = Field(default_factory=list)
class AdminAccountCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
plan: Literal["free", "pro", "team"] = "free"
class AdminAccountUpdate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
# --- Audit Logs ---
class AuditLogEntry(BaseModel):

View File

@@ -102,20 +102,12 @@ class ResolveSessionRequest(BaseModel):
resolution_action: str | None = None
session_rating: int | None = Field(None, ge=1, le=5)
session_feedback: str | None = None
# Structured handoff fields (from cockpit conclude modal)
root_cause: str | None = None
steps_taken: list[str] | None = None
recommendations: str | None = None
class EscalateSessionRequest(BaseModel):
"""Escalate a session to another engineer."""
escalation_reason: str = Field(..., min_length=5, max_length=2000)
escalated_to_id: UUID | None = None
# Structured handoff fields (from cockpit conclude modal)
root_cause: str | None = None
steps_taken: list[str] | None = None
recommendations: str | None = None
class DocumentationStep(BaseModel):
@@ -240,12 +232,6 @@ class AISessionDetail(AISessionSummary):
pending_task_lane: dict[str, Any] | None = None
is_branching: bool = False
active_branch_id: str | None = None
# Triage / cockpit header fields
client_name: str | None = None
asset_name: str | None = None
issue_category: str | None = None
triage_hypothesis: str | None = None
evidence_items: list[dict[str, Any]] | None = None
model_config = {"from_attributes": True}
@@ -291,16 +277,6 @@ class QuestionItem(BaseModel):
"""A question the AI needs answered by the engineer."""
text: str
context: str = ""
options: list[str] | None = None # quick-reply button labels; null = free-text input
class TriageUpdate(BaseModel):
"""AI-inferred triage metadata returned with chat responses."""
client_name: str | None = None
asset_name: str | None = None
issue_category: str | None = None
triage_hypothesis: str | None = None
evidence_items: list[dict[str, Any]] | None = None # appends to existing list
class ChatMessageResponse(BaseModel):
@@ -310,7 +286,6 @@ class ChatMessageResponse(BaseModel):
fork: ForkMetadata | None = None
actions: list[ActionItem] | None = None
questions: list[QuestionItem] | None = None
triage_update: TriageUpdate | None = None
class SaveTaskLaneRequest(BaseModel):
@@ -333,24 +308,3 @@ class AISessionSearchResult(BaseModel):
created_at: datetime
model_config = {"from_attributes": True}
# ── Triage / Cockpit ──
class TriagePatchRequest(BaseModel):
"""Update triage metadata on a session (incident header fields)."""
client_name: str | None = None
asset_name: str | None = None
issue_category: str | None = None
triage_hypothesis: str | None = None
evidence_items: list[dict[str, Any]] | None = None
class TriagePatchResponse(BaseModel):
"""Updated triage metadata after a PATCH."""
id: UUID
client_name: str | None = None
asset_name: str | None = None
issue_category: str | None = None
triage_hypothesis: str | None = None
evidence_items: list[dict[str, Any]] | None = None

View File

@@ -1,37 +0,0 @@
"""Pydantic schemas for device types."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class DeviceTypeCreate(BaseModel):
slug: str = Field(min_length=1, max_length=50, pattern=r"^[a-z0-9\-]+$")
label: str = Field(min_length=1, max_length=100)
category: str = Field(
min_length=1, max_length=50,
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
)
sort_order: int = Field(default=0, ge=0)
class DeviceTypeUpdate(BaseModel):
label: str | None = Field(default=None, min_length=1, max_length=100)
category: str | None = Field(
default=None, min_length=1, max_length=50,
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
)
sort_order: int | None = Field(default=None, ge=0)
class DeviceTypeResponse(BaseModel):
id: UUID
slug: str
label: str
category: str
is_system: bool
team_id: UUID | None = None
sort_order: int
created_at: datetime
model_config = {"from_attributes": True}

View File

@@ -1,136 +0,0 @@
"""Pydantic schemas for network diagrams."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class Position(BaseModel):
x: float
y: float
class DeviceProperties(BaseModel):
hostname: str | None = None
ip: str | None = None
subnet: str | None = None
vendor: str | None = None
model: str | None = None
role: str | None = None
vlan: str | None = None
notes: str | None = None
status: str = Field(default="unknown", pattern=r"^(unknown|online|offline|degraded)$")
class DiagramNode(BaseModel):
id: str
type: str
label: str
position: Position
properties: DeviceProperties = Field(default_factory=DeviceProperties)
class DiagramEdge(BaseModel):
id: str
source: str
target: str
label: str | None = None
connectionType: str = "ethernet"
speed: str | None = None
notes: str | None = None
routing: str | None = None
class NetworkDiagramCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
class NetworkDiagramUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] | None = None
edges: list[DiagramEdge] | None = None
class NetworkDiagramResponse(BaseModel):
id: UUID
team_id: UUID
name: str
client_name: str | None = None
asset_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
thumbnail_url: str | None = None
is_archived: bool = False
created_by: UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class NetworkDiagramListItem(BaseModel):
id: UUID
name: str
client_name: str | None = None
description: str | None = None
node_count: int = 0
category_counts: dict[str, int] = Field(default_factory=dict)
created_by: UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ExistingBounds(BaseModel):
minX: float
maxX: float
minY: float
maxY: float
class AIGenerateRequest(BaseModel):
description: str = Field(min_length=1, max_length=5000)
client_name: str | None = None
mode: str = Field(default="replace", pattern=r"^(replace|merge)$")
existingBounds: ExistingBounds | None = None
class AIGenerateResponse(BaseModel):
nodes: list[DiagramNode]
edges: list[DiagramEdge]
suggestedName: str | None = None
notes: str | None = None
class DiagramImportRequest(BaseModel):
schemaVersion: int = Field(ge=1, le=1)
name: str = Field(min_length=1, max_length=255)
client_name: str | None = None
description: str | None = None
nodes: list[DiagramNode] = Field(default_factory=list)
edges: list[DiagramEdge] = Field(default_factory=list)
class DiagramImportResponse(BaseModel):
diagram: NetworkDiagramResponse
warnings: list[str] = Field(default_factory=list)
class DiagramExportResponse(BaseModel):
schemaVersion: int = 1
name: str
client_name: str | None = None
description: str | None = None
nodes: list[DiagramNode]
edges: list[DiagramEdge]
exportedAt: str

View File

@@ -77,6 +77,9 @@ scope narrows it to this endpoint.
- JSON array of objects with `text` (required) and `context` (optional, 1 sentence)
- 1-3 questions per response
- Do NOT ask questions inline in your prose. ALL questions go in the marker.
- If the engineer's message contains tasks marked `_(not yet completed)_`, re-include \
those as questions/actions in your next response UNLESS you are ≥75% confident the \
information is no longer needed to resolve the issue. Default to keeping them.
**[ACTIONS] marker format:**
- JSON array of objects with `label` (required), `command` (optional), `description` (required)
@@ -84,38 +87,9 @@ scope narrows it to this endpoint.
- Commands should be PowerShell unless context indicates Linux/Mac
- For GUI-only steps, omit `command`
**[QUESTIONS] `options` field:**
When a question has a small, constrained set of answers (yes/no, 2-4 choices), include \
an `options` array with the answer labels. The engineer will see these as quick-reply buttons. \
Example: `{"text": "Did nslookup time out or return a wrong IP?", "options": ["Timed out", "Wrong IP", "Both"]}`
Omit `options` when the answer is open-ended.
**Both markers are stripped from display** — the engineer sees them as interactive UI cards, \
not raw JSON. Put analysis BEFORE markers. Markers go at the END of your response.
## Triage Context Extraction
When you learn NEW facts about the case from the engineer's messages, emit a \
[TRIAGE_UPDATE] marker with a JSON object containing ONLY the fields that changed. \
Do NOT repeat unchanged fields. Only emit this marker when you have grounded evidence — \
never guess or fabricate. If you are not confident, do not emit the marker.
Fields:
- `client_name` — the MSP client/company being helped (only from explicit mention or ticket data)
- `asset_name` — the device, user, or asset being troubleshot
- `issue_category` — human-readable category like "DNS / Networking", "Microsoft 365", "Active Directory"
- `triage_hypothesis` — your current working hypothesis about the root cause (update as evidence changes)
- `evidence_items` — NEW evidence to append: `[{"text": "description", "status": "confirmed|ruled_out|pending"}]`
Example (only include fields that have new information):
[TRIAGE_UPDATE]
{"issue_category": "DNS / Networking", "triage_hypothesis": "Corrupted DNS cache on NIC", "evidence_items": [{"text": "Gateway 192.168.1.1 reachable", "status": "confirmed"}, {"text": "DNS 1.1.1.1 timeout", "status": "ruled_out"}]}
[/TRIAGE_UPDATE]
Place [TRIAGE_UPDATE] AFTER [QUESTIONS]/[ACTIONS] markers, before [FORK] if present. \
This marker is optional — only emit it when you learn something new.
## Using the Team's Flow Library
Your team has built troubleshooting flows in ResolutionFlow. When relevant flows \
appear in the context below, reference them by name so the engineer can launch them \
@@ -184,6 +158,8 @@ To create a fork, append this marker AFTER your [QUESTIONS]/[ACTIONS] markers:
Every single response MUST contain [QUESTIONS] and/or [ACTIONS] markers with valid JSON. \
No exceptions. Not even when forking. A response without at least one of these markers \
will crash the UI. If you are unsure, include both. The markers are REQUIRED output, not optional.
If any tasks in the engineer's message are marked `_(not yet completed)_`, re-include them \
in your markers unless you are ≥75% confident that information is no longer relevant.
"""
@@ -327,6 +303,8 @@ async def _call_anthropic_cached(
}
]
_mcp_active = mcp_servers is not anthropic.NOT_GIVEN
try:
response = await client.beta.messages.create(
model=settings.AI_MODEL_ANTHROPIC,
@@ -337,12 +315,22 @@ async def _call_anthropic_cached(
tools=tools,
betas=["mcp-client-2025-11-20"],
)
except anthropic.BadRequestError as e:
# MCP server failures (rate limits, connection errors) should not
# block the assistant entirely — retry without MCP tools.
if "MCP server" in str(e) and mcp_servers is not anthropic.NOT_GIVEN:
logger.warning("MCP server error, retrying without MCP: %s", e)
response = await client.beta.messages.create(
except Exception as e:
# MCP server failures surface as many error types — BadRequestError,
# APIStatusError, APIConnectionError, APITimeoutError. Always retry
# without MCP when MCP was active, so a flaky external server never
# blocks the assistant entirely.
_is_mcp_error = _mcp_active and (
"MCP server" in str(e)
or "mcp" in type(e).__name__.lower()
or isinstance(e, (anthropic.BadRequestError, anthropic.APIStatusError))
)
if _is_mcp_error:
logger.warning(
"MCP server error (%s), retrying without MCP: %s",
type(e).__name__, e,
)
response = await client.messages.create(
model=settings.AI_MODEL_ANTHROPIC,
max_tokens=max_tokens,
system=system_blocks,

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta
from typing import Optional, Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -103,13 +103,23 @@ async def start_conversation(
Returns (conversation, greeting_message).
"""
# Load tree
# Load tree — must be accessible to this account.
# Allows own account's trees, default trees, and public trees.
# Raises ValueError (caught by endpoint as 404) if not found or not accessible.
result = await db.execute(
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == tree_id)
select(Tree).options(selectinload(Tree.tags)).where(
Tree.id == tree_id,
or_(
Tree.account_id == account_id,
Tree.author_id == user_id,
Tree.is_default == True,
Tree.is_public == True,
),
)
)
tree = result.scalar_one_or_none()
if not tree:
raise ValueError(f"Tree {tree_id} not found")
raise ValueError(f"Tree {tree_id} not found or not accessible")
conversation = CopilotConversation(
user_id=user_id,

View File

@@ -1,151 +0,0 @@
"""AI service for generating network diagrams from natural language."""
import json
import logging
from app.core.ai_provider import get_ai_provider
from app.core.config import settings
from app.schemas.network_diagram import (
AIGenerateRequest,
AIGenerateResponse,
DiagramNode,
DiagramEdge,
DeviceProperties,
Position,
)
logger = logging.getLogger(__name__)
SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers.
Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble.
Return this exact structure:
{{
"nodes": [
{{
"id": "unique-string",
"type": "device-type-slug",
"label": "device label",
"position": {{ "x": number, "y": number }},
"properties": {{
"hostname": "string or null",
"ip": "string or null",
"subnet": "string or null",
"vendor": "string or null",
"model": "string or null",
"role": "string or null",
"vlan": "string or null",
"notes": "string or null",
"status": "unknown"
}}
}}
],
"edges": [
{{
"id": "unique-string",
"source": "node-id",
"target": "node-id",
"label": "connection label or null",
"connectionType": "ethernet|fiber|wifi|vpn|vlan|wan",
"speed": "string or null",
"notes": "string or null"
}}
],
"suggestedName": "short descriptive diagram name",
"notes": "any important assumptions or missing info, or null"
}}
Available device type slugs: {available_slugs}
Position nodes thoughtfully in a logical network topology layout.
Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y.
Place WAN/internet at top, core network in middle, endpoints at bottom.
{merge_instructions}"""
MERGE_INSTRUCTIONS = """
IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices.
The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}.
Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100).
You may create edges that connect new nodes to existing nodes if the description implies a connection.
Use these existing node IDs for connections: {existing_node_ids}"""
async def generate_diagram(
request: AIGenerateRequest,
available_slugs: list[str],
existing_node_ids: list[str] | None = None,
) -> AIGenerateResponse:
merge_instructions = ""
if request.mode == "merge" and request.existingBounds:
b = request.existingBounds
merge_instructions = MERGE_INSTRUCTIONS.format(
minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY,
existing_node_ids=", ".join(existing_node_ids or []),
)
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
available_slugs=", ".join(available_slugs),
merge_instructions=merge_instructions,
)
model = settings.get_model_for_action("network_diagram_generate")
provider = get_ai_provider(model)
messages = [{"role": "user", "content": request.description}]
response_text, input_tokens, output_tokens = await provider.generate_json(
system_prompt=system_prompt,
messages=messages,
max_tokens=4096,
)
logger.info(
"Network diagram AI generation: input_tokens=%d, output_tokens=%d",
input_tokens, output_tokens,
)
try:
data = json.loads(response_text)
except json.JSONDecodeError as e:
logger.error("Failed to parse AI response as JSON: %s", e)
raise ValueError("AI generated an invalid response, please try again")
try:
nodes = []
for raw_node in data.get("nodes", []):
node_type = raw_node.get("type", "server")
if node_type not in available_slugs:
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
node_type = "server"
nodes.append(DiagramNode(
id=raw_node["id"],
type=node_type,
label=raw_node.get("label", node_type),
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
properties=DeviceProperties(**{
k: v for k, v in raw_node.get("properties", {}).items()
if k in DeviceProperties.model_fields
}),
))
edges = []
for raw_edge in data.get("edges", []):
edges.append(DiagramEdge(
id=raw_edge["id"],
source=raw_edge["source"],
target=raw_edge["target"],
label=raw_edge.get("label"),
connectionType=raw_edge.get("connectionType", "ethernet"),
speed=raw_edge.get("speed"),
notes=raw_edge.get("notes"),
))
except KeyError as e:
logger.warning("AI response missing required field: %s", e)
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
return AIGenerateResponse(
nodes=nodes,
edges=edges,
suggestedName=data.get("suggestedName"),
notes=data.get("notes"),
)

View File

@@ -19,13 +19,7 @@ class ResolutionOutputGenerator:
def __init__(self, db: AsyncSession):
self.db = db
async def generate_all(
self,
session_id: UUID,
root_cause: str | None = None,
steps_taken: list[str] | None = None,
recommendations: str | None = None,
) -> list[SessionResolutionOutput]:
async def generate_all(self, session_id: UUID) -> list[SessionResolutionOutput]:
result = await self.db.execute(
select(AISession).where(AISession.id == session_id)
)
@@ -33,12 +27,7 @@ class ResolutionOutputGenerator:
if not session:
raise ValueError(f"Session {session_id} not found")
context = self._build_session_context(
session,
root_cause=root_cause,
steps_taken=steps_taken,
recommendations=recommendations,
)
context = self._build_session_context(session)
outputs = []
for output_type, prompt in [
@@ -93,13 +82,7 @@ class ResolutionOutputGenerator:
await self.db.flush()
return output
def _build_session_context(
self,
session: AISession,
root_cause: str | None = None,
steps_taken: list[str] | None = None,
recommendations: str | None = None,
) -> str:
def _build_session_context(self, session: AISession) -> str:
intake = session.intake_content or {}
intake_text = intake.get("text", "") or str(intake)
parts = [
@@ -109,31 +92,10 @@ class ResolutionOutputGenerator:
f"Resolution: {session.resolution_summary or 'Not specified'}",
]
# Structured handoff fields from cockpit conclude modal
if root_cause:
parts.append(f"Root cause: {root_cause}")
if steps_taken:
parts.append("Steps performed:")
for step in steps_taken:
parts.append(f" - {step}")
if recommendations:
parts.append(f"Recommendations: {recommendations}")
# Triage metadata (cockpit branch)
if getattr(session, 'client_name', None):
parts.append(f"Client: {session.client_name}")
if getattr(session, 'triage_hypothesis', None):
parts.append(f"Hypothesis: {session.triage_hypothesis}")
if getattr(session, 'evidence_items', None):
parts.append("Evidence collected:")
for item in session.evidence_items:
icon = {"confirmed": "", "ruled_out": "", "pending": "?"}.get(item.get("status", ""), "?")
parts.append(f" {icon} {item.get('text', '')}")
# Diagnostic steps from FlowPilot session steps
steps = sorted(session.steps or [], key=lambda s: s.step_order)
diagnostic = []
follow_ups: list[str] = []
for step in sorted(session.steps or [], key=lambda s: s.step_order):
for step in steps:
content = step.content or {}
step_type = content.get("type", "")
if step_type == "resolution_suggestion":

View File

@@ -133,13 +133,10 @@ def _parse_questions_marker(ai_content: str) -> tuple[str, list[dict[str, Any]]
valid_questions = []
for q in questions:
if isinstance(q, dict) and q.get("text"):
item = {
valid_questions.append({
"text": q["text"],
"context": q.get("context", ""),
}
if q.get("options") and isinstance(q["options"], list):
item["options"] = q["options"]
valid_questions.append(item)
})
if not valid_questions:
return ai_content, None
@@ -150,43 +147,6 @@ def _parse_questions_marker(ai_content: str) -> tuple[str, list[dict[str, Any]]
return cleaned, valid_questions
def _parse_triage_update_marker(ai_content: str) -> tuple[str, dict[str, Any] | None]:
"""Extract [TRIAGE_UPDATE]...[/TRIAGE_UPDATE] JSON from AI response.
Returns (cleaned_content, triage_update_dict_or_None).
The marker is stripped from display text.
"""
match = re.search(r'\[TRIAGE_UPDATE\]\s*([\s\S]*?)\s*\[/TRIAGE_UPDATE\]', ai_content)
if not match:
return ai_content, None
try:
raw = match.group(1).strip()
if raw.startswith("```"):
raw = re.sub(r'^```(?:json)?\s*', '', raw)
raw = re.sub(r'\s*```$', '', raw)
triage = json.loads(raw)
except (json.JSONDecodeError, ValueError) as e:
logger.warning("Failed to parse [TRIAGE_UPDATE] marker: %s", e)
return ai_content, None
if not isinstance(triage, dict):
logger.warning("Invalid [TRIAGE_UPDATE] data — expected object")
return ai_content, None
# Only keep recognized fields
valid_fields = {"client_name", "asset_name", "issue_category", "triage_hypothesis", "evidence_items"}
filtered = {k: v for k, v in triage.items() if k in valid_fields and v is not None}
if not filtered:
return ai_content, None
cleaned = ai_content[:match.start()] + ai_content[match.end():]
cleaned = cleaned.strip()
return cleaned, filtered
async def create_chat_session(
user_id: UUID,
account_id: UUID,
@@ -223,14 +183,14 @@ async def send_chat_message(
message: str,
db: AsyncSession,
images: list[dict[str, Any]] | None = None,
) -> tuple[str, list[dict[str, Any]], AISession, dict[str, Any] | None, list[dict[str, Any]] | None, list[dict[str, Any]] | None, dict[str, Any] | None]:
) -> tuple[str, list[dict[str, Any]], AISession, dict[str, Any] | None, list[dict[str, Any]] | None, list[dict[str, Any]] | None]:
"""Send a message in a chat session and get AI response.
Args:
images: Optional list of {"media_type": str, "data": str (base64)}
for vision content attached to this message.
Returns (ai_content, suggested_flows, session, fork_metadata, actions_data, questions_data, triage_update_data).
Returns (ai_content, suggested_flows, session, fork_metadata, actions_data, questions_data).
"""
result = await db.execute(
select(AISession).where(
@@ -277,8 +237,10 @@ async def send_chat_message(
ai_content, input_tokens, output_tokens = await _call_ai(**prompt_args)
# Update branch conversation
# Strip _(not yet completed)_ markers before storage (same reason as main path)
stored_message = message.replace("_(not yet completed)_", "(pending)").replace("_(skipped)_", "(skipped)")
msgs = list(branch.conversation_messages or [])
msgs.append({"role": "user", "content": message})
msgs.append({"role": "user", "content": stored_message})
msgs.append({"role": "assistant", "content": ai_content})
branch.conversation_messages = msgs
@@ -293,19 +255,6 @@ async def send_chat_message(
branch_display, branch_fork_data = _parse_fork_marker(ai_content)
branch_display, branch_actions_data = _parse_actions_marker(branch_display)
branch_display, branch_questions_data = _parse_questions_marker(branch_display)
branch_display, branch_triage_data = _parse_triage_update_marker(branch_display)
# Auto-PATCH triage from branch response
if branch_triage_data:
for field in ("client_name", "asset_name", "issue_category", "triage_hypothesis"):
if field in branch_triage_data and getattr(session, field) is None:
setattr(session, field, branch_triage_data[field])
new_evidence = branch_triage_data.get("evidence_items")
if new_evidence and isinstance(new_evidence, list):
existing = list(session.evidence_items or [])
existing.extend(new_evidence)
session.evidence_items = existing
if branch_display != ai_content:
# Store stripped content in branch history
msgs[-1] = {"role": "assistant", "content": branch_display}
@@ -339,17 +288,19 @@ async def send_chat_message(
except Exception:
logger.exception("Failed to create fork within branch for session %s", session.id)
# Persist task lane state on session — only overwrite when new markers present
# Persist task lane state on session
if branch_questions_data or branch_actions_data:
session.pending_task_lane = {
"questions": branch_questions_data or [],
"actions": branch_actions_data or [],
}
else:
session.pending_task_lane = None
suggested_flows = extract_suggested_flows(
await rag_search(query=message, account_id=account_id, db=db, limit=8)
)
return branch_display, suggested_flows, session, branch_fork_metadata, branch_actions_data, branch_questions_data, branch_triage_data
return branch_display, suggested_flows, session, branch_fork_metadata, branch_actions_data, branch_questions_data
# Auto-title from first message if still default
if session.step_count == 0 and message.strip():
@@ -392,34 +343,23 @@ async def send_chat_message(
# Check for questions marker in AI response
display_content, questions_data = _parse_questions_marker(display_content)
# Check for triage update marker in AI response
display_content, triage_update_data = _parse_triage_update_marker(display_content)
logger.info(
"Marker parsing results — actions: %s, questions: %s, fork: %s, triage: %s, raw_length: %d, display_length: %d",
bool(actions_data), bool(questions_data), bool(fork_data), bool(triage_update_data),
"Marker parsing results — actions: %s, questions: %s, fork: %s, raw_length: %d, display_length: %d",
bool(actions_data), bool(questions_data), bool(fork_data),
len(ai_content), len(display_content),
)
# Auto-PATCH session with triage metadata if AI inferred new fields
if triage_update_data:
# Apply non-evidence fields directly (AI only fills null fields — manual edits win)
for field in ("client_name", "asset_name", "issue_category", "triage_hypothesis"):
if field in triage_update_data and getattr(session, field) is None:
setattr(session, field, triage_update_data[field])
# Append new evidence items (never modify existing)
new_evidence = triage_update_data.get("evidence_items")
if new_evidence and isinstance(new_evidence, list):
existing = list(session.evidence_items or [])
existing.extend(new_evidence)
session.evidence_items = existing
# Store DISPLAY content (markers stripped) in conversation_messages.
# The format reminder in the user message + system prompt final reminder
# are sufficient to keep the AI emitting markers on subsequent turns.
#
# Strip _(not yet completed)_ task markers from the stored user message.
# The AI processes them correctly on the current turn, but persisting them
# into history causes the AI to re-inject stale task lane items from prior
# turns — even across unrelated topics in a long session.
stored_message = message.replace("_(not yet completed)_", "(pending)").replace("_(skipped)_", "(skipped)")
msgs = list(session.conversation_messages or [])
msgs.append({"role": "user", "content": message})
msgs.append({"role": "user", "content": stored_message})
msgs.append({"role": "assistant", "content": display_content})
session.conversation_messages = msgs
session.step_count += 2 # message count for display
@@ -470,13 +410,15 @@ async def send_chat_message(
logger.exception("Failed to create fork for session %s", session_id)
# Fork failed but chat message still sent — don't break the response
# Persist task lane state on session — only overwrite when new markers present
# Persist task lane state on session
if questions_data or actions_data:
session.pending_task_lane = {
"questions": questions_data or [],
"actions": actions_data or [],
}
else:
session.pending_task_lane = None
suggested_flows = extract_suggested_flows(rag_results)
return display_content, suggested_flows, session, fork_metadata, actions_data, questions_data, triage_update_data
return display_content, suggested_flows, session, fork_metadata, actions_data, questions_data

View File

@@ -0,0 +1,91 @@
"""
Tenant filter enforcement check.
Scans endpoint and service files for SQLAlchemy select() calls on known
tenant tables and warns when account_id or tenant_filter is not present
in the surrounding 15 lines (the typical extent of a single query).
Usage:
python scripts/check_tenant_filters.py # warn mode (exits 0)
python scripts/check_tenant_filters.py --fail # block mode (exits 1 on findings)
"""
import re
import sys
from pathlib import Path
# Tables that must always be filtered by account_id or tenant_filter.
# Extend this list as new tenant tables are added.
TENANT_MODELS = [
"Tree", "AISession", "Session", "StepLibrary", "FlowProposal",
"CopilotConversation", "AssistantChat", "FileUpload", "KBImport",
"PsaConnection", "PsaPostLog", "PsaMemberMapping", "AIChatSession",
"AIConversation", "AIUsage", "Subscription", "AccountInvite",
"Notification", "NotificationConfig", "SessionShare", "UserFolder",
"UserPinnedTree", "SessionBranch", "SessionHandoff",
"SessionResolutionOutput", "ForkPoint", "AISessionStep",
"AISuggestion", "StepCategory", "TreeCategory", "TreeTag",
"Attachment", "SessionSupportingData", "MaintenanceSchedule",
"AuditLog", "ScriptBuilderSession", "ScriptTemplate",
"StepRating", "StepUsageLog", "TargetList",
]
# Directories to scan
SCAN_DIRS = [
Path("app/api/endpoints"),
Path("app/services"),
]
# Patterns that indicate the query is correctly scoped.
# NOTE: user_id scoping is accepted for user-owned resources (sessions, folders, notifications).
# For account-shared resources (trees, steps, etc.) use tenant_filter or account_id.
SAFE_PATTERNS = [
r"tenant_filter",
r"account_id",
r"user_id", # User-scoped resources (sessions, folders, notifications, etc.)
r"is_super_admin", # Super admin queries intentionally bypass tenant filter
r"# cross-tenant: approved", # Explicit approval comment
]
SKIP_FILES = {
"admin.py", # Super admin endpoints intentionally bypass tenant filter
"admin_gallery.py", # Gallery management — super admin only, no tenant scoping needed
"public_templates.py",# Public template browser — intentionally cross-tenant
"auth.py", # Auth/registration — no account context during login/register
"ratings.py", # Session ratings — user-scoped via session lookup chain
}
findings = []
for scan_dir in SCAN_DIRS:
if not scan_dir.exists():
continue
for path in sorted(scan_dir.glob("*.py")):
if path.name in SKIP_FILES:
continue
lines = path.read_text().splitlines()
for i, line in enumerate(lines):
for model in TENANT_MODELS:
if re.search(rf"\bselect\s*\(\s*{model}\b", line):
# Check surrounding 15 lines for a safe pattern
start = max(0, i - 2)
end = min(len(lines), i + 15)
context = "\n".join(lines[start:end])
if not any(re.search(p, context) for p in SAFE_PATTERNS):
findings.append(
f"{path}:{i + 1}: select({model}) — no tenant_filter or account_id found in context"
)
if findings:
print(f"\n⚠ Tenant filter check — {len(findings)} warning(s):\n")
for f in findings:
print(f" {f}")
print()
if "--fail" in sys.argv:
print("Run with --fail: exiting 1")
sys.exit(1)
else:
print("Run in warn mode — not blocking. Pass --fail to block.")
sys.exit(0)
else:
print("✓ Tenant filter check passed — no unscoped tenant table queries found.")
sys.exit(0)

View File

@@ -19,116 +19,8 @@ class TestAdminEndpoints:
"/api/v1/admin/users", headers=admin_auth_headers
)
assert response.status_code == 200
payload = response.json()
assert payload["total"] >= 2 # admin + test_user
assert len(payload["items"]) >= 2
@pytest.mark.asyncio
async def test_list_users_supports_search(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test admin people search by user email."""
response = await client.get(
"/api/v1/admin/users",
params={"search": test_user["email"]},
headers=admin_auth_headers,
)
assert response.status_code == 200
payload = response.json()
assert payload["total"] >= 1
assert any(item["email"] == test_user["email"] for item in payload["items"])
@pytest.mark.asyncio
async def test_list_accounts_as_admin(
self, client: AsyncClient, admin_auth_headers: dict
):
"""Test listing accounts with member data."""
response = await client.get(
"/api/v1/admin/accounts", headers=admin_auth_headers
)
assert response.status_code == 200
payload = response.json()
assert payload["total"] >= 1
assert len(payload["items"]) >= 1
assert "members" in payload["items"][0]
assert "subscription" in payload["items"][0]
@pytest.mark.asyncio
async def test_create_account_as_admin(
self, client: AsyncClient, admin_auth_headers: dict
):
"""Test creating an empty account from admin."""
response = await client.post(
"/api/v1/admin/accounts",
json={"name": "Acme Customer", "plan": "pro"},
headers=admin_auth_headers,
)
assert response.status_code == 201
payload = response.json()
assert payload["name"] == "Acme Customer"
assert payload["subscription"]["plan"] == "pro"
assert payload["display_code"]
@pytest.mark.asyncio
async def test_get_account_detail_as_admin(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test fetching account detail for management view."""
account_id = test_user["user_data"]["account_id"]
response = await client.get(
f"/api/v1/admin/accounts/{account_id}",
headers=admin_auth_headers,
)
assert response.status_code == 200
payload = response.json()
assert payload["id"] == account_id
assert "members" in payload
assert "invites" in payload
@pytest.mark.asyncio
async def test_update_account_name_as_admin(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test renaming an account from admin detail view."""
account_id = test_user["user_data"]["account_id"]
response = await client.put(
f"/api/v1/admin/accounts/{account_id}",
json={"name": "Renamed Customer Account"},
headers=admin_auth_headers,
)
assert response.status_code == 200
payload = response.json()
assert payload["id"] == account_id
assert payload["name"] == "Renamed Customer Account"
@pytest.mark.asyncio
async def test_update_account_plan(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test changing an account's subscription plan."""
account_id = test_user["user_data"]["account_id"]
response = await client.put(
f"/api/v1/admin/accounts/{account_id}/subscription/plan",
json={"plan": "pro"},
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json()["plan"] == "pro"
@pytest.mark.asyncio
async def test_extend_account_trial(
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
):
"""Test starting or extending an account trial."""
account_id = test_user["user_data"]["account_id"]
response = await client.put(
f"/api/v1/admin/accounts/{account_id}/subscription/extend-trial",
json={"days": 14},
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json()["status"] == "trialing"
assert response.json()["current_period_end"] is not None
users = response.json()
assert len(users) >= 2 # admin + test_user
@pytest.mark.asyncio
async def test_list_users_as_non_admin(

View File

@@ -1,107 +0,0 @@
"""Integration tests for feature flag resolution endpoint."""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
async def _seed_feature_flag(db: AsyncSession, flag_key: str, display_name: str):
"""Insert a feature flag and return its id."""
result = await db.execute(
text(
"INSERT INTO feature_flags (id, flag_key, display_name) "
"VALUES (gen_random_uuid(), :key, :name) RETURNING id"
),
{"key": flag_key, "name": display_name},
)
await db.commit()
return result.scalar_one()
async def _seed_plan_default(db: AsyncSession, flag_id, plan: str, enabled: bool):
"""Insert a plan default for a flag."""
await db.execute(
text(
"INSERT INTO plan_feature_defaults (id, plan, flag_id, enabled) "
"VALUES (gen_random_uuid(), :plan, :flag_id, :enabled)"
),
{"plan": plan, "flag_id": flag_id, "enabled": enabled},
)
await db.commit()
async def _seed_account_override(db: AsyncSession, flag_id, account_id, enabled: bool):
"""Insert an account override for a flag."""
await db.execute(
text(
"INSERT INTO account_feature_overrides (id, account_id, flag_id, enabled) "
"VALUES (gen_random_uuid(), :account_id, :flag_id, :enabled)"
),
{"account_id": account_id, "flag_id": flag_id, "enabled": enabled},
)
await db.commit()
async def _get_account_id(db: AsyncSession, user_id: str):
"""Get account_id for a user."""
result = await db.execute(
text("SELECT account_id FROM users WHERE id = :uid"),
{"uid": user_id},
)
return result.scalar_one()
class TestFeatureFlagResolution:
"""Tests for GET /auth/me/feature-flags."""
@pytest.mark.asyncio
async def test_no_flags_returns_empty(self, client: AsyncClient, auth_headers: dict):
"""When no flags exist, returns empty dict."""
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
assert response.status_code == 200
assert response.json() == {}
@pytest.mark.asyncio
async def test_plan_default_resolves(
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
):
"""Flag with plan default for 'free' plan resolves correctly."""
flag_id = await _seed_feature_flag(test_db, "test_feature", "Test Feature")
await _seed_plan_default(test_db, flag_id, "free", True)
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
assert response.status_code == 200
assert response.json()["test_feature"] is True
@pytest.mark.asyncio
async def test_no_plan_default_resolves_false(
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
):
"""Flag with no plan default for user's plan resolves to false."""
flag_id = await _seed_feature_flag(test_db, "pro_only", "Pro Only")
await _seed_plan_default(test_db, flag_id, "pro", True)
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
assert response.status_code == 200
assert response.json()["pro_only"] is False
@pytest.mark.asyncio
async def test_account_override_beats_plan_default(
self, client: AsyncClient, auth_headers: dict, test_user: dict, test_db: AsyncSession
):
"""Account override takes precedence over plan default."""
flag_id = await _seed_feature_flag(test_db, "overridden", "Overridden Flag")
await _seed_plan_default(test_db, flag_id, "free", False)
account_id = await _get_account_id(test_db, test_user["user_data"]["id"])
await _seed_account_override(test_db, flag_id, account_id, True)
response = await client.get("/api/v1/auth/me/feature-flags", headers=auth_headers)
assert response.status_code == 200
assert response.json()["overridden"] is True
@pytest.mark.asyncio
async def test_unauthenticated_returns_401(self, client: AsyncClient):
"""Unauthenticated request returns 401."""
response = await client.get("/api/v1/auth/me/feature-flags")
assert response.status_code == 401

View File

@@ -0,0 +1,545 @@
"""Phase 1 migration tests — verify account_id backfill correctness.
These tests create objects via ORM (which uses the updated models),
then verify account_id is populated correctly. They run against a
real PostgreSQL test DB (same as all other integration tests).
"""
import pytest
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
from app.models.account import Account
from app.models.user import User
from app.models.tree import Tree
from app.models.session import Session
from app.models.attachment import Attachment
from app.models.supporting_data import SessionSupportingData
from app.models.session_resolution_output import SessionResolutionOutput
from app.models.ai_session import AISession
from app.core.security import get_password_hash
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _make_account_and_user(db: AsyncSession, suffix: str) -> tuple[Account, User]:
account = Account(name=f"Corp {suffix}", display_code=uuid.uuid4().hex[:8])
db.add(account)
await db.flush()
user = User(
email=f"user-{suffix}-{uuid.uuid4().hex[:6]}@example.com",
name=f"User {suffix}",
password_hash=get_password_hash("TestPass123!"),
is_active=True,
account_id=account.id,
account_role="engineer",
)
db.add(user)
await db.flush()
return account, user
async def _make_tree(db: AsyncSession, account: Account, user: User) -> Tree:
tree = Tree(
name=f"Tree {uuid.uuid4().hex[:6]}",
account_id=account.id,
author_id=user.id,
visibility="team",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
db.add(tree)
await db.flush()
return tree
async def _make_session(db: AsyncSession, account: Account, user: User, tree: Tree) -> Session:
s = Session(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
tree_snapshot={},
)
db.add(s)
await db.flush()
return s
# ── Group 1: Core sessions ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_account_id_matches_user(test_db: AsyncSession):
"""sessions.account_id must equal the user's account_id."""
account, user = await _make_account_and_user(test_db, "s1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
await test_db.commit()
result = await test_db.execute(select(Session).where(Session.id == session.id))
row = result.scalar_one()
assert row.account_id == account.id, f"Expected {account.id}, got {row.account_id}"
@pytest.mark.asyncio
async def test_attachment_account_id_matches_session(test_db: AsyncSession):
"""attachments.account_id must match the parent session's account_id."""
account, user = await _make_account_and_user(test_db, "att1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
attachment = Attachment(
session_id=session.id,
account_id=account.id,
file_name="test.png",
file_type="image/png",
)
test_db.add(attachment)
await test_db.commit()
result = await test_db.execute(select(Attachment).where(Attachment.id == attachment.id))
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_supporting_data_account_id(test_db: AsyncSession):
"""session_supporting_data.account_id must match parent session's account_id."""
account, user = await _make_account_and_user(test_db, "sd1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
sd = SessionSupportingData(
session_id=session.id,
account_id=account.id,
label="Log snippet",
data_type="text_snippet",
content="error: connection refused",
)
test_db.add(sd)
await test_db.commit()
result = await test_db.execute(
select(SessionSupportingData).where(SessionSupportingData.id == sd.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_session_resolution_output_account_id(test_db: AsyncSession):
"""session_resolution_outputs.account_id must match the parent ai_session's account_id.
NOTE: session_resolution_outputs.session_id FK points to ai_sessions (not sessions).
"""
account, user = await _make_account_and_user(test_db, "sro1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test resolution output",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
output = SessionResolutionOutput(
session_id=ai_session.id,
account_id=account.id,
output_type="psa_ticket_notes",
generated_content="Ticket notes content",
generated_by_model="gpt-4",
)
test_db.add(output)
await test_db.commit()
result = await test_db.execute(
select(SessionResolutionOutput).where(SessionResolutionOutput.id == output.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 2: AI & branching ───────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_session_branch_account_id_matches_ai_session(test_db: AsyncSession):
"""session_branches.account_id must match parent ai_session.account_id."""
from app.models.session_branch import SessionBranch
account, user = await _make_account_and_user(test_db, "sb1")
ai_session = AISession(
user_id=user.id,
account_id=account.id,
problem_summary="test",
problem_domain="networking",
status="active",
)
test_db.add(ai_session)
await test_db.flush()
branch = SessionBranch(
session_id=ai_session.id,
account_id=account.id,
label="Branch A",
branch_order=1,
conversation_messages=[],
)
test_db.add(branch)
await test_db.commit()
result = await test_db.execute(
select(SessionBranch).where(SessionBranch.id == branch.id)
)
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_ai_suggestion_account_id_matches_user(test_db: AsyncSession):
"""ai_suggestions.account_id must match the creating user's account_id."""
from app.models.ai_suggestion import AISuggestion
account, user = await _make_account_and_user(test_db, "ais1")
tree = await _make_tree(test_db, account, user)
suggestion = AISuggestion(
tree_id=tree.id,
user_id=user.id,
account_id=account.id,
action_type="add_node",
changes_json={},
status="pending",
)
test_db.add(suggestion)
await test_db.commit()
result = await test_db.execute(
select(AISuggestion).where(AISuggestion.id == suggestion.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 3: Steps & ratings ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_step_rating_account_id_is_rater_account(test_db: AsyncSession):
"""step_ratings.account_id must be the RATER's account, not the step's account."""
from app.models.step_library import StepLibrary, StepRating
account_a, user_a = await _make_account_and_user(test_db, "sr-rater")
account_b, user_b = await _make_account_and_user(test_db, "sr-step-owner")
# Step owned by account_b
step = StepLibrary(
title="A step",
step_type="action",
content={"text": "do something"},
created_by=user_b.id,
account_id=account_b.id,
visibility="public",
)
test_db.add(step)
await test_db.flush()
# user_a (account_a) rates the step
rating = StepRating(
step_id=step.id,
user_id=user_a.id,
account_id=account_a.id, # rater's account, not step owner's
was_helpful=True,
is_verified_use=False,
is_visible=True,
)
test_db.add(rating)
await test_db.commit()
result = await test_db.execute(select(StepRating).where(StepRating.id == rating.id))
row = result.scalar_one()
assert row.account_id == account_a.id, (
f"account_id should be rater's account ({account_a.id}), got {row.account_id}"
)
@pytest.mark.asyncio
async def test_step_usage_log_account_id_is_logger_account(test_db: AsyncSession):
"""step_usage_log.account_id must be the LOGGER's account (user who used the step)."""
from app.models.step_library import StepLibrary, StepUsageLog
account, user = await _make_account_and_user(test_db, "sul1")
tree = await _make_tree(test_db, account, user)
session = await _make_session(test_db, account, user, tree)
step = StepLibrary(
title="A usage step",
step_type="action",
content={"text": "do something"},
created_by=user.id,
account_id=account.id,
visibility="team",
)
test_db.add(step)
await test_db.flush()
log = StepUsageLog(
step_id=step.id,
user_id=user.id,
account_id=account.id,
session_id=session.id,
)
test_db.add(log)
await test_db.commit()
result = await test_db.execute(select(StepUsageLog).where(StepUsageLog.id == log.id))
row = result.scalar_one()
assert row.account_id == account.id, (
f"account_id should be logger's account ({account.id}), got {row.account_id}"
)
# ── Group 4: User personalization ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_user_folder_account_id_matches_user(test_db: AsyncSession):
"""user_folders.account_id must match the owning user's account_id."""
from app.models.folder import UserFolder
account, user = await _make_account_and_user(test_db, "uf1")
folder = UserFolder(
user_id=user.id,
account_id=account.id,
name="My Folder",
color="#6366f1",
icon="folder",
display_order=0,
)
test_db.add(folder)
await test_db.commit()
result = await test_db.execute(select(UserFolder).where(UserFolder.id == folder.id))
row = result.scalar_one()
assert row.account_id == account.id
@pytest.mark.asyncio
async def test_user_pinned_tree_account_id_matches_user(test_db: AsyncSession):
"""user_pinned_trees.account_id must match the pinning user's account_id."""
from app.models.user_pinned_tree import UserPinnedTree
account, user = await _make_account_and_user(test_db, "pt1")
tree = await _make_tree(test_db, account, user)
pin = UserPinnedTree(
user_id=user.id,
tree_id=tree.id,
account_id=account.id,
display_order=0,
)
test_db.add(pin)
await test_db.commit()
result = await test_db.execute(select(UserPinnedTree).where(UserPinnedTree.id == pin.id))
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 5: PSA & notifications ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_psa_member_mapping_account_id_matches_connection(test_db: AsyncSession):
"""psa_member_mappings.account_id must match psa_connection's account_id."""
from app.models.psa_connection import PsaConnection
from app.models.psa_member_mapping import PsaMemberMapping
account, user = await _make_account_and_user(test_db, "psa1")
conn = PsaConnection(
account_id=account.id,
provider="connectwise",
display_name="Test CW",
site_url="https://cw.example.com",
company_id="TEST",
credentials_encrypted="placeholder",
)
test_db.add(conn)
await test_db.flush()
mapping = PsaMemberMapping(
psa_connection_id=conn.id,
user_id=user.id,
account_id=account.id,
external_member_id="cw-123",
external_member_name="Test User",
matched_by="manual_admin",
)
test_db.add(mapping)
await test_db.commit()
result = await test_db.execute(
select(PsaMemberMapping).where(PsaMemberMapping.id == mapping.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 6: Maintenance ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_maintenance_schedule_account_id_matches_tree(test_db: AsyncSession):
"""maintenance_schedules.account_id must match the tree's account_id."""
from app.models.maintenance_schedule import MaintenanceSchedule
account, user = await _make_account_and_user(test_db, "ms1")
tree = Tree(
name="Maintenance Flow",
account_id=account.id,
author_id=user.id,
visibility="team",
tree_type="maintenance",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
test_db.add(tree)
await test_db.flush()
schedule = MaintenanceSchedule(
tree_id=tree.id,
account_id=account.id,
created_by=user.id,
cron_expression="0 9 * * 1",
timezone="UTC",
is_active=True,
)
test_db.add(schedule)
await test_db.commit()
result = await test_db.execute(
select(MaintenanceSchedule).where(MaintenanceSchedule.id == schedule.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 7: Legacy team_id tables ───────────────────────────────────────────
@pytest.mark.asyncio
async def test_script_builder_session_account_id(test_db: AsyncSession):
"""script_builder_sessions.account_id must match user's account_id."""
from app.models.script_builder_session import ScriptBuilderSession
account, user = await _make_account_and_user(test_db, "sbs1")
sbs = ScriptBuilderSession(
user_id=user.id,
account_id=account.id,
language="powershell",
)
test_db.add(sbs)
await test_db.commit()
result = await test_db.execute(
select(ScriptBuilderSession).where(ScriptBuilderSession.id == sbs.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 8: TargetList ────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
"""target_lists.account_id must be set to the team admin's account_id."""
from app.models.target_list import TargetList
from app.models.team import Team
account, user = await _make_account_and_user(test_db, "tl1")
# Make user a team admin
team = Team(name=f"Team {uuid.uuid4().hex[:6]}")
test_db.add(team)
await test_db.flush()
user.team_id = team.id
user.is_team_admin = True
await test_db.flush()
target_list = TargetList(
team_id=team.id,
account_id=account.id,
created_by=user.id,
name="Server Targets",
targets=[{"label": "SRV-01"}],
)
test_db.add(target_list)
await test_db.commit()
result = await test_db.execute(
select(TargetList).where(TargetList.id == target_list.id)
)
row = result.scalar_one()
assert row.account_id == account.id
# ── Group 10 (runs first): Global content tables ──────────────────────────────
@pytest.mark.asyncio
async def test_template_trees_table_exists_and_has_no_account_id(test_db: AsyncSession):
"""template_trees must exist and must NOT have an account_id column."""
result = await test_db.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'template_trees'
"""))
columns = {row[0] for row in result.fetchall()}
assert 'id' in columns, "template_trees.id must exist"
assert 'account_id' not in columns, "template_trees must not have account_id (global content)"
@pytest.mark.asyncio
async def test_platform_steps_table_exists_and_has_no_account_id(test_db: AsyncSession):
"""platform_steps must exist and must NOT have an account_id column."""
result = await test_db.execute(text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'platform_steps'
"""))
columns = {row[0] for row in result.fetchall()}
assert 'id' in columns, "platform_steps.id must exist"
assert 'account_id' not in columns, "platform_steps must not have account_id (global content)"
# ── Group 9: SET NOT NULL on existing nullable columns ────────────────────────
@pytest.mark.asyncio
async def test_tree_account_id_is_not_null(test_db: AsyncSession):
"""trees.account_id must be NOT NULL after Phase 1 — enforced at DB level."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(Tree(
name="Bad tree",
# account_id intentionally omitted
author_id=None,
visibility="private",
tree_type="troubleshooting",
tree_structure={},
is_active=True,
status="draft",
))
await test_db.flush()
@pytest.mark.asyncio
async def test_user_account_id_is_not_null(test_db: AsyncSession):
"""users.account_id must be NOT NULL after Phase 1."""
from sqlalchemy.exc import IntegrityError
with pytest.raises(IntegrityError):
test_db.add(User(
email=f"orphan-{uuid.uuid4().hex[:6]}@example.com",
name="Orphan",
password_hash=get_password_hash("x"),
is_active=True,
role="engineer",
account_role="engineer",
# account_id intentionally omitted
))
await test_db.flush()

View File

@@ -0,0 +1,266 @@
# backend/tests/test_rls_isolation.py
"""
RLS foundation tests.
Connect directly as resolutionflow_app (not superuser) and verify:
- Tenant A cannot read Tenant B's rows
- No tenant context set → zero rows for private data (fail-closed)
- Platform rows (PLATFORM_ACCOUNT_ID) are visible to all tenants
Tests bypass FastAPI entirely — raw asyncpg connections only.
MUST FAIL before Task 10 (RLS migration) and PASS after it.
Run with:
DB_APP_ROLE_PASSWORD=app_secret_change_me pytest tests/test_rls_isolation.py -v
The test DB is patherly_test (matches conftest.py default).
"""
import os
import uuid
import asyncpg
import pytest
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
_APP_PASSWORD = os.getenv("DB_APP_ROLE_PASSWORD", "app_secret_change_me")
_ADMIN_DSN = f"postgresql://postgres:postgres@{_DB_HOST}:{_DB_PORT}/{_DB_NAME}"
PLATFORM_ACCOUNT_ID = "00000000-0000-0000-0000-000000000001"
ACCOUNT_A_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
async def admin_conn():
"""Superuser asyncpg connection for fixture setup and teardown."""
conn = await asyncpg.connect(_ADMIN_DSN)
yield conn
await conn.close()
@pytest.fixture(scope="module", autouse=True)
async def seed_rls_test_data(admin_conn):
"""
Create two isolated test accounts, one user per account, and one private
tree per account. Trees require a valid author_id FK to users, so users
must be created first.
accounts.display_code must be unique and 8 chars (NOT NULL constraint).
"""
# Insert accounts
await admin_conn.execute(f"""
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
VALUES
('{ACCOUNT_A_ID}', 'RLS Tenant A', 'RLSA0001', NOW(), NOW()),
('{ACCOUNT_B_ID}', 'RLS Tenant B', 'RLSB0001', NOW(), NOW())
ON CONFLICT (id) DO NOTHING
""")
# Insert one user per account (users.account_id NOT NULL, password_hash NOT NULL)
user_a_id = str(uuid.uuid4())
user_b_id = str(uuid.uuid4())
await admin_conn.execute(f"""
INSERT INTO users (
id, email, password_hash, name, role, is_active, account_id,
account_role, created_at
) VALUES
('{user_a_id}', 'rls-user-a@example.com',
'placeholder', 'RLS User A', 'engineer', TRUE,
'{ACCOUNT_A_ID}', 'engineer', NOW()),
('{user_b_id}', 'rls-user-b@example.com',
'placeholder', 'RLS User B', 'engineer', TRUE,
'{ACCOUNT_B_ID}', 'engineer', NOW())
ON CONFLICT (email) DO NOTHING
""")
# Look up the user IDs we just inserted (ON CONFLICT may have skipped)
row_a = await admin_conn.fetchrow(
"SELECT id FROM users WHERE email = 'rls-user-a@example.com'"
)
row_b = await admin_conn.fetchrow(
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
)
actual_user_a = str(row_a["id"])
actual_user_b = str(row_b["id"])
# Insert one private tree per account with explicit author_id
await admin_conn.execute(f"""
INSERT INTO trees (
id, name, tree_structure, account_id, author_id, is_active, is_default,
is_public, visibility, tree_type, created_at, updated_at
) VALUES
(gen_random_uuid(), 'RLS Tree A', '[]'::jsonb, '{ACCOUNT_A_ID}', '{actual_user_a}',
TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW()),
(gen_random_uuid(), 'RLS Tree B', '[]'::jsonb, '{ACCOUNT_B_ID}', '{actual_user_b}',
TRUE, FALSE, FALSE, 'private', 'troubleshooting', NOW(), NOW())
""")
# One platform-owned tree_tag (global, visible to all tenants)
await admin_conn.execute(f"""
INSERT INTO tree_tags (
id, name, slug, account_id, usage_count, created_at
) VALUES (
gen_random_uuid(), 'rls-global-tag', 'rls-global-tag',
'{PLATFORM_ACCOUNT_ID}', 0, NOW()
) ON CONFLICT DO NOTHING
""")
yield
# Cleanup
await admin_conn.execute(
f"DELETE FROM trees WHERE account_id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')"
)
await admin_conn.execute(
"DELETE FROM users WHERE email IN "
"('rls-user-a@example.com', 'rls-user-b@example.com')"
)
await admin_conn.execute(
f"DELETE FROM accounts WHERE id IN ('{ACCOUNT_A_ID}', '{ACCOUNT_B_ID}')"
)
await admin_conn.execute("DELETE FROM tree_tags WHERE slug = 'rls-global-tag'")
@pytest.fixture
async def conn_a():
"""App-role connection, tenant context = Account A."""
conn = await asyncpg.connect(
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
user="resolutionflow_app", password=_APP_PASSWORD,
)
await conn.execute(
"SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_A_ID
)
yield conn
await conn.close()
@pytest.fixture
async def conn_b():
"""App-role connection, tenant context = Account B."""
conn = await asyncpg.connect(
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
user="resolutionflow_app", password=_APP_PASSWORD,
)
await conn.execute(
"SELECT set_config('app.current_account_id', $1, false)", ACCOUNT_B_ID
)
yield conn
await conn.close()
@pytest.fixture
async def conn_no_context():
"""App-role connection with NO tenant context set."""
conn = await asyncpg.connect(
host=_DB_HOST, port=_DB_PORT, database=_DB_NAME,
user="resolutionflow_app", password=_APP_PASSWORD,
)
yield conn
await conn.close()
# ---------------------------------------------------------------------------
# trees
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0, "Account A should not see Account B trees"
@pytest.mark.asyncio
async def test_trees_account_a_can_see_own_rows(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
)
assert len(rows) >= 1, "Account A should see its own trees"
@pytest.mark.asyncio
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
rows = await conn_no_context.fetch(
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
)
assert len(rows) == 0, "No-context connection should see no private trees"
# ---------------------------------------------------------------------------
# tree_tags — platform visibility
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0
@pytest.mark.asyncio
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
rows_a = await conn_a.fetch(
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
)
rows_b = await conn_b.fetch(
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
)
assert len(rows_a) >= 1, "Account A should see platform tags"
assert len(rows_b) >= 1, "Account B should see platform tags"
# ---------------------------------------------------------------------------
# tree_categories — platform visibility
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0
# ---------------------------------------------------------------------------
# step_categories — platform visibility
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0
# ---------------------------------------------------------------------------
# psa_connections — tenant-only
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0
# ---------------------------------------------------------------------------
# flow_proposals — tenant-only
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
rows = await conn_a.fetch(
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
)
assert len(rows) == 0

View File

@@ -1,4 +1,6 @@
"""Integration tests for Script Template Editor permissions and share endpoint."""
from uuid import UUID as PyUUID
import pytest
from httpx import AsyncClient
from sqlalchemy import select
@@ -65,6 +67,9 @@ class TestScriptTemplatePermissions:
data = resp.json()
assert data["name"] == "Test Template"
assert data["created_by"] is not None
result = await test_db.execute(select(ScriptTemplate).where(ScriptTemplate.id == PyUUID(data["id"])))
template = result.scalar_one()
assert template.account_id is not None
@pytest.mark.asyncio
async def test_engineer_can_edit_own_template(self, client, auth_headers, test_db):

View File

@@ -6,14 +6,18 @@ from datetime import datetime, timezone
import pytest
import sqlalchemy as sa
from app.models.script_template import ScriptGeneration
from app.models.user import User
# ── Fixtures ──────────────────────────────────────────────────────────────
@pytest.fixture
async def seed_script_data(test_db):
async def seed_script_data(test_db, test_user):
"""Seed script categories and templates into the test database."""
now = datetime.now(timezone.utc)
cat_id = uuid.UUID("00000000-0000-0000-0000-000000000001")
user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"]))
user = user_result.scalar_one()
# Insert category
await test_db.execute(
@@ -142,20 +146,20 @@ async def seed_script_data(test_db):
await test_db.execute(
sa.text("""
INSERT INTO script_templates (
id, category_id, name, slug, description,
id, category_id, account_id, name, slug, description,
script_body, parameters_schema, default_values, validation_rules,
tags, complexity, estimated_runtime, requires_elevation,
requires_modules, version, is_verified, is_active, usage_count,
created_at, updated_at
) VALUES (
:id, :category_id, :name, :slug, :description,
:id, :category_id, :account_id, :name, :slug, :description,
:script_body, CAST(:parameters_schema AS jsonb), '{}'::jsonb, '{}'::jsonb,
CAST(:tags AS jsonb), :complexity, :estimated_runtime, :requires_elevation,
'[]'::jsonb, 1, true, true, 0,
:now, :now
)
"""),
{**tmpl, "category_id": cat_id, "now": now},
{**tmpl, "category_id": cat_id, "account_id": user.account_id, "now": now},
)
await test_db.commit()
@@ -245,7 +249,7 @@ async def test_get_template_detail_not_found(client, auth_headers):
# ── Generate ──────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_generate_script_success(client, auth_headers, seed_script_data):
async def test_generate_script_success(client, auth_headers, seed_script_data, test_db, test_user):
list_resp = await client.get(
"/api/v1/scripts/templates?search=unlock",
headers=auth_headers,
@@ -265,6 +269,13 @@ async def test_generate_script_success(client, auth_headers, seed_script_data):
assert "script" in data
assert "jsmith" in data["script"]
assert "id" in data
generation_result = await test_db.execute(
sa.select(ScriptGeneration).where(ScriptGeneration.id == uuid.UUID(data["id"]))
)
generation = generation_result.scalar_one()
user_result = await test_db.execute(sa.select(User).where(User.email == test_user["email"]))
user = user_result.scalar_one()
assert generation.account_id == user.account_id
@pytest.mark.asyncio

View File

@@ -0,0 +1,58 @@
import asyncio
from uuid import UUID
import pytest
from unittest.mock import MagicMock
from app.core.tenant_context import (
set_current_account_id,
clear_current_account_id,
get_current_account_id,
)
def test_contextvar_is_none_by_default():
assert get_current_account_id() is None
def test_set_and_clear():
account_id = UUID("aaaaaaaa-0000-0000-0000-000000000001")
token = set_current_account_id(account_id)
assert get_current_account_id() == str(account_id)
clear_current_account_id(token)
assert get_current_account_id() is None
def test_tasks_are_isolated():
"""Each asyncio task has its own ContextVar value."""
results = {}
async def set_in_task(name: str, value: str):
token = set_current_account_id(UUID(value))
await asyncio.sleep(0)
results[name] = get_current_account_id()
clear_current_account_id(token)
async def run():
await asyncio.gather(
set_in_task("a", "aaaaaaaa-0000-0000-0000-000000000001"),
set_in_task("b", "bbbbbbbb-0000-0000-0000-000000000002"),
)
asyncio.run(run())
assert results["a"] == "aaaaaaaa-0000-0000-0000-000000000001"
assert results["b"] == "bbbbbbbb-0000-0000-0000-000000000002"
@pytest.mark.asyncio
async def test_require_tenant_context_raises_403_when_no_account():
from fastapi import HTTPException
from app.api.deps import require_tenant_context
user = MagicMock()
user.account_id = None
gen = require_tenant_context(current_user=user)
with pytest.raises(HTTPException) as exc_info:
await gen.__anext__()
assert exc_info.value.status_code == 403
assert "account required" in exc_info.value.detail.lower()

View File

@@ -0,0 +1,578 @@
"""Phase 0 tenant-isolation tests.
Verifies that endpoints respect account boundaries and don't leak data
across tenants. Each task group tests a specific endpoint fix.
"""
import uuid
import datetime as dt
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.account import Account
from app.models.user import User
from app.models.tree import Tree
from app.core.security import get_password_hash
# ── Helpers ──────────────────────────────────────────────────────────────────
async def _create_account_and_user(db: AsyncSession, prefix: str):
"""Create a fresh account + engineer user. Returns (account, user, plain_password)."""
password = "TestPass123!"
account = Account(
name=f"{prefix}-corp",
display_code=uuid.uuid4().hex[:8],
)
db.add(account)
await db.flush()
user = User(
email=f"{prefix}-{uuid.uuid4().hex[:6]}@example.com",
name=f"{prefix} user",
password_hash=get_password_hash(password),
is_active=True,
account_id=account.id,
account_role="engineer",
)
db.add(user)
await db.flush()
return account, user, password
async def _login(client: AsyncClient, email: str, password: str) -> dict:
"""Log in and return Authorization headers."""
resp = await client.post(
"/api/v1/auth/login",
json={"email": email, "password": password},
)
assert resp.status_code == 200, f"Login failed: {resp.text}"
token = resp.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
async def _create_private_tree(db: AsyncSession, account: Account, user: User) -> Tree:
"""Create a private tree owned by the given account/user."""
tree = Tree(
name=f"Private Tree {uuid.uuid4().hex[:6]}",
account_id=account.id,
author_id=user.id,
visibility="private",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
db.add(tree)
await db.flush()
return tree
# ── Task 3: Analytics flow endpoint ──────────────────────────────────────────
@pytest.mark.asyncio
async def test_analytics_flow_cannot_read_other_account_tree(
client: AsyncClient, test_db: AsyncSession
):
"""Account A cannot read flow analytics for Account B's private tree."""
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "anl-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "anl-b")
tree_b = await _create_private_tree(test_db, acct_b, user_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(
f"/api/v1/analytics/flows/{tree_b.id}",
headers=headers_a,
)
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}: {resp.text}"
# ── Task 4: Category tree count ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_category_tree_count_scoped_to_account(
client: AsyncClient, test_db: AsyncSession
):
"""tree_count on a category must not include trees from other accounts."""
from app.models.category import TreeCategory
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "cat-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "cat-b")
# Shared category (account_id=None means global)
category = TreeCategory(
name="Shared Category",
slug=f"shared-cat-{uuid.uuid4().hex[:6]}",
account_id=None,
is_active=True,
)
test_db.add(category)
await test_db.flush()
# 3 trees for account_b under this category
for i in range(3):
tree = Tree(
name=f"B Tree {i}",
account_id=acct_b.id,
author_id=user_b.id,
category_id=category.id,
visibility="team",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
test_db.add(tree)
# 1 tree for account_a under this category
tree_a = Tree(
name="A Tree",
account_id=acct_a.id,
author_id=user_a.id,
category_id=category.id,
visibility="team",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
test_db.add(tree_a)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(
f"/api/v1/categories/{category.id}",
headers=headers_a,
)
assert resp.status_code == 200, resp.text
# account_a should only see their 1 tree, not account_b's 3
assert resp.json()["tree_count"] == 1, (
f"Expected tree_count=1 (own trees only), got {resp.json()['tree_count']}"
)
# ── Task 5: AI session search scope ──────────────────────────────────────────
@pytest.mark.asyncio
async def test_ai_session_search_cannot_see_other_users_sessions(
client: AsyncClient, test_db: AsyncSession
):
"""User A cannot find User B's AI sessions via the search endpoint,
even when both users are in the same account."""
from app.models.ai_session import AISession
# Two users in the SAME account
account = Account(name="Shared Corp", display_code=uuid.uuid4().hex[:8])
test_db.add(account)
await test_db.flush()
password = "TestPass123!"
user_a = User(
email=f"user-a-{uuid.uuid4().hex[:6]}@shared.com",
name="User A",
password_hash=get_password_hash(password),
is_active=True,
account_id=account.id,
account_role="engineer",
)
user_b = User(
email=f"user-b-{uuid.uuid4().hex[:6]}@shared.com",
name="User B",
password_hash=get_password_hash(password),
is_active=True,
account_id=account.id,
account_role="engineer",
)
test_db.add_all([user_a, user_b])
await test_db.flush()
# Session belonging to user_b with distinctive problem_summary
session_b = AISession(
user_id=user_b.id,
account_id=account.id,
problem_summary="CONFIDENTIAL: user_b's session",
problem_domain="networking",
status="resolved",
)
test_db.add(session_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, password)
resp = await client.get(
"/api/v1/ai-sessions/search",
params={"q": "CONFIDENTIAL"},
headers=headers_a,
)
assert resp.status_code == 200, resp.text
results = resp.json()
ids = [r["id"] for r in results]
assert str(session_b.id) not in ids, (
"User A can see User B's session via search — cross-user leak within account"
)
# ── Task 6: Cross-tenant UUID audit ─────────────────────────────────────────
@pytest.mark.asyncio
async def test_get_tree_returns_404_not_403_for_other_account(
client: AsyncClient, test_db: AsyncSession
):
"""Account A gets 404 (not 403) when accessing Account B's private tree."""
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-tree-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-tree-b")
tree_b = await _create_private_tree(test_db, acct_b, user_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/trees/{tree_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-tenant tree access, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_update_tree_returns_404_not_403_for_other_account(
client: AsyncClient, test_db: AsyncSession
):
"""Account A gets 404 (not 403) when trying to update Account B's tree."""
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-upd-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-upd-b")
tree_b = await _create_private_tree(test_db, acct_b, user_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.put(
f"/api/v1/trees/{tree_b.id}",
json={"name": "Hacked"},
headers=headers_a,
)
assert resp.status_code == 404, (
f"Expected 404 for cross-tenant tree update, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_get_session_returns_404_not_403_for_other_user(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 (not 403) when accessing User B's session."""
from app.models.session import Session
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-sess-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-sess-b")
tree_b = await _create_private_tree(test_db, acct_b, user_b)
session_b = Session(
tree_id=tree_b.id,
user_id=user_b.id,
tree_snapshot={"id": "root", "type": "start", "children": []},
path_taken=[],
decisions=[],
)
test_db.add(session_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/sessions/{session_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-user session access, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_ai_session_get_returns_404_not_403_for_other_user(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 (not 403) when accessing User B's AI session."""
from app.models.ai_session import AISession
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-ais-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-ais-b")
ai_session_b = AISession(
user_id=user_b.id,
account_id=acct_b.id,
problem_summary="Test session",
problem_domain="networking",
status="active",
)
test_db.add(ai_session_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/ai-sessions/{ai_session_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-user AI session access, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_ai_session_retry_psa_push_requires_ownership(
client: AsyncClient, test_db: AsyncSession
):
"""User A cannot retry PSA push for User B's AI session."""
from app.models.ai_session import AISession
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-psa-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-psa-b")
ai_session_b = AISession(
user_id=user_b.id,
account_id=acct_b.id,
problem_summary="PSA test",
problem_domain="networking",
status="resolved",
)
test_db.add(ai_session_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.post(
f"/api/v1/ai-sessions/{ai_session_b.id}/retry-psa-push",
headers=headers_a,
)
assert resp.status_code == 404, (
f"Expected 404 for cross-user retry-psa-push, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_upload_url_returns_404_not_403_for_other_account(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 (not 403) when accessing User B's upload URL."""
from app.models.file_upload import FileUpload
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-upl-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-upl-b")
upload_b = FileUpload(
account_id=acct_b.id,
uploaded_by=user_b.id,
filename="secret.png",
content_type="image/png",
size_bytes=1024,
storage_key="test/secret.png",
)
test_db.add(upload_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/uploads/{upload_b.id}/url", headers=headers_a)
assert resp.status_code in (404, 503), (
f"Expected 404 (or 503 if storage not configured) for cross-account upload, got {resp.status_code}"
)
@pytest.mark.asyncio
async def test_share_revoke_returns_404_not_403_for_other_user(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 (not 403) when revoking User B's share."""
from app.models.session import Session
from app.models.session_share import SessionShare
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-shr-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-shr-b")
tree_b = await _create_private_tree(test_db, acct_b, user_b)
session_b = Session(
tree_id=tree_b.id,
user_id=user_b.id,
tree_snapshot={"id": "root", "type": "start", "children": []},
path_taken=[],
decisions=[],
)
test_db.add(session_b)
await test_db.flush()
share_b = SessionShare(
session_id=session_b.id,
account_id=acct_b.id,
share_token="test-token-unique-" + uuid.uuid4().hex[:8],
share_name="Test",
visibility="public",
created_by=user_b.id,
)
test_db.add(share_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.delete(f"/api/v1/shares/{share_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-user share revoke, got {resp.status_code}"
)
# ── Task 6 (continued): steps, tags, step_categories, maintenance_schedules ──
@pytest.mark.asyncio
async def test_cannot_access_other_account_step(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 when reading a team-visibility step owned by Account B."""
from app.models.step_library import StepLibrary
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-step-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-step-b")
# Create a team-visibility step owned by account B
step_b = StepLibrary(
title="Account B Confidential Step",
step_type="action",
content={"description": "secret step"},
created_by=user_b.id,
account_id=acct_b.id,
visibility="team",
is_active=True,
)
test_db.add(step_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/steps/{step_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-account step access, got {resp.status_code}: {resp.text}"
)
@pytest.mark.asyncio
async def test_cannot_access_other_account_tag(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 when reading a tag scoped to Account B."""
from app.models.tag import TreeTag
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-tag-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-tag-b")
# Create an account-scoped tag for account B
tag_b = TreeTag(
name=f"account-b-tag-{uuid.uuid4().hex[:6]}",
slug=f"account-b-tag-{uuid.uuid4().hex[:6]}",
account_id=acct_b.id,
)
test_db.add(tag_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/tags/{tag_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-account tag access, got {resp.status_code}: {resp.text}"
)
@pytest.mark.asyncio
async def test_cannot_access_other_account_step_category(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 when reading a step category scoped to Account B."""
from app.models.step_category import StepCategory
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-scat-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-scat-b")
# Create an account-scoped step category for account B
category_b = StepCategory(
name=f"Account B Category {uuid.uuid4().hex[:6]}",
slug=f"account-b-cat-{uuid.uuid4().hex[:6]}",
account_id=acct_b.id,
is_active=True,
)
test_db.add(category_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/step-categories/{category_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-account step category access, got {resp.status_code}: {resp.text}"
)
@pytest.mark.asyncio
async def test_maintenance_schedule_returns_404_for_other_team(
client: AsyncClient, test_db: AsyncSession
):
"""User A gets 404 when reading a maintenance schedule belonging to Team B's tree."""
from app.models.team import Team
from app.models.maintenance_schedule import MaintenanceSchedule
# Create two separate teams
team_a = Team(name="Team A Corp")
team_b = Team(name="Team B Corp")
test_db.add_all([team_a, team_b])
await test_db.flush()
# Create accounts and users, assign to respective teams
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "t6-ms-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "t6-ms-b")
user_a.team_id = team_a.id
user_b.team_id = team_b.id
await test_db.flush()
# Create a maintenance tree owned by team B
tree_b = Tree(
name="Team B Maintenance Flow",
account_id=acct_b.id,
author_id=user_b.id,
team_id=team_b.id,
visibility="team",
tree_type="maintenance",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
status="published",
)
test_db.add(tree_b)
await test_db.flush()
# Create a schedule for that tree
schedule_b = MaintenanceSchedule(
tree_id=tree_b.id,
created_by=user_b.id,
cron_expression="0 2 * * 0",
timezone="UTC",
is_active=True,
next_run_at=dt.datetime(2026, 12, 31, tzinfo=dt.timezone.utc),
)
test_db.add(schedule_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(f"/api/v1/maintenance-schedules/tree/{tree_b.id}", headers=headers_a)
assert resp.status_code == 404, (
f"Expected 404 for cross-team maintenance schedule access, got {resp.status_code}: {resp.text}"
)
@pytest.mark.asyncio
async def test_get_documentation_returns_404_for_other_user_session(
client: AsyncClient, test_db: AsyncSession
):
"""GET /ai-sessions/{id}/documentation must return 404 (not 403) for cross-user access."""
from app.models.ai_session import AISession
acct_a, user_a, pass_a = await _create_account_and_user(test_db, "doc-a")
acct_b, user_b, pass_b = await _create_account_and_user(test_db, "doc-b")
session_b = AISession(
user_id=user_b.id,
account_id=acct_b.id,
problem_summary="B's confidential session",
problem_domain="networking",
status="resolved",
)
test_db.add(session_b)
await test_db.commit()
headers_a = await _login(client, user_a.email, pass_a)
resp = await client.get(
f"/api/v1/ai-sessions/{session_b.id}/documentation",
headers=headers_a,
)
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}: {resp.text}"

View File

@@ -447,3 +447,55 @@ class TestVisibilityFilter:
assert "author_name" in trees[0]
# visibility key should be present
assert "visibility" in trees[0]
@pytest.mark.asyncio
async def test_get_tree_returns_404_not_403_for_other_account_tree(
self, client: AsyncClient, auth_headers: dict, test_db: AsyncSession
):
"""Account A must not learn that Account B's private tree exists."""
from app.models.tree import Tree
from app.models.account import Account
from app.models.user import User
from app.core.security import get_password_hash
import uuid
# Create a second account and user
account_b = Account(name="Other Corp", display_code="OTH00001")
test_db.add(account_b)
await test_db.flush()
user_b = User(
email=f"user-b-{uuid.uuid4().hex[:6]}@example.com",
name="User B",
password_hash=get_password_hash("TestPass123!"),
is_active=True,
account_id=account_b.id,
account_role="engineer",
)
test_db.add(user_b)
await test_db.flush()
# Create a private tree belonging to account_b
private_tree = Tree(
name="Secret Tree",
account_id=account_b.id,
author_id=user_b.id,
visibility="private",
tree_type="troubleshooting",
tree_structure={"id": "root", "type": "start", "children": []},
is_active=True,
is_default=False,
is_public=False,
status="published",
)
test_db.add(private_tree)
await test_db.commit()
response = await client.get(
f"/api/v1/trees/{private_tree.id}",
headers=auth_headers,
)
assert response.status_code == 404, (
f"Expected 404 but got {response.status_code}"
"leaking tree existence to wrong tenant"
)

Some files were not shown because too many files have changed in this diff Show More