Compare commits
101 Commits
docs/updat
...
66968e4c59
| Author | SHA1 | Date | |
|---|---|---|---|
| 66968e4c59 | |||
| b0622f5511 | |||
| f3c3ee5b57 | |||
| b49772f1a1 | |||
| 210d310fb2 | |||
| 92fadfb90a | |||
| 3f0a132058 | |||
| da93ae55c3 | |||
| 56fd440b16 | |||
| b3be66652e | |||
| 0fbc1e0a57 | |||
| 46291f30b9 | |||
| f0ccf313a4 | |||
| 0d9babb986 | |||
| 567985402f | |||
| 08a4c6600d | |||
| 29fa48e71b | |||
| 908a867986 | |||
| 346576a730 | |||
| b18072e24b | |||
| e0f44e2985 | |||
| adfbb39297 | |||
| 6bae205a8c | |||
| ee2b2c2399 | |||
| 37bc47b75b | |||
| c8bdd0014e | |||
| 2a2b770405 | |||
| d6d0e9f3c1 | |||
| ab4bf3b32f | |||
|
|
d3c93cd006 | ||
|
|
4037a5213e | ||
|
|
0ed5977fee | ||
|
|
c5b8229ef6 | ||
|
|
eba50e1f95 | ||
|
|
8eb814283d | ||
|
|
b433b232dc | ||
|
|
015df1fe5f | ||
|
|
cf9c258f9e | ||
|
|
c063952f12 | ||
|
|
36721eb5af | ||
|
|
3cd4084f78 | ||
|
|
ed763d1cea | ||
|
|
c37e216e0b | ||
|
|
91cc9a4170 | ||
|
|
2a4220b496 | ||
|
|
c8f571db39 | ||
|
|
7efa22454d | ||
|
|
05421fc65c | ||
|
|
dfcad531e2 | ||
|
|
684fb07e47 | ||
|
|
4a12c9b37d | ||
|
|
e41d7bd960 | ||
|
|
f2c3bd7a9b | ||
|
|
9786c6b1fb | ||
|
|
4529955f7d | ||
|
|
b7b0d41f92 | ||
|
|
a4512dcf90 | ||
|
|
764db79060 | ||
|
|
f90e2c956f | ||
|
|
bdaea68dd3 | ||
|
|
02c19a7580 | ||
|
|
a392d24101 | ||
|
|
b9c9bb548d | ||
|
|
662df2907d | ||
|
|
b9547e6ce1 | ||
|
|
760e0f77f8 | ||
|
|
a71f082e25 | ||
|
|
abd79bc763 | ||
|
|
af5ceea7f9 | ||
|
|
f54d7ecd78 | ||
|
|
46593ba8ca | ||
|
|
52553d62d2 | ||
|
|
a48660700a | ||
|
|
3ff886363c | ||
|
|
501442e5f0 | ||
|
|
6f53ec06f5 | ||
|
|
ec322f7cdf | ||
|
|
f9248aeaa8 | ||
|
|
c6da4ebee5 | ||
|
|
64f004a62c | ||
|
|
ba36e37dab | ||
|
|
9e6965512b | ||
|
|
893b8a5008 | ||
|
|
e05472615b | ||
|
|
00fdd663bc | ||
|
|
8cf58add22 | ||
|
|
6c231ef1c6 | ||
|
|
758cd61621 | ||
|
|
b9fcdd5d73 | ||
|
|
4273ed0e5c | ||
|
|
0107d2d896 | ||
|
|
79ae34108a | ||
|
|
bd29f590a2 | ||
|
|
ce4cfc3240 | ||
|
|
82ee177d9b | ||
|
|
ed8de92c52 | ||
|
|
5bd331ca92 | ||
|
|
87fac02e9b | ||
|
|
4f4bc435da | ||
|
|
ac2b193909 | ||
|
|
b641ac6c55 |
154
.gitea/workflows/ci.yml
Normal file
154
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,154 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
backend:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg16
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: resolutionflow_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
env:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@postgres:5432/resolutionflow_test
|
||||
DATABASE_URL_SYNC: postgresql://postgres:postgres@postgres:5432/resolutionflow_test
|
||||
SECRET_KEY: ci-test-secret-key-not-for-production
|
||||
DEBUG: "true"
|
||||
APP_NAME: ResolutionFlow
|
||||
TEST_DB_NAME: resolutionflow_test
|
||||
DB_APP_ROLE_PASSWORD: app_secret_ci
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install --break-system-packages -r backend/requirements.txt -r backend/requirements-dev.txt
|
||||
|
||||
- name: Run Alembic migrations
|
||||
run: cd backend && alembic upgrade head
|
||||
|
||||
- name: Check tenant filter enforcement
|
||||
run: cd backend && python scripts/check_tenant_filters.py
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: cd backend && python -m pytest --override-ini="addopts=" --cov=app --cov-report=term-missing --cov-report=json:coverage.json --cov-fail-under=50
|
||||
|
||||
- name: Display coverage summary
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
python -c "
|
||||
import json
|
||||
with open('coverage.json') as f:
|
||||
data = json.load(f)
|
||||
total = data['totals']['percent_covered_display']
|
||||
print(f'Total coverage: {total}%')
|
||||
print()
|
||||
print('Module coverage:')
|
||||
for fname, fdata in sorted(data['files'].items()):
|
||||
pct = fdata['summary']['percent_covered_display']
|
||||
if float(pct) < 80:
|
||||
print(f' WARNING {fname}: {pct}%')
|
||||
else:
|
||||
print(f' OK {fname}: {pct}%')
|
||||
"
|
||||
|
||||
frontend:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: cd frontend && npm ci
|
||||
|
||||
- name: Lint
|
||||
run: cd frontend && npm run lint
|
||||
|
||||
- name: Test with coverage
|
||||
run: cd frontend && npm run test:coverage
|
||||
|
||||
- name: Build
|
||||
run: cd frontend && NODE_OPTIONS="--max-old-space-size=4096" npm run build
|
||||
|
||||
- name: Upload build artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: frontend/dist
|
||||
retention-days: 1
|
||||
|
||||
e2e:
|
||||
needs: [frontend]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg16
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: resolutionflow_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
env:
|
||||
PLAYWRIGHT_DATABASE_URL: postgresql+asyncpg://postgres:postgres@postgres:5432/resolutionflow_test
|
||||
PLAYWRIGHT_DATABASE_URL_SYNC: postgresql://postgres:postgres@postgres:5432/resolutionflow_test
|
||||
PLAYWRIGHT_API_ORIGIN: http://127.0.0.1:8000
|
||||
PLAYWRIGHT_BASE_URL: http://127.0.0.1:4173
|
||||
PLAYWRIGHT_SECRET_KEY: ci-playwright-secret-key
|
||||
PLAYWRIGHT_TEST_EMAIL: teamadmin@resolutionflow.example.com
|
||||
PLAYWRIGHT_TEST_PASSWORD: TestPass123!
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install backend dependencies
|
||||
run: pip install --break-system-packages -r backend/requirements.txt -r backend/requirements-dev.txt
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: cd frontend && npm ci
|
||||
|
||||
- name: Download frontend build
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: frontend-dist
|
||||
path: frontend/dist
|
||||
|
||||
- name: Install Playwright browser
|
||||
run: cd frontend && npx playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright smoke tests
|
||||
run: cd frontend && npm run test:e2e
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: |
|
||||
frontend/playwright-report
|
||||
frontend/test-results
|
||||
if-no-files-found: ignore
|
||||
19
.gitea/workflows/mirror-to-github.yml
Normal file
19
.gitea/workflows/mirror-to-github.yml
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Mirror to GitHub
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
mirror:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Push to GitHub
|
||||
run: |
|
||||
cd /tmp
|
||||
git clone --mirror https://gitea.resolutionflow.com/chihlasm/resolutionflow.git repo
|
||||
cd repo
|
||||
git remote add github https://x-access-token:${{ secrets.GH_MIRROR_TOKEN }}@github.com/${{ secrets.GH_MIRROR_REPO }}
|
||||
git push github --all --force
|
||||
git push github --tags --force
|
||||
43
.gitea/workflows/runner-probe.yml
Normal file
43
.gitea/workflows/runner-probe.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Runner Probe
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
probe:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Runner labels and OS
|
||||
run: |
|
||||
echo "=== OS ==="
|
||||
uname -a
|
||||
cat /etc/os-release 2>/dev/null || true
|
||||
|
||||
- name: Python versions
|
||||
run: |
|
||||
echo "=== Python ==="
|
||||
which python3 && python3 --version || echo "python3 not found"
|
||||
which python && python --version || echo "python not found"
|
||||
ls /usr/bin/python* 2>/dev/null || true
|
||||
|
||||
- name: Node versions
|
||||
run: |
|
||||
echo "=== Node ==="
|
||||
which node && node --version || echo "node not found"
|
||||
which npm && npm --version || echo "npm not found"
|
||||
ls /usr/bin/node* 2>/dev/null || true
|
||||
ls ~/.nvm/versions/node/ 2>/dev/null || echo "no nvm versions"
|
||||
|
||||
- name: Docker
|
||||
run: |
|
||||
echo "=== Docker ==="
|
||||
which docker && docker --version || echo "docker not found"
|
||||
docker info 2>/dev/null | grep -E "Server Version|Operating System" || true
|
||||
|
||||
- name: User and home
|
||||
run: |
|
||||
echo "=== User ==="
|
||||
whoami
|
||||
echo "HOME=$HOME"
|
||||
echo "PATH=$PATH"
|
||||
5
.github/workflows/ci.yml
vendored
5
.github/workflows/ci.yml
vendored
@@ -31,6 +31,8 @@ jobs:
|
||||
SECRET_KEY: ci-test-secret-key-not-for-production
|
||||
DEBUG: "true"
|
||||
APP_NAME: ResolutionFlow
|
||||
TEST_DB_NAME: resolutionflow_test
|
||||
DB_APP_ROLE_PASSWORD: app_secret_ci
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -47,6 +49,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -r backend/requirements.txt -r backend/requirements-dev.txt
|
||||
|
||||
- name: Run Alembic migrations
|
||||
run: cd backend && alembic upgrade head
|
||||
|
||||
- name: Check tenant filter enforcement
|
||||
run: cd backend && python scripts/check_tenant_filters.py
|
||||
# Warn mode only (exits 0). Switch to --fail after Phase 1 backlog clears.
|
||||
|
||||
@@ -9,7 +9,9 @@ All notable changes to ResolutionFlow are documented here.
|
||||
- Recurring Issue Detection — client-specific pattern alerts (#60)
|
||||
- Step Feedback Flag — "This Step is Wrong" reporting (#58)
|
||||
- **Tenant Isolation Phase 0** — multi-tenant data isolation (#132) with app-layer filtering helpers (`tenant_filter()`, `get_tenant_context`), cross-tenant access audit (analytics, categories, AI sessions, trees), UUID endpoint isolation with 404 responses for unauthorized access, ownership checks on all sensitive operations, and CI grep gate for missing tenant filters
|
||||
- **Tenant Isolation Phase 1** — PostgreSQL Row-Level Security (RLS) enforcement across all core tables (trees, tags, categories, psa_connections, flow_proposals) with database role separation (`resolutionflow_app` for user operations, `resolutionflow_admin` with BYPASSRLS for admin endpoints), admin database engine isolation, tenant context via `ContextVar` with automatic transaction-scoped enforcement, `account_id` column backfill on 35+ tables (sessions, AI branching, PSA, notifications, scripts, targets, folders), global content separation via platform account, fresh-DB migration order fixes
|
||||
- **Tenant Isolation Phase 2** — PostgreSQL Row Level Security (RLS) on 11 session-related tables (ai_sessions, session_steps, session_tags, etc.), account_id NOT NULL enforcement on all write paths, Alembic migrations with dual-env support (Railway native vars + explicit DATABASE_URL_SYNC), RLS test coverage with cross-account isolation verification, migration CI/CD integration
|
||||
- **Tenant Isolation Phase 3** — RLS on audit_logs and tree_shares tables, cross-tenant session access for public shares (via get_admin_db), complete account_id propagation across PSA integration write paths, final RLS policy enforcement
|
||||
- **Tenant Isolation Phase 4** (#136) — RLS enforcement on all 31 remaining tables (users, trees, teams, integrations, scripts, categories, templates, surveys, etc.), BYPASSRLS session pattern for auth deps and background jobs, admin session factory for startup routines (service accounts, seed data), global table exclusions (platform_steps, template_trees, script_categories, accounts), RLS tests with complete cross-tenant isolation verification, proper tree_shares ownership checks using tree owner's account_id
|
||||
- **Script Library default view** — "All Scripts" tab now displays all accessible scripts (team + library)
|
||||
- **Session documentation overhaul** — reformatted PSA resolution/escalation notes with cleaner headers, inline engineer responses, decimal hour display (0.25 hrs), follow-up recommendations, and improved "What We Know" section from evidence items
|
||||
- **Client communication improvements** — new `request_info` audience type for client-facing information requests, improved status update and email draft prompts with per-context guidance
|
||||
@@ -24,7 +26,6 @@ All notable changes to ResolutionFlow are documented here.
|
||||
- **Assistant Chat session actions** — moved Pause/Resume/Close actions from action bar to page header for consistency with FlowPilot
|
||||
- **Design system token normalization** — unified FlowPilot, AssistantChat, and ScriptBuilder components to use consistent design tokens
|
||||
- **Tenant data boundaries** — all session and tree endpoints now return 404 (not 403) for cross-tenant access attempts to avoid confirming resource existence
|
||||
- **Admin database routing** — privileged operations (analytics, user management) now bypass RLS via dedicated admin engine
|
||||
|
||||
### Fixed
|
||||
- **CRITICAL: Copilot tree query isolation** (#131) — user could access any tree UUID if known, exposing full tree structure to AI. Now scoped to current account with 404 for inaccessible trees.
|
||||
@@ -33,6 +34,7 @@ All notable changes to ResolutionFlow are documented here.
|
||||
- **Category tree counts** — cross-tenant row count leakage via tree_count field in GET `/categories/{id}`. Now scoped to requesting account.
|
||||
- **PSA retry ownership check** — retry-psa-push had no ownership validation (CRITICAL). Now validates user ownership before allowing retry.
|
||||
- **Task Lane save operation** — invalid task_lane_item UUIDs returned 403 revealing existence. Now returns 404 and uses query-level filtering.
|
||||
- **Phase 4 RLS enforcement** — fixed auth deps, user-mutation endpoints, background jobs, and lifespan routines to use BYPASSRLS sessions for reading/writing tenant-isolated tables; fixed seed scripts to use ADMIN_DATABASE_URL; bootstrap service account now initializes correctly with proper BYPASSRLS context
|
||||
- Dark text rendering on blue accent step-number badges across all flow types
|
||||
- Script Library tab ownership filter now preserved across category and search changes
|
||||
- Race conditions in script builder session creation and slug generation
|
||||
@@ -43,7 +45,6 @@ All notable changes to ResolutionFlow are documented here.
|
||||
- Task Lane stale data when creating new chat or resuming from concluded session
|
||||
- Chat ref invalidation race condition between handleNewChat and async data loads
|
||||
- Images now properly display in chat message history instead of blank placeholders
|
||||
- Non-default, no-team trees now properly handled in global content migration
|
||||
|
||||
---
|
||||
|
||||
|
||||
116
CLAUDE.md
116
CLAUDE.md
@@ -222,10 +222,9 @@ docker exec -it resolutionflow_postgres psql -U postgres -d resolutionflow
|
||||
cd backend && pip install httpx && python -m scripts.seed_trees
|
||||
|
||||
# CI/CD debugging
|
||||
gh run list --limit 5 # Recent CI runs
|
||||
gh run view <id> --log-failed # Failed job logs
|
||||
gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusion}'
|
||||
# NEVER use `gh run watch` — it holds context open and burns tokens while waiting
|
||||
# CI runs on Gitea (gitea.resolutionflow.com), NOT GitHub Actions — gh run list will return nothing useful
|
||||
# Check CI status at: https://gitea.resolutionflow.com/chihlasm/resolutionflow/actions
|
||||
# `gh` CLI is still used for GitHub Issues/PRs (mirrored repo), not for CI runs
|
||||
```
|
||||
|
||||
### URLs
|
||||
@@ -375,6 +374,16 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
|
||||
|
||||
**106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`).
|
||||
|
||||
**107. Startup routines must use `_admin_session_factory()` after Phase 4 RLS:** Any code that runs at startup (lifespan, `ensure_service_account`, seed scripts) and touches tenant-isolated tables (`users`, etc.) must use `_admin_session_factory()` — not `get_db()`. Phase 4 enabled RLS on `users`; a tenant-scoped session has no `app.current_account_id` set at startup, so all queries return 0 rows or fail. `get_service_account_id` in `deps.py` is safe — it reads from `app.state` cached at startup, never hits the DB per-request.
|
||||
|
||||
**108. Tables with no `account_id` column (never add to RLS migrations):** `script_categories`, `platform_steps`, `template_trees`, `plan_feature_defaults`, `accounts` — global/platform tables documented with "No account_id. No RLS." in their model files. When writing RLS migrations, scan at the class level (check for `account_id: Mapped` within the class block), not the file level — multiple classes in one `.py` file can have different columns (e.g. `ScriptCategory` vs `ScriptTemplate` in `script_template.py`).
|
||||
|
||||
**109. `tree_shares.account_id` must equal `tree.account_id`, not the actor's account:** When creating a `TreeShare`, always use `account_id=tree.account_id` (tree owner's tenant). A super admin in tenant A sharing tenant B's tree must produce a share row in tenant B's RLS context — using `current_user.account_id` instead makes the share invisible to the tree owner after RLS is enforced.
|
||||
|
||||
**110. Backfill migrations for `account_id` require a service-code audit:** When a migration adds `account_id` to an existing model via backfill (nullable → backfill → NOT NULL), grep for ALL `ModelClass(` instantiation sites in service code and verify `account_id=` is passed. SQLAlchemy accepts `None` silently with no warning; Phase 4 RLS WITH CHECK only surfaces the problem at runtime as `InsufficientPrivilegeError: new row violates row-level security policy`. Fixed example: `AISessionStep` — all 5 creation sites in `flowpilot_engine.py` were missing `account_id` until April 2026.
|
||||
|
||||
**111. Global Axios interceptor fires before component `.catch()` — fix optional-data endpoints at the source:** The global 5xx handler in `client.ts` fires for ALL non-401 5xx responses, even when a component does `.catch(() => {})`. If an endpoint returns optional UI data (e.g., board filters, PSA config), return `[]` / `{}` on provider failure rather than raising 502. Silencing the error in the component is not enough — the toast appears anyway. See `list_boards` in `integrations.py` for the fixed pattern.
|
||||
|
||||
## RBAC & Permissions
|
||||
|
||||
- **Role hierarchy:** super_admin > team_admin > engineer > viewer
|
||||
@@ -444,6 +453,7 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
|
||||
- Always include `Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>`
|
||||
- Always create feature branch BEFORE committing: `git checkout -b feat/feature-name`
|
||||
- Large features: commit per phase with `npm run build` validation
|
||||
- **Remote is Gitea, not GitHub directly:** Push to `gitea.resolutionflow.com/chihlasm/resolutionflow`. Gitea auto-mirrors to GitHub via `.gitea/workflows/mirror-to-github.yml` — never push directly to GitHub.
|
||||
|
||||
### After Completing Work
|
||||
|
||||
@@ -491,7 +501,7 @@ When a feature, fix, or significant piece of work is finished and merged/committ
|
||||
## Deployment (Railway)
|
||||
|
||||
- **Production:** `resolutionflow.com` (frontend), `api.resolutionflow.com` (backend)
|
||||
- Auto-deploys on push to `main`
|
||||
- Auto-deploys via: push to Gitea → Gitea mirrors to GitHub → Railway watches GitHub `main` and deploys
|
||||
- PR environments auto-created (need manual domain generation in Railway dashboard)
|
||||
- PR envs need `VITE_API_URL` set with `https://` prefix on frontend service
|
||||
- `ALLOW_RAILWAY_ORIGINS=true` enables CORS for `*.up.railway.app`
|
||||
@@ -519,104 +529,42 @@ When a feature, fix, or significant piece of work is finished and merged/committ
|
||||
| Design System | [DESIGN-SYSTEM.md](DESIGN-SYSTEM.md) |
|
||||
| Dev Environment | [DEV-ENV.md](DEV-ENV.md) — 46.202.92.250 setup, Docker, CORS, networking |
|
||||
|
||||
|
||||
<!-- gitnexus:start -->
|
||||
# GitNexus — Code Intelligence
|
||||
|
||||
This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||
This project is indexed by GitNexus as **resolutionflow**. Use it selectively — for routine additive work (new endpoints, new components, isolated fixes) just read the files directly. GitNexus earns its cost when you're about to touch something genuinely central with many callers.
|
||||
|
||||
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
|
||||
|
||||
## Always Do
|
||||
## When to Use It
|
||||
|
||||
- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user.
|
||||
- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows.
|
||||
- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits.
|
||||
- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance.
|
||||
- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`.
|
||||
**Use GitNexus when:**
|
||||
- Touching a core shared symbol with many callers — `flowpilot_engine`, `unified_chat_service`, auth middleware, `get_db`, shared hooks
|
||||
- Renaming anything used across multiple files
|
||||
- Tracing an unfamiliar bug through a call chain you haven't read
|
||||
- Assessing whether a refactor is safe before starting
|
||||
|
||||
## When Debugging
|
||||
**Skip GitNexus when:**
|
||||
- Adding a new endpoint, component, or isolated feature
|
||||
- Fixing a bug in a self-contained file
|
||||
- Making changes you can already see the full scope of by reading the file
|
||||
|
||||
1. `gitnexus_query({query: "<error or symptom>"})` — find execution flows related to the issue
|
||||
2. `gitnexus_context({name: "<suspect function>"})` — see all callers, callees, and process participation
|
||||
3. `READ gitnexus://repo/resolutionflow/process/{processName}` — trace the full execution flow step by step
|
||||
4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed
|
||||
|
||||
## When Refactoring
|
||||
|
||||
- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`.
|
||||
- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code.
|
||||
- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed.
|
||||
|
||||
## Never Do
|
||||
|
||||
- NEVER edit a function, class, or method without first running `gitnexus_impact` on it.
|
||||
- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis.
|
||||
- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph.
|
||||
- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope.
|
||||
|
||||
## Tools Quick Reference
|
||||
## Useful Tools
|
||||
|
||||
| Tool | When to use | Command |
|
||||
|------|-------------|---------|
|
||||
| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` |
|
||||
| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` |
|
||||
| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` |
|
||||
| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` |
|
||||
| `query` | Find code by concept when you don't know where to look | `gitnexus_query({query: "auth validation"})` |
|
||||
| `context` | See all callers/callees of a symbol before touching it | `gitnexus_context({name: "symbolName"})` |
|
||||
| `impact` | Blast radius check before editing a shared symbol | `gitnexus_impact({target: "X", direction: "upstream"})` |
|
||||
| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` |
|
||||
| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` |
|
||||
|
||||
## Impact Risk Levels
|
||||
|
||||
| Depth | Meaning | Action |
|
||||
|-------|---------|--------|
|
||||
| d=1 | WILL BREAK — direct callers/importers | MUST update these |
|
||||
| d=2 | LIKELY AFFECTED — indirect deps | Should test |
|
||||
| d=3 | MAY NEED TESTING — transitive | Test if critical path |
|
||||
|
||||
## Resources
|
||||
|
||||
| Resource | Use for |
|
||||
|----------|---------|
|
||||
| `gitnexus://repo/resolutionflow/context` | Codebase overview, check index freshness |
|
||||
| `gitnexus://repo/resolutionflow/clusters` | All functional areas |
|
||||
| `gitnexus://repo/resolutionflow/processes` | All execution flows |
|
||||
| `gitnexus://repo/resolutionflow/process/{name}` | Step-by-step execution trace |
|
||||
|
||||
## Self-Check Before Finishing
|
||||
|
||||
Before completing any code modification task, verify:
|
||||
1. `gitnexus_impact` was run for all modified symbols
|
||||
2. No HIGH/CRITICAL risk warnings were ignored
|
||||
3. `gitnexus_detect_changes()` confirms changes match expected scope
|
||||
4. All d=1 (WILL BREAK) dependents were updated
|
||||
|
||||
## Keeping the Index Fresh
|
||||
|
||||
After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it:
|
||||
A PostToolUse hook re-indexes automatically after `git commit`. To manually refresh:
|
||||
|
||||
```bash
|
||||
npx gitnexus analyze
|
||||
```
|
||||
|
||||
If the index previously included embeddings, preserve them by adding `--embeddings`:
|
||||
|
||||
```bash
|
||||
npx gitnexus analyze --embeddings
|
||||
```
|
||||
|
||||
To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.**
|
||||
|
||||
> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`.
|
||||
|
||||
## CLI
|
||||
|
||||
| Task | Read this skill file |
|
||||
|------|---------------------|
|
||||
| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` |
|
||||
| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` |
|
||||
| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` |
|
||||
| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` |
|
||||
| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` |
|
||||
| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` |
|
||||
|
||||
<!-- gitnexus:end -->
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
> **Purpose:** Quick-reference file showing exactly where the project stands.
|
||||
> **For Claude Code:** Read this first to understand what's done and what's next.
|
||||
> **Last Updated:** March 23, 2026
|
||||
> **Last Updated:** April 12, 2026
|
||||
|
||||
---
|
||||
|
||||
@@ -163,6 +163,13 @@
|
||||
- SQL wildcard escaping in tag search
|
||||
- PSA credentials encrypted at rest (Fernet)
|
||||
|
||||
### Tenant Isolation (Phases 1-4 Complete)
|
||||
- PostgreSQL RLS enabled across tenant-scoped tables in phased rollout
|
||||
- `account_id` propagation completed across core content, sessions, analytics, notifications, shares, and remaining Phase 4 tables
|
||||
- Global platform tables correctly excluded from tenant RLS where they have no `account_id` (`script_categories`, `platform_steps`, `template_trees`)
|
||||
- Runtime bootstrap paths updated to use BYPASSRLS/admin sessions where needed (auth/user mutations, startup service account, background jobs, seed scripts)
|
||||
- Preview Railway backend and frontend deployments green for PR 136 after the Phase 4 fixes
|
||||
|
||||
### Copilot-First Dashboard (March 2026)
|
||||
|
||||
- Redesigned dashboard as FlowPilot copilot launchpad (ChatGPT-style input)
|
||||
|
||||
633
DEV-ENV.md
633
DEV-ENV.md
@@ -1,262 +1,523 @@
|
||||
# ResolutionFlow Dev Environment Setup & Operations Guide
|
||||
# ResolutionFlow — Dev Environment Setup & Operations Guide
|
||||
|
||||
## Server Overview
|
||||
> **Scope:** Stand up a working ResolutionFlow dev environment from scratch on any Linux host (VPS, on-prem Proxmox LXC/VM, bare metal). Self-contained — do not read another doc to get the dev stack running.
|
||||
> **Last rewritten:** April 2026, post-Hostinger-VPS deprecation, ahead of Proxmox migration.
|
||||
> **Audience:** You (returning to the project), a teammate, or a fresh Claude Code session.
|
||||
|
||||
- **Provider:** Hostinger KVM VPS (srv1522117)
|
||||
- **IP Address:** 46.202.92.250
|
||||
- **OS:** Ubuntu 24.04 LTS
|
||||
- **CPU:** 2 vCPU cores
|
||||
- **RAM:** 8GB
|
||||
- **Disk:** 100GB NVMe SSD
|
||||
- **Swap:** 4GB (`/swapfile`, swappiness=10)
|
||||
If you're picking up mid-migration and need to know what code state is on the current branch, read `docs/FlowAssist_Migration/MIGRATION-HANDOFF.md` first.
|
||||
|
||||
## Architecture
|
||||
---
|
||||
|
||||
All services run as Docker containers on the host, managed via SSH or from the VS Code Server integrated terminal.
|
||||
## 1. What this project needs, regardless of host
|
||||
|
||||
```
|
||||
Host (root@srv1522117)
|
||||
├── Traefik → reverse proxy + auto SSL (Let's Encrypt)
|
||||
├── VS Code Server → browser IDE at https://code.resolutionflow.com
|
||||
└── ResolutionFlow Stack
|
||||
├── resolutionflow_frontend → Vite/React on port 5173
|
||||
├── resolutionflow_backend → FastAPI/Uvicorn on port 8000
|
||||
└── resolutionflow_postgres → PostgreSQL 16 + pgvector on port 5432
|
||||
```
|
||||
These are non-negotiable. If your host can't provide them, fix that before anything else.
|
||||
|
||||
## Access URLs
|
||||
| Component | Required version | Notes |
|
||||
|---|---|---|
|
||||
| **Linux** | any mainstream distro | Ubuntu 22.04+ / Debian 12+ tested; Alpine fine for containers |
|
||||
| **Python** | 3.11+ | Backend and migrations |
|
||||
| **Node.js** | 20.19+ | Vite 7 fails on older versions — CLAUDE.md Lesson 63 |
|
||||
| **PostgreSQL** | 16 | `gen_random_uuid()` + `jsonb` + RLS are all leaned on |
|
||||
| **Docker + Docker Compose** | recent | Only if you are running Postgres and/or backend as containers |
|
||||
| **Git** | recent | |
|
||||
|
||||
| Service | URL |
|
||||
Optional but recommended:
|
||||
|
||||
| Tool | Why |
|
||||
|---|---|
|
||||
| VS Code Server | https://code.resolutionflow.com |
|
||||
| Frontend (dev) | http://46.202.92.250:5173 |
|
||||
| Backend API | http://46.202.92.250:8000 |
|
||||
| API Docs | http://46.202.92.250:8000/docs |
|
||||
| **code-server** | Browser-based VS Code; how this project has historically been edited |
|
||||
| **`gh` CLI** | Mirror repo is on GitHub via Gitea; `gh` reads issues and PRs |
|
||||
| **bun** | Required for the gstack `/browse` + `/qa` skills (CLAUDE.md Lesson 82) |
|
||||
| **`npx gitnexus analyze`** | Code-graph for Phase 2+ work that touches `unified_chat_service` |
|
||||
| **Claude Code CLI** | If you want to run Claude Code locally on the host |
|
||||
|
||||
## Docker Layout
|
||||
---
|
||||
|
||||
## 2. Architectural shape
|
||||
|
||||
The project is three services plus your editor. Keep these facts in mind regardless of topology:
|
||||
|
||||
```
|
||||
/docker/
|
||||
├── traefik/
|
||||
│ ├── docker-compose.yml → Traefik reverse proxy
|
||||
│ └── .env → ACME_EMAIL for Let's Encrypt
|
||||
└── vscode/
|
||||
├── docker-compose.yml → VS Code Server
|
||||
└── .env → CODE_PASSWORD
|
||||
Your browser
|
||||
├─► code-server (editor, optional — usually port 8080 or behind TLS)
|
||||
├─► frontend (Vite) (dev server, port 5173)
|
||||
└─► backend (FastAPI) (dev server, port 8000)
|
||||
│
|
||||
└─► PostgreSQL (port 5432)
|
||||
```
|
||||
|
||||
Project lives inside the VS Code Server Docker volume:
|
||||
**The frontend calls the backend by URL at runtime.** The frontend does not proxy through the backend. Whatever URL your browser uses to reach the backend is what `VITE_API_URL` must be set to, **baked in at build time**. Changing `VITE_API_URL` requires rebuilding the frontend.
|
||||
|
||||
**The backend calls the database by URL at runtime.** The URL depends on where Postgres is relative to the backend — Docker service name if both are in the same compose network, `localhost` if Postgres is native on the same host, or a DNS name if they're in separate containers/VMs.
|
||||
|
||||
**CORS is configured explicitly.** The backend's `CORS_ORIGINS` list must include every origin your browser will use to reach the frontend. A missing origin shows up as failed preflight requests.
|
||||
|
||||
---
|
||||
|
||||
## 3. Topology choices — pick one before you start
|
||||
|
||||
The project is agnostic to topology, but each shape has different setup steps.
|
||||
|
||||
### Option A — all-in-one LXC/VM/host (simplest)
|
||||
|
||||
Postgres, backend, and frontend all run on one Linux host. code-server runs on the same host or a sibling. No Docker required. Best for a single-developer Proxmox LXC.
|
||||
|
||||
### Option B — Docker Compose on one host
|
||||
|
||||
Postgres, backend, and frontend run as Docker containers on one host. code-server runs outside the compose network (on the host or in another container). This is how the old Hostinger VPS was configured. Best if you want reproducible container images.
|
||||
|
||||
### Option C — split services across containers/VMs
|
||||
|
||||
Postgres in one container/VM, backend and frontend in another, code-server in a third. Most complex; requires explicit networking between them. Use only if you have a specific reason.
|
||||
|
||||
**Pick one and stick with it for the entire setup.** Mixing Options A and B halfway through is where setup runs off the rails.
|
||||
|
||||
---
|
||||
|
||||
## 4. Per-host configuration
|
||||
|
||||
These values are specific to your host. Fill them in once and reference them by name throughout the rest of the doc.
|
||||
|
||||
```
|
||||
/var/lib/docker/volumes/vscode_vscode-data/_data/resolutionflow/
|
||||
DEV_HOST = <hostname or IP your browser uses, e.g. dev.internal, 10.0.0.42>
|
||||
DEV_HOST_SCHEME = <http or https; http is fine for internal dev, https if behind a TLS proxy>
|
||||
FRONTEND_PORT = 5173
|
||||
BACKEND_PORT = 8000
|
||||
POSTGRES_PORT = 5432 # or 5433 if you're avoiding conflict with a host Postgres
|
||||
POSTGRES_DB_NAME = resolutionflow
|
||||
POSTGRES_USER = postgres
|
||||
POSTGRES_PASSWORD = <local-dev-password; anything, this is not prod>
|
||||
SECRET_KEY = <openssl rand -hex 32 — generate fresh per host, do not reuse>
|
||||
ANTHROPIC_API_KEY = <from https://console.anthropic.com>
|
||||
GOOGLE_AI_API_KEY = <optional, only if using Gemini as a fallback>
|
||||
```
|
||||
|
||||
## VS Code Server
|
||||
Store these somewhere you can copy from during setup. Do not commit them.
|
||||
|
||||
- **Container user:** `coder` (UID 1000)
|
||||
- **Home directory:** `/home/coder`
|
||||
- **Project location:** `/home/coder/resolutionflow`
|
||||
- **Host volume path:** `/var/lib/docker/volumes/vscode_vscode-data/_data`
|
||||
- **Access URL:** `https://code.resolutionflow.com`
|
||||
- **HTTPS:** Auto-provisioned via Traefik + Let's Encrypt
|
||||
> **Naming note:** the canonical database name is `resolutionflow`. If you see `patherly` in a config file, that's drift from an earlier rename and is being swept in a separate commit — use `resolutionflow`. CLAUDE.md tracks the live-code files that still reference `patherly`.
|
||||
|
||||
### Compose File Location
|
||||
`/docker/vscode/docker-compose.yml`
|
||||
---
|
||||
|
||||
## Traefik
|
||||
## 5. Setup procedure
|
||||
|
||||
Handles reverse proxying and automatic SSL for all services. HTTP automatically redirects to HTTPS.
|
||||
Run these in order. Stop at the first failure and investigate.
|
||||
|
||||
### Adding A New Service Behind Traefik
|
||||
|
||||
Add these labels to any new Docker service:
|
||||
|
||||
```yaml
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.<n>.rule=Host(`subdomain.resolutionflow.com`)"
|
||||
- "traefik.http.routers.<n>.entrypoints=websecure"
|
||||
- "traefik.http.routers.<n>.tls.certresolver=letsencrypt"
|
||||
- "traefik.http.services.<n>.loadbalancer.server.port=<port>"
|
||||
```
|
||||
|
||||
Also create an A record in DNS pointing the subdomain to `46.202.92.250`.
|
||||
|
||||
## ResolutionFlow Dev Stack
|
||||
|
||||
### Important: No Docker Inside VS Code Container
|
||||
|
||||
The VS Code Server container does NOT have Docker. All `docker compose` commands must be run via SSH as root on the host.
|
||||
|
||||
### Environment Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `.env` | Root — Docker Compose interpolation (`SECRET_KEY`, `ANTHROPIC_API_KEY`, `GOOGLE_AI_API_KEY`, `POSTGRES_PORT`) |
|
||||
| `backend/.env` | Backend source of truth — all FastAPI settings, API keys, DB URLs, CORS |
|
||||
| `frontend/.env` | Frontend — `VITE_API_URL` pointing to backend |
|
||||
|
||||
### Critical Remote Access Config
|
||||
|
||||
**`frontend/.env`:**
|
||||
```
|
||||
VITE_API_URL=http://46.202.92.250:8000
|
||||
```
|
||||
|
||||
**`backend/.env`:**
|
||||
```
|
||||
CORS_ORIGINS=["http://localhost:3000","http://localhost:5173","http://127.0.0.1:3000","http://127.0.0.1:5173","http://46.202.92.250:5173","http://46.202.92.250:3000","https://resolutionflow.com","https://www.resolutionflow.com"]
|
||||
FRONTEND_URL=http://46.202.92.250:5173
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/resolutionflow
|
||||
DATABASE_URL_SYNC=postgresql://postgres:postgres@db:5432/resolutionflow
|
||||
```
|
||||
|
||||
Note: `DATABASE_URL` uses `@db:5432` (Docker service name), not `@localhost`.
|
||||
|
||||
**`docker-compose.dev.yml`:**
|
||||
```yaml
|
||||
- VITE_API_URL=http://46.202.92.250:8000
|
||||
```
|
||||
|
||||
### Starting the Dev Environment
|
||||
|
||||
SSH into host as root:
|
||||
### 5.1 Install system dependencies
|
||||
|
||||
```bash
|
||||
cd /var/lib/docker/volumes/vscode_vscode-data/_data/resolutionflow
|
||||
docker compose -f docker-compose.dev.yml up -d
|
||||
# Ubuntu / Debian
|
||||
sudo apt update && sudo apt install -y \
|
||||
git curl build-essential \
|
||||
python3.11 python3.11-venv python3-pip \
|
||||
postgresql-client # not the server — only if running Postgres natively
|
||||
|
||||
# Node 20 via nvm (survives container rebuilds if stored in a volume)
|
||||
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash
|
||||
export NVM_DIR="$HOME/.nvm" && source "$NVM_DIR/nvm.sh"
|
||||
nvm install 20
|
||||
nvm alias default 20
|
||||
```
|
||||
|
||||
### Running Migrations (Fresh Database)
|
||||
For Option B (Docker Compose), also:
|
||||
|
||||
```bash
|
||||
cd /var/lib/docker/volumes/vscode_vscode-data/_data/resolutionflow
|
||||
curl -fsSL https://get.docker.com | sh
|
||||
sudo usermod -aG docker $USER # log out and back in for this to take effect
|
||||
```
|
||||
|
||||
### 5.2 Clone the repo
|
||||
|
||||
```bash
|
||||
git clone https://gitea.resolutionflow.com/chihlasm/resolutionflow.git
|
||||
# or the GitHub mirror:
|
||||
# git clone https://github.com/chihlasm/resolutionflow.git
|
||||
cd resolutionflow
|
||||
|
||||
# Check out the working branch if you're continuing mid-migration.
|
||||
git fetch origin
|
||||
git checkout feat/flowpilot-migration
|
||||
```
|
||||
|
||||
### 5.3 Start PostgreSQL
|
||||
|
||||
**Option A (native Postgres on the host):**
|
||||
|
||||
```bash
|
||||
sudo apt install -y postgresql-16
|
||||
sudo -u postgres psql -c "CREATE DATABASE resolutionflow;"
|
||||
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';"
|
||||
# Adjust pg_hba.conf if you need non-local connections.
|
||||
```
|
||||
|
||||
**Option B (Postgres via Docker Compose):** The repo has a `docker-compose.dev.yml` at the root. Check its Postgres service for the container name, port mapping, and volume. CLAUDE.md Lesson 65 notes the local compose defaults use container name `resolutionflow_postgres`, database `resolutionflow`, port `5433` mapped to the host. Confirm what the compose file actually says on your branch before trusting those values.
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml up -d db
|
||||
docker compose -f docker-compose.dev.yml logs db # wait for "ready to accept connections"
|
||||
```
|
||||
|
||||
**Verify:**
|
||||
|
||||
```bash
|
||||
# From the host (Option A) or the backend container/LXC (Option B):
|
||||
psql -h <db-host> -p <POSTGRES_PORT> -U postgres -d resolutionflow -c "SELECT now();"
|
||||
```
|
||||
|
||||
### 5.4 Write the `.env` files
|
||||
|
||||
The repo expects three env files. Create each one:
|
||||
|
||||
**`backend/.env`** — backend source of truth:
|
||||
|
||||
```bash
|
||||
APP_NAME=ResolutionFlow
|
||||
DEBUG=true
|
||||
|
||||
# DB URLs — `<db-host>` is `localhost` for Option A, the Docker service name
|
||||
# (e.g. `db`) for Option B, or the DB container/VM hostname for Option C.
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@<db-host>:<POSTGRES_PORT>/resolutionflow
|
||||
DATABASE_URL_SYNC=postgresql://postgres:postgres@<db-host>:<POSTGRES_PORT>/resolutionflow
|
||||
|
||||
# Auth
|
||||
SECRET_KEY=<SECRET_KEY>
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=5
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
REQUIRE_INVITE_CODE=true
|
||||
|
||||
# AI providers
|
||||
AI_PROVIDER=anthropic
|
||||
ANTHROPIC_API_KEY=<ANTHROPIC_API_KEY>
|
||||
GOOGLE_AI_API_KEY=<GOOGLE_AI_API_KEY or leave unset>
|
||||
|
||||
# FlowPilot MCP telemetry — leave on so the Phase 0.5 baseline data keeps accruing
|
||||
ENABLE_MCP_MICROSOFT_LEARN=true
|
||||
|
||||
# CORS + frontend URL
|
||||
FRONTEND_URL=<DEV_HOST_SCHEME>://<DEV_HOST>:<FRONTEND_PORT>
|
||||
CORS_ORIGINS=["http://localhost:5173","http://127.0.0.1:5173","<DEV_HOST_SCHEME>://<DEV_HOST>:<FRONTEND_PORT>"]
|
||||
```
|
||||
|
||||
**`frontend/.env.local`** — frontend build-time config:
|
||||
|
||||
```bash
|
||||
VITE_API_URL=<DEV_HOST_SCHEME>://<DEV_HOST>:<BACKEND_PORT>
|
||||
```
|
||||
|
||||
Optional PostHog (CLAUDE.md Lesson 64 — enables product analytics locally):
|
||||
|
||||
```bash
|
||||
VITE_PUBLIC_POSTHOG_KEY=<from PostHog project settings>
|
||||
VITE_PUBLIC_POSTHOG_HOST=https://us.i.posthog.com
|
||||
```
|
||||
|
||||
**Repo root `.env`** — only needed for Option B (Docker Compose interpolation):
|
||||
|
||||
```bash
|
||||
SECRET_KEY=<SECRET_KEY>
|
||||
ANTHROPIC_API_KEY=<ANTHROPIC_API_KEY>
|
||||
GOOGLE_AI_API_KEY=<GOOGLE_AI_API_KEY or leave unset>
|
||||
POSTGRES_PORT=<POSTGRES_PORT>
|
||||
```
|
||||
|
||||
> **Never commit any `.env` file.** The `.gitignore` already covers this.
|
||||
|
||||
### 5.5 Run the backend setup
|
||||
|
||||
**Option A (native):**
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python3.11 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Migrate the DB to head.
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
**Option B (Docker):**
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml up -d backend
|
||||
docker compose -f docker-compose.dev.yml run --rm backend alembic upgrade head
|
||||
```
|
||||
|
||||
### Seeding Test Users
|
||||
**Expected alembic head** (as of `feat/flowpilot-migration`): `f07010f17b01`. If `alembic current` shows anything else after `upgrade head`, something has gone wrong — stop and investigate.
|
||||
|
||||
### 5.6 Seed test users
|
||||
|
||||
```bash
|
||||
# Option A
|
||||
cd backend && source venv/bin/activate
|
||||
python -m scripts.seed_test_users
|
||||
|
||||
# Option B
|
||||
docker exec resolutionflow_backend python -m scripts.seed_test_users
|
||||
```
|
||||
|
||||
Test accounts (password: `TestPass123!`):
|
||||
Test users (all share password `TestPass123!`):
|
||||
|
||||
| Email | Role | Plan |
|
||||
|---|---|---|
|
||||
| admin@resolutionflow.example.com | Owner | Team |
|
||||
| pro@resolutionflow.example.com | Owner | Pro |
|
||||
| teamadmin@resolutionflow.example.com | Owner | Team |
|
||||
| engineer@resolutionflow.example.com | Engineer | Shared |
|
||||
| Email | Role |
|
||||
|---|---|
|
||||
| `admin@resolutionflow.example.com` | super admin |
|
||||
| `teamadmin@resolutionflow.example.com` | team admin |
|
||||
| `engineer@resolutionflow.example.com` | engineer |
|
||||
| `pro@resolutionflow.example.com` | solo pro |
|
||||
|
||||
### Rebuilding After Config Changes
|
||||
### 5.7 Run the backend
|
||||
|
||||
**Option A:**
|
||||
|
||||
```bash
|
||||
cd backend && source venv/bin/activate
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
**Option B:** Already running from `docker compose up -d backend`. Tail logs:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml logs -f backend
|
||||
```
|
||||
|
||||
**Verify:** `curl <DEV_HOST_SCHEME>://<DEV_HOST>:<BACKEND_PORT>/api/docs` — OpenAPI docs page loads.
|
||||
|
||||
### 5.8 Run the frontend
|
||||
|
||||
**Option A:**
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev -- --host 0.0.0.0 --port 5173
|
||||
```
|
||||
|
||||
**Option B:**
|
||||
|
||||
**Frontend** (Vite bakes env vars at build time — requires rebuild):
|
||||
```bash
|
||||
cd /var/lib/docker/volumes/vscode_vscode-data/_data/resolutionflow
|
||||
docker compose -f docker-compose.dev.yml up -d --build frontend
|
||||
```
|
||||
|
||||
**Backend** (restart only):
|
||||
**Verify:** Open `<DEV_HOST_SCHEME>://<DEV_HOST>:<FRONTEND_PORT>` in your browser. Log in with one of the test users. Navigate to `/pilot` — the FlowPilot session page should render.
|
||||
|
||||
---
|
||||
|
||||
## 6. Verification — proof the env actually works
|
||||
|
||||
Run these after setup. Every item has a concrete expected outcome.
|
||||
|
||||
### 6.1 Database schema is at the right version
|
||||
|
||||
```bash
|
||||
# Option A
|
||||
cd backend && source venv/bin/activate && alembic current
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml run --rm backend alembic current
|
||||
```
|
||||
|
||||
Expected: `f07010f17b01 (head)` on the `feat/flowpilot-migration` branch. On `main`, expected: `074 (head)`.
|
||||
|
||||
### 6.2 Alembic reversibility
|
||||
|
||||
```bash
|
||||
alembic downgrade -1 # should complete cleanly
|
||||
alembic upgrade head # should return to f07010f17b01
|
||||
```
|
||||
|
||||
If either step fails, the migration has a bug and Phase 2 cannot start.
|
||||
|
||||
### 6.3 Prompt-cache hit verification (the deferred Phase 0 TODO)
|
||||
|
||||
`backend/app/core/ai_provider.py` module docstring has a `TODO(phase0-verify)` note describing this. Procedure:
|
||||
|
||||
1. Confirm `AI_PROVIDER=anthropic` and `ANTHROPIC_API_KEY` is set in `backend/.env`.
|
||||
2. Start the backend with log level INFO or lower.
|
||||
3. In the UI, open `/pilot` and send a chat message. Wait a few seconds for the response.
|
||||
4. Send a second chat message in the same session, within 5 minutes of the first.
|
||||
5. In backend logs, grep for lines containing `anthropic.cache`:
|
||||
|
||||
```bash
|
||||
# Option A
|
||||
grep 'anthropic.cache' <log-path>
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml logs backend | grep 'anthropic.cache'
|
||||
```
|
||||
|
||||
6. Expected: two `anthropic.cache` log events. First has `cache_creation_input_tokens > 0`. Second has `cache_read_input_tokens > 0`.
|
||||
7. If the second shows zero reads, inspect the prompt prefix for silent invalidators (timestamps, unsorted JSON keys, varying tool list ordering). Fix before proceeding with any Phase 2 work.
|
||||
|
||||
### 6.4 Frontend build is TypeScript-clean
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npx tsc -b # no errors
|
||||
npm run build # no errors
|
||||
```
|
||||
|
||||
CLAUDE.md Lesson 105 notes that `npm run build` may fail with an `EACCES` on `dist/` inside code-server — that is a Docker filesystem permission issue, not a real build error. Use `npx tsc -b` to verify TypeScript cleanliness in that case.
|
||||
|
||||
### 6.5 `/assistant` → `/pilot` redirect
|
||||
|
||||
Open `<DEV_HOST_SCHEME>://<DEV_HOST>:<FRONTEND_PORT>/assistant/<some-real-session-id>` in the browser. Expected: URL changes to `/pilot/<that-id>`; the FlowPilot session page renders. Bare `/assistant` redirects to bare `/pilot`.
|
||||
|
||||
### 6.6 Dispatcher de-branching
|
||||
|
||||
Navigate to the dashboard. Click a session in `ActiveFlowPilotSessions` or `RecentFlowPilotSessions`. Expected: routes to `/pilot/:id` regardless of the session's `session_type` value. (Check the browser URL bar.)
|
||||
|
||||
### 6.7 CORS
|
||||
|
||||
Open the browser DevTools Network tab, navigate to any backend-hitting page. Expected: no CORS errors. If you see "blocked by CORS policy," the missing origin needs adding to `backend/.env`'s `CORS_ORIGINS`.
|
||||
|
||||
---
|
||||
|
||||
## 7. Runbook
|
||||
|
||||
Day-to-day commands after setup is complete.
|
||||
|
||||
### Restart services
|
||||
|
||||
```bash
|
||||
# Option A
|
||||
# backend — Ctrl-C and re-run uvicorn
|
||||
# frontend — Ctrl-C and re-run npm run dev
|
||||
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml restart backend
|
||||
docker compose -f docker-compose.dev.yml up -d --build frontend # rebuild required if VITE_* changed
|
||||
docker compose -f docker-compose.dev.yml down && docker compose -f docker-compose.dev.yml up -d # full restart
|
||||
```
|
||||
|
||||
**Full restart:**
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml down
|
||||
docker compose -f docker-compose.dev.yml up -d
|
||||
```
|
||||
|
||||
## Installed Tools (Inside VS Code Server Container)
|
||||
|
||||
Installed in `/home/coder` — persists via Docker volume:
|
||||
|
||||
- **nvm** — Node version manager
|
||||
- **Node.js 20.x** — via nvm, default alias set
|
||||
- **npm** — latest
|
||||
- **GitHub CLI (gh)** — authenticated via personal access token
|
||||
- **Claude Code CLI** — `@anthropic-ai/claude-code` (global npm)
|
||||
|
||||
### Permanent Tool Installs
|
||||
|
||||
Tools installed via `apt` inside the container do NOT survive container rebuilds. To add permanently, modify the VS Code Server Docker image and rebuild.
|
||||
|
||||
Temporary (session only):
|
||||
```bash
|
||||
sudo apt update && sudo apt install -y <tool>
|
||||
```
|
||||
|
||||
## SSH Access
|
||||
### Apply a new migration
|
||||
|
||||
```bash
|
||||
ssh root@46.202.92.250
|
||||
# Option A
|
||||
cd backend && source venv/bin/activate && alembic upgrade head
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml run --rm backend alembic upgrade head
|
||||
```
|
||||
|
||||
Key auth configured via `~/.ssh/authorized_keys` on host.
|
||||
### Create a new migration
|
||||
|
||||
## Useful Commands
|
||||
|
||||
### Check all running containers
|
||||
```bash
|
||||
docker ps --format "table {{.Names}}\t{{.Status}}\t{{.Ports}}"
|
||||
# Option A
|
||||
cd backend && source venv/bin/activate
|
||||
alembic revision -m "short description" # manual, preferred per CLAUDE.md Lesson 77
|
||||
# OR
|
||||
alembic revision --autogenerate -m "description" # pulls in drift; review carefully
|
||||
```
|
||||
|
||||
### View container logs
|
||||
Never pass `--rev-id` — let Alembic generate the hex hash.
|
||||
|
||||
### Inspect the database
|
||||
|
||||
```bash
|
||||
docker logs <container_name> --tail 30 -f
|
||||
# Option A (native Postgres)
|
||||
psql -h localhost -p 5432 -U postgres -d resolutionflow
|
||||
|
||||
# Option B (Docker)
|
||||
docker exec -it resolutionflow_postgres psql -U postgres -d resolutionflow
|
||||
```
|
||||
|
||||
### Restart VS Code Server
|
||||
### Run tests
|
||||
|
||||
```bash
|
||||
cd /docker/vscode && docker compose restart
|
||||
# Option A
|
||||
cd backend && source venv/bin/activate
|
||||
pytest --override-ini="addopts="
|
||||
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml run --rm backend pytest --override-ini="addopts="
|
||||
```
|
||||
|
||||
### Restart Traefik
|
||||
First time only, create the test database:
|
||||
|
||||
```bash
|
||||
cd /docker/traefik && docker compose restart
|
||||
# Option A
|
||||
sudo -u postgres psql -c "CREATE DATABASE resolutionflow_test;"
|
||||
|
||||
# Option B
|
||||
docker exec -it resolutionflow_postgres psql -U postgres -c "CREATE DATABASE resolutionflow_test;"
|
||||
```
|
||||
|
||||
### Restart dev stack
|
||||
### View backend logs
|
||||
|
||||
```bash
|
||||
cd /var/lib/docker/volumes/vscode_vscode-data/_data/resolutionflow
|
||||
docker compose -f docker-compose.dev.yml down
|
||||
docker compose -f docker-compose.dev.yml up -d
|
||||
# Option A: wherever you ran uvicorn
|
||||
# Option B
|
||||
docker compose -f docker-compose.dev.yml logs -f --tail=100 backend
|
||||
```
|
||||
|
||||
### Check swap
|
||||
Structured events to grep for:
|
||||
- `anthropic.cache` — prompt-cache hit/creation telemetry (Phase 0.1)
|
||||
- `mcp.turn` — per-turn MCP availability/invocation (Phase 0.5)
|
||||
- `mcp.fallback` — MCP silent-retry fallback fired (Phase 0.5)
|
||||
|
||||
---
|
||||
|
||||
## 8. Troubleshooting
|
||||
|
||||
### CORS errors in the browser
|
||||
|
||||
The backend did not accept the origin your browser used. Check `backend/.env`'s `CORS_ORIGINS` — it must include the exact scheme + host + port the browser sent. Restart the backend after editing.
|
||||
|
||||
### `VITE_API_URL` points at the wrong place
|
||||
|
||||
The frontend was built with a stale value. Rebuild the frontend. Option B: `docker compose up -d --build frontend`. Option A: restart `npm run dev`.
|
||||
|
||||
### `alembic upgrade head` fails with "target database is not up to date"
|
||||
|
||||
Your DB migration chain is out of sync with the code. On a dev box, the safe recovery is to drop the DB and re-migrate from scratch:
|
||||
|
||||
```bash
|
||||
free -h && swapon --show
|
||||
# Option A
|
||||
sudo -u postgres psql -c "DROP DATABASE resolutionflow;" -c "CREATE DATABASE resolutionflow;"
|
||||
cd backend && source venv/bin/activate && alembic upgrade head
|
||||
|
||||
# Option B
|
||||
docker exec resolutionflow_postgres psql -U postgres -c "DROP DATABASE resolutionflow;" -c "CREATE DATABASE resolutionflow;"
|
||||
docker compose -f docker-compose.dev.yml run --rm backend alembic upgrade head
|
||||
```
|
||||
|
||||
### Check disk
|
||||
```bash
|
||||
df -h
|
||||
```
|
||||
Only do this on a dev box — it destroys all local data.
|
||||
|
||||
### Check memory + container usage
|
||||
```bash
|
||||
free -h && docker stats --no-stream
|
||||
```
|
||||
### `alembic heads` shows more than one head
|
||||
|
||||
## DNS Records (resolutionflow.com)
|
||||
Only on a local branch that has diverged from `origin/main`. Production `main` has a single head. If this happens on a fresh clone, one of your local migration files has the wrong `down_revision`. Inspect each file's `down_revision` and reconnect the chain.
|
||||
|
||||
| Type | Name | Value | Purpose |
|
||||
|---|---|---|---|
|
||||
| A | code | 46.202.92.250 | VS Code Server |
|
||||
### Frontend build fails with "EACCES: permission denied" on `dist/`
|
||||
|
||||
## Security Notes
|
||||
Filesystem permission issue inside the code-server container (CLAUDE.md Lesson 105). TypeScript compilation itself completes — use `npx tsc -b` to verify cleanliness without needing to write to `dist/`.
|
||||
|
||||
- UFW is inactive — Traefik and Docker manage port exposure
|
||||
- All public-facing services run through Traefik with valid HTTPS certs
|
||||
- PostgreSQL port 5432 is exposed on all interfaces — restrict if needed in production
|
||||
- Rotate API keys (Anthropic, Voyage) if ever exposed in logs or chat
|
||||
- Never commit `.env` files to Git
|
||||
### `docker` command not found inside code-server
|
||||
|
||||
## VS Code Server Browser Tips
|
||||
If your code-server is itself inside a container, Docker is probably not exposed to it. CLAUDE.md Lesson 103 was written for this case on the old VPS. On Proxmox, the fix depends on topology — either SSH to the host to run Docker commands, or mount the host's Docker socket into the code-server container.
|
||||
|
||||
- **Command Palette:** `F1`
|
||||
- **Terminal:** Ctrl+`
|
||||
- **Rename file:** `F2`
|
||||
- **Go to definition:** `F12`
|
||||
- **Find references:** `Shift+F12`
|
||||
- **Context Menu:** `Alt + Right Click`
|
||||
### Backend returns 500 with `InsufficientPrivilegeError: new row violates row-level security policy`
|
||||
|
||||
RLS is enabled on a table your code wrote to without the right `account_id`. CLAUDE.md Lessons 107, 108, 110 cover this family of bugs. The fix is always at the service layer: make sure every model creation passes `account_id=` explicitly, and that startup routines that touch tenant-isolated tables use `_admin_session_factory()` rather than `get_db()`.
|
||||
|
||||
### Anthropic cache reads are zero on the second turn
|
||||
|
||||
Something in the cached prefix is changing between turns. Inspect the system-block list and the first N history messages for timestamps, `datetime.now()`, unsorted dict keys in JSON prompts, or varying tool-list order. The `anthropic.cache` telemetry shows exactly how many tokens were read vs created — use it to narrow down the invalidator.
|
||||
|
||||
---
|
||||
|
||||
## 9. Security posture for dev environments
|
||||
|
||||
This doc is about dev, not production. But:
|
||||
|
||||
- Never commit `.env` files. The `.gitignore` covers this.
|
||||
- `SECRET_KEY` should be generated per-host, not reused across environments.
|
||||
- `ANTHROPIC_API_KEY` is billable — rotate if leaked into logs or chat.
|
||||
- Postgres on a dev host should not be exposed to the internet. Bind it to `127.0.0.1` or to a private network interface only.
|
||||
- If you expose the frontend or backend publicly (for teammates to test against), put it behind TLS with a real certificate. Do not let dev credentials travel over plain HTTP on the public internet.
|
||||
|
||||
---
|
||||
|
||||
## 10. What's not in this doc
|
||||
|
||||
- **Production deployment.** This is a dev-env doc. Production lives on Railway — see `CLAUDE.md`'s Deployment section.
|
||||
- **How to set up Traefik or any particular reverse proxy.** Whichever proxy you use is your choice; the dev stack just needs something that routes `<host>:5173` and `<host>:8000` to the right services.
|
||||
- **How to configure code-server itself.** Install it however you prefer (native, Docker, LXC); point it at the repo, and the rest of this doc applies.
|
||||
- **Where to host the Proxmox instance.** Up to you.
|
||||
|
||||
If something in this doc turns out to be wrong on your host, fix the doc. This is a living document — the whole point of rewriting it from the Hostinger-specific version was to make it survive host changes.
|
||||
|
||||
@@ -29,13 +29,37 @@ from app.models.session_branch import SessionBranch # noqa: F401
|
||||
from app.models.fork_point import ForkPoint # noqa: F401
|
||||
from app.models.session_handoff import SessionHandoff # noqa: F401
|
||||
from app.models.session_resolution_output import SessionResolutionOutput # noqa: F401
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _alembic_sync_url() -> str:
|
||||
"""Return a psycopg2-compatible sync URL for Alembic.
|
||||
|
||||
Priority order:
|
||||
1. DATABASE_URL_SYNC — in Railway this is set as a reference variable
|
||||
(${{pgvector.DATABASE_URL}}) that resolves to the correct postgres
|
||||
superuser credentials for the current environment (production, PR preview,
|
||||
etc.). This always works even on fresh databases before any custom roles
|
||||
have been created, because it uses the postgres superuser.
|
||||
2. ADMIN_DATABASE_URL (resolutionflow_admin, BYPASSRLS) converted to a sync
|
||||
driver — fallback for local dev where DATABASE_URL_SYNC may not be set.
|
||||
"""
|
||||
if settings.DATABASE_URL_SYNC:
|
||||
return settings.DATABASE_URL_SYNC
|
||||
|
||||
admin_url = settings.ADMIN_DATABASE_URL
|
||||
if admin_url and "+asyncpg" in admin_url:
|
||||
return admin_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
return settings.DATABASE_URL_SYNC
|
||||
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url with the sync version for migrations
|
||||
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL_SYNC)
|
||||
config.set_main_option("sqlalchemy.url", _alembic_sync_url())
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None:
|
||||
@@ -86,7 +110,7 @@ def run_migrations_online() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
connectable = create_engine(
|
||||
settings.DATABASE_URL_SYNC,
|
||||
_alembic_sync_url(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
|
||||
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Enable RLS on Phase 3 tables.
|
||||
|
||||
Tables covered:
|
||||
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
|
||||
- session_shares (account_id NOT NULL since session_share model)
|
||||
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
|
||||
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
|
||||
|
||||
All use a standard intra-tenant isolation policy.
|
||||
Token-based access to session_shares and tree_shares goes through
|
||||
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
|
||||
policy here is correct.
|
||||
|
||||
Revision ID: 04f013768235
|
||||
Revises: a05e1a1bea7c
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = '04f013768235'
|
||||
down_revision: Union[str, None] = 'a05e1a1bea7c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000')::uuid"
|
||||
)
|
||||
|
||||
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||
|
||||
_PHASE3_TABLES = [
|
||||
"step_ratings",
|
||||
"step_usage_log",
|
||||
"target_lists",
|
||||
"session_shares",
|
||||
"audit_logs",
|
||||
"tree_shares",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_STANDARD_USING})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
132
backend/alembic/versions/073_add_device_types_table.py
Normal file
132
backend/alembic/versions/073_add_device_types_table.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Add account-scoped device_types table with platform seed data.
|
||||
|
||||
Revision ID: 073
|
||||
Revises: b3c7e9f2a1d8
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import uuid
|
||||
|
||||
|
||||
revision = "073"
|
||||
down_revision = "b3c7e9f2a1d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
SYSTEM_DEVICE_TYPES = [
|
||||
("router", "Router", "network", 0),
|
||||
("switch", "Switch", "network", 1),
|
||||
("firewall", "Firewall", "network", 2),
|
||||
("access-point", "Access Point", "network", 3),
|
||||
("load-balancer", "Load Balancer", "network", 4),
|
||||
("server", "Server", "compute", 0),
|
||||
("workstation", "Workstation", "compute", 1),
|
||||
("vm", "Virtual Machine", "compute", 2),
|
||||
("container", "Container", "compute", 3),
|
||||
("nas", "NAS", "storage", 0),
|
||||
("san", "SAN", "storage", 1),
|
||||
("cloud-storage", "Cloud Storage", "storage", 2),
|
||||
("cloud", "Cloud", "cloud", 0),
|
||||
("aws", "AWS", "cloud", 1),
|
||||
("azure", "Azure", "cloud", 2),
|
||||
("gcp", "Google Cloud", "cloud", 3),
|
||||
("printer", "Printer", "endpoint", 0),
|
||||
("phone", "Phone", "endpoint", 1),
|
||||
("iot", "IoT Device", "endpoint", 2),
|
||||
("camera", "Camera", "endpoint", 3),
|
||||
("tablet", "Tablet", "endpoint", 4),
|
||||
("laptop", "Laptop", "endpoint", 5),
|
||||
("ups", "UPS", "infrastructure", 0),
|
||||
("pdu", "PDU", "infrastructure", 1),
|
||||
("rack", "Rack", "infrastructure", 2),
|
||||
("patch-panel", "Patch Panel", "infrastructure", 3),
|
||||
("nvr", "NVR", "security", 0),
|
||||
("badge-reader", "Badge Reader", "security", 1),
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"device_types",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("slug", sa.String(50), nullable=False),
|
||||
sa.Column("label", sa.String(100), nullable=False),
|
||||
sa.Column("category", sa.String(50), nullable=False),
|
||||
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"])
|
||||
op.create_index("ix_device_types_account_id", "device_types", ["account_id"])
|
||||
|
||||
device_types_table = sa.table(
|
||||
"device_types",
|
||||
sa.column("id", UUID(as_uuid=True)),
|
||||
sa.column("slug", sa.String),
|
||||
sa.column("label", sa.String),
|
||||
sa.column("category", sa.String),
|
||||
sa.column("is_system", sa.Boolean),
|
||||
sa.column("account_id", UUID(as_uuid=True)),
|
||||
sa.column("sort_order", sa.Integer),
|
||||
)
|
||||
|
||||
op.bulk_insert(device_types_table, [
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"slug": slug,
|
||||
"label": label,
|
||||
"category": category,
|
||||
"is_system": True,
|
||||
"account_id": uuid.UUID(_PLATFORM_UUID),
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
|
||||
])
|
||||
|
||||
op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_select ON device_types
|
||||
FOR SELECT
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
)
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_insert ON device_types
|
||||
FOR INSERT
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_update ON device_types
|
||||
FOR UPDATE
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_delete ON device_types
|
||||
FOR DELETE
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_update ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_select ON device_types")
|
||||
op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_table("device_types")
|
||||
57
backend/alembic/versions/074_add_network_diagrams_table.py
Normal file
57
backend/alembic/versions/074_add_network_diagrams_table.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Add network_diagrams table.
|
||||
|
||||
Revision ID: 074
|
||||
Revises: 073
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
|
||||
revision = "074"
|
||||
down_revision = "073"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"network_diagrams",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("client_name", sa.String(255), nullable=True),
|
||||
sa.Column("asset_name", sa.String(255), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("nodes", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("edges", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("thumbnail_url", sa.Text(), nullable=True),
|
||||
sa.Column("is_archived", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"])
|
||||
op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"])
|
||||
op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON network_diagrams
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams")
|
||||
op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_table("network_diagrams")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Drop team_id from target_lists.
|
||||
|
||||
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
|
||||
All reads/writes use account_id via RLS + application filter.
|
||||
|
||||
Revision ID: 172ad76d7d20
|
||||
Revises: 04f013768235
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '172ad76d7d20'
|
||||
down_revision: Union[str, None] = '04f013768235'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
|
||||
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
|
||||
op.drop_column('target_lists', 'team_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'target_lists_team_id_fkey', 'target_lists', 'teams',
|
||||
['team_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add account_id to audit_logs and backfill via user_id.
|
||||
|
||||
Revision ID: 2a9056eddd90
|
||||
Revises: 70a5dd746e83
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '2a9056eddd90'
|
||||
down_revision: Union[str, None] = '70a5dd746e83'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the acting user's account
|
||||
op.execute("""
|
||||
UPDATE audit_logs al
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE al.user_id = u.id
|
||||
AND u.account_id IS NOT NULL
|
||||
AND al.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
|
||||
"All audit log entries must have an associated user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('audit_logs', 'account_id', nullable=False)
|
||||
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
|
||||
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
|
||||
op.drop_column('audit_logs', 'account_id')
|
||||
90
backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py
Normal file
90
backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Enable RLS on Phase 2 session and supporting tables.
|
||||
|
||||
10 tables use a standard tenant-only policy.
|
||||
step_library uses a visibility-aware policy — public steps visible to all tenants.
|
||||
|
||||
NOTE: session_messages does not exist in this codebase (removed from plan).
|
||||
script_generations is the correct table name (not script_template_generations).
|
||||
sessions and ai_sessions are two separate tables, both in scope.
|
||||
|
||||
Prerequisites:
|
||||
- Phase 1 migration must have run (resolutionflow_app role exists, Phase 1 tables have RLS)
|
||||
- NOT NULL write-path bugs fixed (P2-A commits b641ac6)
|
||||
- shares.py cross-tenant session fix deployed (P2-B commit ac2b193)
|
||||
|
||||
Revision ID: 70a5dd746e83
|
||||
Revises: c5f48b9890f9
|
||||
Create Date: 2026-04-10 06:54:49.431817
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '70a5dd746e83'
|
||||
down_revision: Union[str, None] = 'c5f48b9890f9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
_CURRENT_ACCOUNT = (
|
||||
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
f"'{_NULL_UUID}')::uuid"
|
||||
)
|
||||
|
||||
# Standard tenant-only policy — account_id must match the current tenant.
|
||||
# When no tenant context is set, COALESCE returns the nil UUID so zero rows
|
||||
# are visible (fail-closed).
|
||||
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||
|
||||
# Visibility-aware policy for step_library — public steps (visibility='public')
|
||||
# must be visible to ALL tenants regardless of account_id. This covers the
|
||||
# visibility='public' arm of build_step_visibility_filter() in app/core/filters.py.
|
||||
# The created_by arm (private steps visible to their author) is covered
|
||||
# transitively: private steps share account_id with their creator, so the
|
||||
# account_id match handles it. This relies on account_id NOT NULL on step_library.
|
||||
_STEP_LIBRARY_USING = f"account_id = {_CURRENT_ACCOUNT} OR visibility = 'public'"
|
||||
|
||||
# Standard tables: strict tenant isolation, no cross-tenant visibility.
|
||||
_STANDARD_TABLES = [
|
||||
"sessions",
|
||||
"ai_sessions",
|
||||
"session_branches",
|
||||
"session_supporting_data",
|
||||
"session_resolution_outputs",
|
||||
"session_handoffs",
|
||||
"script_templates",
|
||||
"script_generations",
|
||||
"maintenance_schedules",
|
||||
"psa_post_log",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Standard tenant-isolation tables ────────────────────────────────────
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_STANDARD_USING})
|
||||
""")
|
||||
|
||||
# ── step_library ────────────────────────────────────────────────────────
|
||||
# Public steps (visibility='public') must be readable by all tenants so
|
||||
# the Solutions Library browsing experience works without tenant context.
|
||||
# Private/team steps remain tenant-scoped.
|
||||
op.execute("ALTER TABLE step_library ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE step_library FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON step_library
|
||||
USING ({_STEP_LIBRARY_USING})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _STANDARD_TABLES + ["step_library"]:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add account_id to tree_shares and backfill via tree owner's account.
|
||||
|
||||
The share belongs to the tree's tenant, not the actor who created it.
|
||||
A super admin in account A can share a tree owned by account B; that share
|
||||
must land in account B so account B's RLS filter sees it.
|
||||
|
||||
Revision ID: a05e1a1bea7c
|
||||
Revises: 2a9056eddd90
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = 'a05e1a1bea7c'
|
||||
down_revision: Union[str, None] = '2a9056eddd90'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the tree's account, not the creator's account.
|
||||
# A share lives in the same tenant as its tree so that the tree owner's
|
||||
# RLS context covers their own shares regardless of who created them.
|
||||
op.execute("""
|
||||
UPDATE tree_shares ts
|
||||
SET account_id = t.account_id
|
||||
FROM trees t
|
||||
WHERE ts.tree_id = t.id
|
||||
AND t.account_id IS NOT NULL
|
||||
AND ts.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
|
||||
"All share entries must have a creating user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('tree_shares', 'account_id', nullable=False)
|
||||
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
|
||||
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
|
||||
op.drop_column('tree_shares', 'account_id')
|
||||
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables.
|
||||
|
||||
All tables in this migration already have account_id NOT NULL (enforced by
|
||||
earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY,
|
||||
FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each.
|
||||
|
||||
Policy variants used:
|
||||
- Standard: account_id = current_setting(app.current_account_id)::uuid
|
||||
- Platform: standard OR account_id = PLATFORM_ACCOUNT_ID
|
||||
(for global content tables readable by all tenants)
|
||||
|
||||
Skipped intentionally:
|
||||
- accounts — IS the root table; no account_id column
|
||||
- plan_feature_defaults — platform config; no account_id column
|
||||
- script_categories — global lookup table; no account_id column
|
||||
- platform_steps — global content; no account_id column (readable by all)
|
||||
- template_trees — global content; no account_id column (readable by all)
|
||||
|
||||
Revision ID: b3c7e9f2a1d8
|
||||
Revises: 172ad76d7d20
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b3c7e9f2a1d8"
|
||||
down_revision: Union[str, None] = "172ad76d7d20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Standard policy — tenant sees only own rows.
|
||||
_STANDARD_TABLES = [
|
||||
"users",
|
||||
"account_invites",
|
||||
"account_limit_overrides",
|
||||
"account_feature_overrides",
|
||||
"subscriptions",
|
||||
"ai_chat_sessions",
|
||||
"ai_conversations",
|
||||
"ai_session_steps",
|
||||
"ai_session_embeddings",
|
||||
"ai_suggestions",
|
||||
"ai_usage",
|
||||
"assistant_chats",
|
||||
"attachments",
|
||||
"copilot_conversations",
|
||||
"feedback",
|
||||
"file_uploads",
|
||||
"fork_points",
|
||||
"kb_imports",
|
||||
"notifications",
|
||||
"notification_configs",
|
||||
"notification_logs",
|
||||
"psa_activity_logs",
|
||||
"psa_member_mappings",
|
||||
"script_builder_sessions",
|
||||
"session_ratings",
|
||||
"tree_embeddings",
|
||||
"user_folders",
|
||||
"user_pinned_trees",
|
||||
]
|
||||
|
||||
_POLICY_EXPR = (
|
||||
"account_id = COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_POLICY_EXPR})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
404
backend/alembic/versions/f07010f17b01_flowpilot_phase1_schema.py
Normal file
404
backend/alembic/versions/f07010f17b01_flowpilot_phase1_schema.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""FlowPilot migration Phase 1 — schema for the unified session surface.
|
||||
|
||||
Revision ID: f07010f17b01
|
||||
Revises: 074
|
||||
Create Date: 2026-04-17
|
||||
|
||||
Creates the backing store for the FlowPilot unified session surface:
|
||||
|
||||
- `session_facts` — "What we know" facts, keyed to a session, with a polymorphic
|
||||
`source_ref` pointing at a task-lane item inside `ai_sessions.pending_task_lane`
|
||||
(no DB-level FK; integrity enforced at the service layer per the design doc).
|
||||
- `session_suggested_fixes` — AI-proposed resolution paths. Only one active
|
||||
(`superseded_at IS NULL`) per session at a time.
|
||||
- `draft_templates` — scripts pending post-resolve templatization
|
||||
(Option 2 in the three-option dialog).
|
||||
- `account_settings` — new per-account key/value settings table with a JSONB
|
||||
`preferences` grab-bag. Rows are created lazily on first write.
|
||||
- Column additions to `ai_sessions` — resolution/escalation markdown + external IDs,
|
||||
plus `state_version` (incremented by any write that invalidates the resolution
|
||||
note preview cache).
|
||||
- Column additions to `script_templates` — provenance fields for templates
|
||||
promoted from draft_templates.
|
||||
|
||||
All four new tenant-scoped tables have RLS enabled + forced with a
|
||||
`tenant_isolation` policy matching the repo pattern (USING + WITH CHECK on
|
||||
`account_id = app.current_account_id`). Downgrade is reversible: drops in the
|
||||
inverse order of creation.
|
||||
|
||||
Chained from `074` (add_network_diagrams_table) per the single-head state of
|
||||
production; the other local heads on feat/flowpilot-migration are branch
|
||||
artifacts not present in production.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
|
||||
revision = "f07010f17b01"
|
||||
down_revision = "074"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── ai_sessions: resolution / escalation columns + state_version ───────
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("resolution_note_markdown", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("resolution_note_posted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("resolution_note_external_id", sa.String(128), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("escalation_package_markdown", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("escalation_package_posted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column("escalation_package_external_id", sa.String(128), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"ai_sessions",
|
||||
sa.Column(
|
||||
"state_version",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
),
|
||||
)
|
||||
|
||||
# ── script_templates: provenance for post-resolve promotion ────────────
|
||||
op.add_column(
|
||||
"script_templates",
|
||||
sa.Column(
|
||||
"source_session_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("ai_sessions.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"script_templates",
|
||||
sa.Column(
|
||||
"source_user_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"script_templates",
|
||||
sa.Column("source_ticket_ref", sa.String(64), nullable=True),
|
||||
)
|
||||
|
||||
# ── session_facts ──────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"session_facts",
|
||||
sa.Column(
|
||||
"id",
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("ai_sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"account_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("text", sa.Text(), nullable=False),
|
||||
sa.Column("source_type", sa.String(32), nullable=False),
|
||||
# `source_ref` is a polymorphic pointer to a task-lane item inside
|
||||
# ai_sessions.pending_task_lane JSON, NOT a FK to any table.
|
||||
# Integrity enforced at the service layer per Section 4.2 of the
|
||||
# migration design doc.
|
||||
sa.Column("source_ref", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("source_summary", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_by",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"source_type IN ('question', 'diagnostic_check', 'user_note', 'ai_synthesis')",
|
||||
name="ck_session_facts_source_type",
|
||||
),
|
||||
)
|
||||
# Active-facts-per-session; partial index excludes soft-deleted rows.
|
||||
op.create_index(
|
||||
"idx_session_facts_session",
|
||||
"session_facts",
|
||||
["session_id"],
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_session_facts_account",
|
||||
"session_facts",
|
||||
["account_id"],
|
||||
)
|
||||
op.execute("ALTER TABLE session_facts ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE session_facts FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON session_facts
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
# ── session_suggested_fixes ────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"session_suggested_fixes",
|
||||
sa.Column(
|
||||
"id",
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("ai_sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"account_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("title", sa.String(200), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("confidence_pct", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"script_template_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("script_templates.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("ai_drafted_script", sa.Text(), nullable=True),
|
||||
sa.Column("ai_drafted_parameters", JSONB(), nullable=True),
|
||||
sa.Column("user_decision", sa.String(32), nullable=True),
|
||||
sa.Column("superseded_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"confidence_pct BETWEEN 0 AND 100",
|
||||
name="ck_session_suggested_fixes_confidence_pct",
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"user_decision IS NULL OR user_decision IN ("
|
||||
"'one_off', 'draft_template', 'build_template', 'dismissed')",
|
||||
name="ck_session_suggested_fixes_user_decision",
|
||||
),
|
||||
)
|
||||
# Only-one-active-per-session is enforced by service-layer supersession;
|
||||
# this partial index serves the "find active fix" query.
|
||||
op.create_index(
|
||||
"idx_session_suggested_fixes_session_active",
|
||||
"session_suggested_fixes",
|
||||
["session_id"],
|
||||
postgresql_where=sa.text("superseded_at IS NULL"),
|
||||
)
|
||||
op.execute("ALTER TABLE session_suggested_fixes ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE session_suggested_fixes FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON session_suggested_fixes
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
# ── draft_templates ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"draft_templates",
|
||||
sa.Column(
|
||||
"id",
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
sa.Column(
|
||||
"account_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"source_session_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("ai_sessions.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"source_user_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("script_body", sa.Text(), nullable=False),
|
||||
sa.Column("proposed_parameters", JSONB(), nullable=False),
|
||||
sa.Column("proposed_name", sa.String(200), nullable=True),
|
||||
sa.Column(
|
||||
"proposed_category_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("script_categories.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(32),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"promoted_template_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("script_templates.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"status IN ('pending', 'accepted', 'rejected')",
|
||||
name="ck_draft_templates_status",
|
||||
),
|
||||
)
|
||||
# Supports the Script Library "N scripts ready to review" badge.
|
||||
op.create_index(
|
||||
"idx_draft_templates_account_pending",
|
||||
"draft_templates",
|
||||
["account_id"],
|
||||
postgresql_where=sa.text("status = 'pending'"),
|
||||
)
|
||||
op.execute("ALTER TABLE draft_templates ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE draft_templates FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON draft_templates
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
# ── account_settings ───────────────────────────────────────────────────
|
||||
# One row per account, created lazily on first write. The `preferences`
|
||||
# JSONB is a grab-bag for simple settings (e.g. templatize_prompt_enabled).
|
||||
# Settings graduate to typed columns via future migrations when they meet
|
||||
# the promotion criteria in Section 4.6 of the design doc (hot path /
|
||||
# validation / joins).
|
||||
op.create_table(
|
||||
"account_settings",
|
||||
sa.Column(
|
||||
"account_id",
|
||||
UUID(as_uuid=True),
|
||||
sa.ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"preferences",
|
||||
JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.execute("ALTER TABLE account_settings ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE account_settings FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON account_settings
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop in reverse order so FK dependencies unwind cleanly.
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON account_settings")
|
||||
op.execute("ALTER TABLE account_settings DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_table("account_settings")
|
||||
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON draft_templates")
|
||||
op.execute("ALTER TABLE draft_templates DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_index("idx_draft_templates_account_pending", table_name="draft_templates")
|
||||
op.drop_table("draft_templates")
|
||||
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON session_suggested_fixes")
|
||||
op.execute("ALTER TABLE session_suggested_fixes DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_index(
|
||||
"idx_session_suggested_fixes_session_active",
|
||||
table_name="session_suggested_fixes",
|
||||
)
|
||||
op.drop_table("session_suggested_fixes")
|
||||
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON session_facts")
|
||||
op.execute("ALTER TABLE session_facts DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_index("idx_session_facts_account", table_name="session_facts")
|
||||
op.drop_index("idx_session_facts_session", table_name="session_facts")
|
||||
op.drop_table("session_facts")
|
||||
|
||||
op.drop_column("script_templates", "source_ticket_ref")
|
||||
op.drop_column("script_templates", "source_user_id")
|
||||
op.drop_column("script_templates", "source_session_id")
|
||||
|
||||
op.drop_column("ai_sessions", "state_version")
|
||||
op.drop_column("ai_sessions", "escalation_package_external_id")
|
||||
op.drop_column("ai_sessions", "escalation_package_posted_at")
|
||||
op.drop_column("ai_sessions", "escalation_package_markdown")
|
||||
op.drop_column("ai_sessions", "resolution_note_external_id")
|
||||
op.drop_column("ai_sessions", "resolution_note_posted_at")
|
||||
op.drop_column("ai_sessions", "resolution_note_markdown")
|
||||
@@ -24,10 +24,14 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
token: Annotated[str, Depends(oauth2_scheme)]
|
||||
) -> User:
|
||||
"""Get current authenticated user from JWT token."""
|
||||
"""Get current authenticated user from JWT token.
|
||||
|
||||
Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context
|
||||
sets app.current_account_id, so the users table RLS would block the lookup.
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -77,10 +81,14 @@ async def get_refresh_token_payload(
|
||||
async def get_current_active_user(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> User:
|
||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
||||
Enforces must_change_password — blocks all routes except allowlist."""
|
||||
Enforces must_change_password — blocks all routes except allowlist.
|
||||
|
||||
Uses get_admin_db: runs before require_tenant_context sets the ContextVar,
|
||||
so tenant-scoped tables (subscriptions) would return 0 rows via app role.
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.core.audit import log_audit
|
||||
from app.models.refresh_token import RefreshToken
|
||||
@@ -148,7 +149,7 @@ async def update_member_role(
|
||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||
async def transfer_ownership(
|
||||
data: TransferOwnershipRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Transfer account ownership to another member (owner only)."""
|
||||
@@ -377,7 +378,7 @@ async def list_invites(
|
||||
|
||||
@router.post("/me/leave")
|
||||
async def leave_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Leave the current account (non-owners only). Creates a personal account."""
|
||||
@@ -423,7 +424,7 @@ class DeleteAccountRequest(BaseModel):
|
||||
@router.delete("/me")
|
||||
async def delete_account(
|
||||
data: DeleteAccountRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.orm import selectinload, aliased
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.audit import log_audit
|
||||
@@ -24,21 +24,44 @@ from app.models.invite_code import InviteCode
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.tree import Tree
|
||||
from app.schemas.user import UserResponse, RoleUpdate, AccountRoleUpdate
|
||||
from app.schemas.admin import MoveUserAccount, AdminUserCreate, AdminUserCreateResponse, AdminPasswordReset, AdminPasswordResetResponse, HardDeleteCheckResponse
|
||||
from app.schemas.admin import (
|
||||
MoveUserAccount,
|
||||
AdminUserCreate,
|
||||
AdminUserCreateResponse,
|
||||
AdminPasswordReset,
|
||||
AdminPasswordResetResponse,
|
||||
HardDeleteCheckResponse,
|
||||
AdminUserListItem,
|
||||
AdminUserListResponse,
|
||||
AdminAccountMember,
|
||||
AdminAccountListItem,
|
||||
AdminAccountListResponse,
|
||||
AdminAccountOwnerSummary,
|
||||
AdminAccountSubscriptionSummary,
|
||||
AdminAccountUsageSummary,
|
||||
AdminAccountDetailResponse,
|
||||
AdminAccountInviteSummary,
|
||||
AdminAccountCreate,
|
||||
AdminAccountUpdate,
|
||||
)
|
||||
from app.schemas.subscription import SubscriptionPlanUpdate, ExtendTrialRequest
|
||||
from app.schemas.user_detail import (
|
||||
UserDetailResponse, AccountSummary, SubscriptionSummary,
|
||||
SessionSummary, AuditLogSummary, InviteCodeUsedSummary,
|
||||
)
|
||||
from app.api.deps import require_admin
|
||||
from app.core.subscriptions import get_account_usage
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserResponse])
|
||||
@router.get("/users", response_model=AdminUserListResponse)
|
||||
async def list_users(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
page: Optional[int] = Query(None, ge=1),
|
||||
size: Optional[int] = Query(None, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="Search by user or account fields"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
@@ -46,23 +69,240 @@ async def list_users(
|
||||
account_id: Optional[UUID] = Query(None, description="Filter by account"),
|
||||
include_archived: bool = Query(False, description="Include archived (soft-deleted) users"),
|
||||
):
|
||||
"""List all users (super admin only)."""
|
||||
query = select(User)
|
||||
"""List users for super admin global people search."""
|
||||
resolved_limit = size or limit
|
||||
resolved_skip = skip
|
||||
current_page = 1
|
||||
|
||||
if page is not None:
|
||||
resolved_skip = (page - 1) * resolved_limit
|
||||
current_page = page
|
||||
elif resolved_limit > 0:
|
||||
current_page = (resolved_skip // resolved_limit) + 1
|
||||
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.outerjoin(Account, User.account_id == Account.id)
|
||||
)
|
||||
query = (
|
||||
select(
|
||||
User,
|
||||
Account.name.label("account_name"),
|
||||
Account.display_code.label("account_display_code"),
|
||||
)
|
||||
.outerjoin(Account, User.account_id == Account.id)
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
count_query = count_query.where(User.deleted_at.is_(None))
|
||||
if is_active is not None:
|
||||
query = query.where(User.is_active == is_active)
|
||||
count_query = count_query.where(User.is_active == is_active)
|
||||
if role:
|
||||
query = query.where(User.role == role)
|
||||
count_query = count_query.where(User.role == role)
|
||||
if account_id:
|
||||
query = query.where(User.account_id == account_id)
|
||||
count_query = count_query.where(User.account_id == account_id)
|
||||
if search:
|
||||
search_term = f"%{search.strip()}%"
|
||||
search_filter = or_(
|
||||
User.name.ilike(search_term),
|
||||
User.email.ilike(search_term),
|
||||
Account.name.ilike(search_term),
|
||||
Account.display_code.ilike(search_term),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(skip).limit(limit)
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
query = query.order_by(User.created_at.desc()).offset(resolved_skip).limit(resolved_limit)
|
||||
result = await db.execute(query)
|
||||
users = result.scalars().all()
|
||||
return users
|
||||
rows = result.all()
|
||||
|
||||
items = [
|
||||
AdminUserListItem(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
role=user.role,
|
||||
is_super_admin=user.is_super_admin,
|
||||
is_active=user.is_active,
|
||||
account_id=user.account_id,
|
||||
account_role=user.account_role,
|
||||
account_name=account_name,
|
||||
account_display_code=account_display_code,
|
||||
created_at=user.created_at,
|
||||
last_login=user.last_login,
|
||||
deleted_at=user.deleted_at,
|
||||
)
|
||||
for user, account_name, account_display_code in rows
|
||||
]
|
||||
|
||||
return AdminUserListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=current_page,
|
||||
per_page=resolved_limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/accounts", response_model=AdminAccountListResponse)
|
||||
async def list_accounts(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(12, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="Search by account, display code, or owner"),
|
||||
plan: Optional[str] = Query(None, description="Filter by subscription plan"),
|
||||
status: Optional[str] = Query(None, description="Filter by subscription status"),
|
||||
include_archived: bool = Query(False, description="Include archived users in account member lists"),
|
||||
):
|
||||
"""List accounts with embedded members for the admin panel."""
|
||||
owner_user = aliased(User)
|
||||
|
||||
count_query = (
|
||||
select(func.count(func.distinct(Account.id)))
|
||||
.select_from(Account)
|
||||
.outerjoin(owner_user, Account.owner_id == owner_user.id)
|
||||
.outerjoin(Subscription, Subscription.account_id == Account.id)
|
||||
)
|
||||
accounts_query = (
|
||||
select(
|
||||
Account,
|
||||
owner_user.id.label("owner_user_id"),
|
||||
owner_user.name.label("owner_name"),
|
||||
owner_user.email.label("owner_email"),
|
||||
Subscription.id.label("subscription_id"),
|
||||
Subscription.plan.label("subscription_plan"),
|
||||
Subscription.status.label("subscription_status"),
|
||||
Subscription.billing_interval.label("subscription_billing_interval"),
|
||||
Subscription.current_period_end.label("subscription_current_period_end"),
|
||||
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
|
||||
)
|
||||
.outerjoin(owner_user, Account.owner_id == owner_user.id)
|
||||
.outerjoin(Subscription, Subscription.account_id == Account.id)
|
||||
)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search.strip()}%"
|
||||
search_filter = or_(
|
||||
Account.name.ilike(search_term),
|
||||
Account.display_code.ilike(search_term),
|
||||
owner_user.name.ilike(search_term),
|
||||
owner_user.email.ilike(search_term),
|
||||
)
|
||||
count_query = count_query.where(search_filter)
|
||||
accounts_query = accounts_query.where(search_filter)
|
||||
if plan:
|
||||
count_query = count_query.where(Subscription.plan == plan)
|
||||
accounts_query = accounts_query.where(Subscription.plan == plan)
|
||||
if status:
|
||||
count_query = count_query.where(Subscription.status == status)
|
||||
accounts_query = accounts_query.where(Subscription.status == status)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
accounts_result = await db.execute(
|
||||
accounts_query
|
||||
.order_by(Account.created_at.desc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
rows = accounts_result.all()
|
||||
accounts = [row.Account for row in rows]
|
||||
|
||||
account_ids = [account.id for account in accounts]
|
||||
members_by_account: dict[UUID, list[AdminAccountMember]] = {account_id: [] for account_id in account_ids}
|
||||
pending_invites_by_account: dict[UUID, int] = {account_id: 0 for account_id in account_ids}
|
||||
usage_by_account: dict[UUID, AdminAccountUsageSummary] = {}
|
||||
|
||||
if account_ids:
|
||||
members_query = select(User).where(User.account_id.in_(account_ids))
|
||||
if not include_archived:
|
||||
members_query = members_query.where(User.deleted_at.is_(None))
|
||||
members_query = members_query.order_by(User.created_at.asc())
|
||||
|
||||
members_result = await db.execute(members_query)
|
||||
for member in members_result.scalars().all():
|
||||
members_by_account.setdefault(member.account_id, []).append(
|
||||
AdminAccountMember(
|
||||
id=member.id,
|
||||
email=member.email,
|
||||
name=member.name,
|
||||
role=member.role,
|
||||
is_super_admin=member.is_super_admin,
|
||||
is_active=member.is_active,
|
||||
account_role=member.account_role,
|
||||
created_at=member.created_at,
|
||||
last_login=member.last_login,
|
||||
deleted_at=member.deleted_at,
|
||||
)
|
||||
)
|
||||
|
||||
pending_invites_result = await db.execute(
|
||||
select(AccountInvite.account_id, func.count(AccountInvite.id))
|
||||
.where(
|
||||
AccountInvite.account_id.in_(account_ids),
|
||||
AccountInvite.used_at.is_(None),
|
||||
)
|
||||
.group_by(AccountInvite.account_id)
|
||||
)
|
||||
pending_invites_by_account.update({row[0]: row[1] for row in pending_invites_result.all()})
|
||||
|
||||
for account_id in account_ids:
|
||||
usage = await get_account_usage(account_id, db)
|
||||
usage_by_account[account_id] = AdminAccountUsageSummary(
|
||||
tree_count=usage.get("tree_count", 0),
|
||||
session_count_this_month=usage.get("session_count_this_month", 0),
|
||||
)
|
||||
|
||||
items = [
|
||||
AdminAccountListItem(
|
||||
id=row.Account.id,
|
||||
name=row.Account.name,
|
||||
display_code=row.Account.display_code,
|
||||
created_at=row.Account.created_at,
|
||||
owner_id=row.Account.owner_id,
|
||||
owner=(
|
||||
AdminAccountOwnerSummary(
|
||||
id=row.owner_user_id,
|
||||
name=row.owner_name,
|
||||
email=row.owner_email,
|
||||
) if row.owner_user_id and row.owner_name and row.owner_email else None
|
||||
),
|
||||
subscription=(
|
||||
AdminAccountSubscriptionSummary(
|
||||
id=row.subscription_id,
|
||||
plan=row.subscription_plan,
|
||||
status=row.subscription_status,
|
||||
billing_interval=row.subscription_billing_interval,
|
||||
current_period_end=row.subscription_current_period_end,
|
||||
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
|
||||
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
|
||||
),
|
||||
usage=usage_by_account.get(row.Account.id, AdminAccountUsageSummary()),
|
||||
member_count=len(members_by_account.get(row.Account.id, [])),
|
||||
active_member_count=sum(1 for member in members_by_account.get(row.Account.id, []) if member.is_active),
|
||||
pending_invite_count=pending_invites_by_account.get(row.Account.id, 0),
|
||||
sso_enabled=row.Account.sso_enabled,
|
||||
branding_company_name=row.Account.branding_company_name,
|
||||
members=members_by_account.get(row.Account.id, []),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return AdminAccountListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=size,
|
||||
)
|
||||
|
||||
|
||||
def _generate_display_code() -> str:
|
||||
@@ -71,6 +311,192 @@ def _generate_display_code() -> str:
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
async def _generate_unique_display_code(db: AsyncSession) -> str:
|
||||
"""Generate a unique display code for a new account."""
|
||||
while True:
|
||||
display_code = _generate_display_code()
|
||||
existing = await db.execute(select(Account.id).where(Account.display_code == display_code))
|
||||
if existing.scalar_one_or_none() is None:
|
||||
return display_code
|
||||
|
||||
|
||||
async def _get_account_detail_payload(
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
include_archived: bool = False,
|
||||
) -> AdminAccountDetailResponse:
|
||||
owner_user = aliased(User)
|
||||
result = await db.execute(
|
||||
select(
|
||||
Account,
|
||||
owner_user.id.label("owner_user_id"),
|
||||
owner_user.name.label("owner_name"),
|
||||
owner_user.email.label("owner_email"),
|
||||
Subscription.id.label("subscription_id"),
|
||||
Subscription.plan.label("subscription_plan"),
|
||||
Subscription.status.label("subscription_status"),
|
||||
Subscription.billing_interval.label("subscription_billing_interval"),
|
||||
Subscription.current_period_end.label("subscription_current_period_end"),
|
||||
Subscription.cancel_at_period_end.label("subscription_cancel_at_period_end"),
|
||||
)
|
||||
.outerjoin(owner_user, Account.owner_id == owner_user.id)
|
||||
.outerjoin(Subscription, Subscription.account_id == Account.id)
|
||||
.where(Account.id == account_id)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
if not row:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
members_query = select(User).where(User.account_id == account_id).order_by(User.created_at.asc())
|
||||
if not include_archived:
|
||||
members_query = members_query.where(User.deleted_at.is_(None))
|
||||
members_result = await db.execute(members_query)
|
||||
members = [
|
||||
AdminAccountMember(
|
||||
id=member.id,
|
||||
email=member.email,
|
||||
name=member.name,
|
||||
role=member.role,
|
||||
is_super_admin=member.is_super_admin,
|
||||
is_active=member.is_active,
|
||||
account_role=member.account_role,
|
||||
created_at=member.created_at,
|
||||
last_login=member.last_login,
|
||||
deleted_at=member.deleted_at,
|
||||
)
|
||||
for member in members_result.scalars().all()
|
||||
]
|
||||
|
||||
invites_result = await db.execute(
|
||||
select(AccountInvite)
|
||||
.where(AccountInvite.account_id == account_id)
|
||||
.order_by(AccountInvite.created_at.desc())
|
||||
)
|
||||
invites = [
|
||||
AdminAccountInviteSummary(
|
||||
id=invite.id,
|
||||
email=invite.email,
|
||||
role=invite.role,
|
||||
expires_at=invite.expires_at,
|
||||
created_at=invite.created_at,
|
||||
used_at=invite.used_at,
|
||||
)
|
||||
for invite in invites_result.scalars().all()
|
||||
if invite.used_at is None
|
||||
]
|
||||
|
||||
usage = await get_account_usage(account_id, db)
|
||||
|
||||
return AdminAccountDetailResponse(
|
||||
id=row.Account.id,
|
||||
name=row.Account.name,
|
||||
display_code=row.Account.display_code,
|
||||
created_at=row.Account.created_at,
|
||||
owner_id=row.Account.owner_id,
|
||||
owner=(
|
||||
AdminAccountOwnerSummary(
|
||||
id=row.owner_user_id,
|
||||
name=row.owner_name,
|
||||
email=row.owner_email,
|
||||
) if row.owner_user_id and row.owner_name and row.owner_email else None
|
||||
),
|
||||
subscription=(
|
||||
AdminAccountSubscriptionSummary(
|
||||
id=row.subscription_id,
|
||||
plan=row.subscription_plan,
|
||||
status=row.subscription_status,
|
||||
billing_interval=row.subscription_billing_interval,
|
||||
current_period_end=row.subscription_current_period_end,
|
||||
cancel_at_period_end=row.subscription_cancel_at_period_end or False,
|
||||
) if row.subscription_id and row.subscription_plan and row.subscription_status else None
|
||||
),
|
||||
usage=AdminAccountUsageSummary(
|
||||
tree_count=usage.get("tree_count", 0),
|
||||
session_count_this_month=usage.get("session_count_this_month", 0),
|
||||
),
|
||||
member_count=len(members),
|
||||
active_member_count=sum(1 for member in members if member.is_active),
|
||||
pending_invite_count=len(invites),
|
||||
sso_enabled=row.Account.sso_enabled,
|
||||
branding_company_name=row.Account.branding_company_name,
|
||||
members=members,
|
||||
invites=invites,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/accounts", response_model=AdminAccountDetailResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_account(
|
||||
data: AdminAccountCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Create a new account without requiring an initial user."""
|
||||
owner_id = None
|
||||
if data.owner_email:
|
||||
result = await db.execute(select(User).where(User.email == data.owner_email.strip()))
|
||||
owner = result.scalar_one_or_none()
|
||||
if not owner:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No user found with email '{data.owner_email}'")
|
||||
owner_id = owner.id
|
||||
|
||||
display_code = await _generate_unique_display_code(db)
|
||||
new_account = Account(
|
||||
name=data.name.strip(),
|
||||
display_code=display_code,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_subscription = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan=data.plan,
|
||||
status="active",
|
||||
)
|
||||
db.add(new_subscription)
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "account.create_admin", "account", new_account.id,
|
||||
{"name": new_account.name, "plan": data.plan, "owner_email": data.owner_email},
|
||||
)
|
||||
await db.commit()
|
||||
return await _get_account_detail_payload(new_account.id, db)
|
||||
|
||||
|
||||
@router.get("/accounts/{account_id}", response_model=AdminAccountDetailResponse)
|
||||
async def get_account_detail(
|
||||
account_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
include_archived: bool = Query(False),
|
||||
):
|
||||
"""Get detailed account information for admin management."""
|
||||
return await _get_account_detail_payload(account_id, db, include_archived=include_archived)
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}", response_model=AdminAccountDetailResponse)
|
||||
async def update_account(
|
||||
account_id: UUID,
|
||||
data: AdminAccountUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Update account settings from the admin panel."""
|
||||
result = await db.execute(select(Account).where(Account.id == account_id))
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
old_name = account.name
|
||||
account.name = data.name.strip()
|
||||
await log_audit(
|
||||
db, current_user.id, "account.update_admin", "account", account.id,
|
||||
{"old_name": old_name, "new_name": account.name},
|
||||
)
|
||||
await db.commit()
|
||||
return await _get_account_detail_payload(account.id, db)
|
||||
|
||||
|
||||
@router.post("/users", response_model=AdminUserCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
data: AdminUserCreate,
|
||||
@@ -516,6 +942,28 @@ async def _get_user_subscription(user_id: UUID, db: AsyncSession) -> tuple[User,
|
||||
return user, subscription
|
||||
|
||||
|
||||
async def _get_account_subscription(account_id: UUID, db: AsyncSession) -> tuple[Account, Subscription]:
|
||||
"""Helper to load account and its subscription."""
|
||||
account_result = await db.execute(select(Account).where(Account.id == account_id))
|
||||
account = account_result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
sub_result = await db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account.id)
|
||||
)
|
||||
subscription = sub_result.scalar_one_or_none()
|
||||
if not subscription:
|
||||
subscription = Subscription(
|
||||
account_id=account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(subscription)
|
||||
await db.flush()
|
||||
return account, subscription
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/subscription/plan")
|
||||
async def update_user_plan(
|
||||
user_id: UUID,
|
||||
@@ -535,6 +983,31 @@ async def update_user_plan(
|
||||
return {"plan": subscription.plan, "status": subscription.status}
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}/subscription/plan")
|
||||
async def update_account_plan(
|
||||
account_id: UUID,
|
||||
data: SubscriptionPlanUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Change an account subscription plan (super admin only)."""
|
||||
if data.plan not in ("free", "pro", "team"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid plan")
|
||||
account, subscription = await _get_account_subscription(account_id, db)
|
||||
old_plan = subscription.plan
|
||||
subscription.plan = data.plan
|
||||
await log_audit(
|
||||
db,
|
||||
current_user.id,
|
||||
"subscription.plan_change",
|
||||
"subscription",
|
||||
subscription.id,
|
||||
{"old_plan": old_plan, "new_plan": data.plan, "account_id": str(account_id)},
|
||||
)
|
||||
await db.commit()
|
||||
return {"plan": subscription.plan, "status": subscription.status}
|
||||
|
||||
|
||||
@router.put("/users/{user_id}/subscription/extend-trial")
|
||||
async def extend_user_trial(
|
||||
user_id: UUID,
|
||||
@@ -565,6 +1038,43 @@ async def extend_user_trial(
|
||||
"current_period_end": subscription.current_period_end}
|
||||
|
||||
|
||||
@router.put("/accounts/{account_id}/subscription/extend-trial")
|
||||
async def extend_account_trial(
|
||||
account_id: UUID,
|
||||
data: ExtendTrialRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_admin)],
|
||||
):
|
||||
"""Extend or start a trial for an account subscription (super admin only)."""
|
||||
if data.days < 1 or data.days > 90:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Days must be 1-90")
|
||||
account, subscription = await _get_account_subscription(account_id, db)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if subscription.status == "trialing" and subscription.current_period_end:
|
||||
new_end = subscription.current_period_end + timedelta(days=data.days)
|
||||
else:
|
||||
subscription.status = "trialing"
|
||||
subscription.current_period_start = now
|
||||
new_end = now + timedelta(days=data.days)
|
||||
|
||||
subscription.current_period_end = new_end
|
||||
await log_audit(
|
||||
db,
|
||||
current_user.id,
|
||||
"subscription.extend_trial",
|
||||
"subscription",
|
||||
subscription.id,
|
||||
{"days": data.days, "new_end": new_end.isoformat(), "account_id": str(account.id)},
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
"plan": subscription.plan,
|
||||
"status": subscription.status,
|
||||
"current_period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/password-reset", response_model=AdminPasswordResetResponse)
|
||||
async def admin_reset_password(
|
||||
user_id: UUID,
|
||||
|
||||
@@ -43,6 +43,7 @@ async def create_suggestion(
|
||||
suggestion = AISuggestion(
|
||||
tree_id=data.tree_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
session_id=data.session_id,
|
||||
action_type=data.action_type,
|
||||
target_node_id=data.target_node_id,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.core.config import settings
|
||||
from app.core.settings_manager import SettingsManager
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
@@ -67,7 +67,7 @@ def _generate_display_code() -> str:
|
||||
async def register(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Register a new user.
|
||||
|
||||
@@ -232,7 +232,7 @@ async def register(
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login and get access token."""
|
||||
# Find user by email
|
||||
@@ -270,7 +270,7 @@ async def login(
|
||||
async def login_json(
|
||||
request: Request,
|
||||
credentials: UserLogin,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login with JSON body (alternative to form data)."""
|
||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||
@@ -304,7 +304,7 @@ async def login_json(
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
||||
user_id = payload.get("sub")
|
||||
@@ -368,7 +368,7 @@ async def get_me(
|
||||
async def update_me(
|
||||
data: UserUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Update current user's profile (name, email)."""
|
||||
update_fields = data.model_fields_set - {"current_password"}
|
||||
@@ -415,7 +415,7 @@ async def update_me(
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Logout user by revoking the refresh token."""
|
||||
jti = payload.get("jti")
|
||||
@@ -438,7 +438,7 @@ async def change_password(
|
||||
request: Request,
|
||||
data: ChangePasswordRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
@@ -478,7 +478,7 @@ async def change_password(
|
||||
async def forgot_password(
|
||||
request: Request,
|
||||
data: ForgotPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
@@ -513,7 +513,7 @@ async def forgot_password(
|
||||
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
||||
async def verify_reset_token(
|
||||
data: VerifyResetTokenRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify a password reset token is valid."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -544,7 +544,7 @@ async def verify_reset_token(
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
data: ResetPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Reset password using a valid reset token."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -611,7 +611,7 @@ async def reset_password(
|
||||
|
||||
@router.get("/email/verification-status")
|
||||
async def get_verification_status(
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Check if email verification is enabled on the platform."""
|
||||
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -623,7 +623,7 @@ async def get_verification_status(
|
||||
async def send_verification_email(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Send an email verification link to the current user."""
|
||||
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -662,7 +662,7 @@ async def send_verification_email(
|
||||
@router.post("/email/verify")
|
||||
async def verify_email(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify an email using a token. Public endpoint."""
|
||||
token = data.get("token")
|
||||
|
||||
120
backend/app/api/endpoints/device_types.py
Normal file
120
backend/app/api/endpoints/device_types.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Device types API endpoints."""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.device_type import DeviceType
|
||||
from app.schemas.device_type import (
|
||||
DeviceTypeCreate,
|
||||
DeviceTypeUpdate,
|
||||
DeviceTypeResponse,
|
||||
)
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
router = APIRouter(prefix="/device-types", tags=["device-types"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[DeviceTypeResponse])
|
||||
async def list_device_types(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[DeviceTypeResponse]:
|
||||
stmt = (
|
||||
select(DeviceType)
|
||||
.where(
|
||||
or_(
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [DeviceTypeResponse.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/", response_model=DeviceTypeResponse, status_code=201)
|
||||
async def create_device_type(
|
||||
data: DeviceTypeCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DeviceTypeResponse:
|
||||
existing = await db.execute(
|
||||
select(DeviceType).where(
|
||||
DeviceType.slug == data.slug,
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account")
|
||||
|
||||
system_existing = await db.execute(
|
||||
select(DeviceType).where(
|
||||
DeviceType.slug == data.slug,
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
)
|
||||
)
|
||||
if system_existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' conflicts with a system type")
|
||||
|
||||
device_type = DeviceType(
|
||||
slug=data.slug,
|
||||
label=data.label,
|
||||
category=data.category,
|
||||
is_system=False,
|
||||
account_id=current_user.account_id,
|
||||
sort_order=data.sort_order,
|
||||
)
|
||||
db.add(device_type)
|
||||
await db.commit()
|
||||
await db.refresh(device_type)
|
||||
return DeviceTypeResponse.model_validate(device_type)
|
||||
|
||||
|
||||
@router.put("/{device_type_id}", response_model=DeviceTypeResponse)
|
||||
async def update_device_type(
|
||||
device_type_id: UUID,
|
||||
data: DeviceTypeUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DeviceTypeResponse:
|
||||
device_type = await db.get(DeviceType, device_type_id)
|
||||
if not device_type:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
if device_type.is_system:
|
||||
raise HTTPException(status_code=403, detail="Cannot modify system device types")
|
||||
if device_type.account_id != current_user.account_id:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(device_type, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(device_type)
|
||||
return DeviceTypeResponse.model_validate(device_type)
|
||||
|
||||
|
||||
@router.delete("/{device_type_id}", status_code=204)
|
||||
async def delete_device_type(
|
||||
device_type_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> None:
|
||||
device_type = await db.get(DeviceType, device_type_id)
|
||||
if not device_type:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
if device_type.is_system:
|
||||
raise HTTPException(status_code=403, detail="Cannot delete system device types")
|
||||
if device_type.account_id != current_user.account_id:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
|
||||
await db.delete(device_type)
|
||||
await db.commit()
|
||||
@@ -27,6 +27,7 @@ from app.schemas.psa_connection import (
|
||||
PsaMemberMappingSaveRequest,
|
||||
PsaMemberResponse,
|
||||
AutoMatchResult,
|
||||
PSABoardResponse,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.services.psa.encryption import (
|
||||
@@ -345,26 +346,103 @@ async def update_flowpilot_settings(
|
||||
# ── ticket / status / company endpoints ──────────────────────────
|
||||
|
||||
|
||||
@router.get("/tickets/search", response_model=list[PSATicketSearchResult])
|
||||
async def search_tickets(
|
||||
@router.get("/boards", response_model=list[PSABoardResponse])
|
||||
async def list_boards(
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
query: str = "",
|
||||
board_id: int | None = None,
|
||||
status_id: int | None = None,
|
||||
include_closed: bool = False,
|
||||
):
|
||||
"""Search ConnectWise tickets."""
|
||||
"""List PSA service boards."""
|
||||
if not current_user.account_id:
|
||||
raise HTTPException(status_code=400, detail="User has no account")
|
||||
|
||||
from app.services.psa.registry import get_provider_for_account
|
||||
from app.services.psa.exceptions import PSAError
|
||||
|
||||
try:
|
||||
provider = await get_provider_for_account(current_user.account_id, db)
|
||||
boards = await provider.list_boards()
|
||||
return [PSABoardResponse(id=b.id, name=b.name) for b in boards]
|
||||
except PSAError:
|
||||
# Boards are optional UI chrome — degrade gracefully rather than surfacing a toast
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/tickets/search", response_model=list[PSATicketSearchResult])
|
||||
async def search_tickets(
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
query: str = "",
|
||||
board_id: int | None = None,
|
||||
status_id: int | None = None,
|
||||
include_closed: bool = False,
|
||||
assigned_to_me: bool = False,
|
||||
unassigned: bool = False,
|
||||
board_ids: str = "",
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
):
|
||||
"""Search ConnectWise tickets."""
|
||||
if not current_user.account_id:
|
||||
raise HTTPException(status_code=400, detail="User has no account")
|
||||
|
||||
from app.services.psa.registry import get_provider_for_account
|
||||
from app.services.psa.exceptions import PSAError
|
||||
|
||||
# Resolve assigned_to_me → member_identifier (CW login name for resources contains filter)
|
||||
member_identifier: str | None = None
|
||||
if assigned_to_me:
|
||||
conn_result = await db.execute(
|
||||
select(PsaConnection).where(
|
||||
PsaConnection.account_id == current_user.account_id,
|
||||
PsaConnection.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
conn = conn_result.scalar_one_or_none()
|
||||
if conn:
|
||||
mapping_result = await db.execute(
|
||||
select(PsaMemberMapping).where(
|
||||
PsaMemberMapping.psa_connection_id == conn.id,
|
||||
PsaMemberMapping.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
mapping = mapping_result.scalar_one_or_none()
|
||||
if not mapping:
|
||||
# No mapping for this user — return empty list
|
||||
return []
|
||||
|
||||
from app.services.psa.registry import get_provider_for_account as _get_provider
|
||||
from app.services.psa.exceptions import PSAError as _PSAError
|
||||
try:
|
||||
_provider = await _get_provider(current_user.account_id, db)
|
||||
cw_members = await _provider.list_members()
|
||||
matched = next((m for m in cw_members if m.id == mapping.external_member_id), None)
|
||||
if matched:
|
||||
member_identifier = matched.identifier
|
||||
else:
|
||||
return []
|
||||
except _PSAError:
|
||||
return []
|
||||
|
||||
# Parse comma-separated board_ids
|
||||
parsed_board_ids: list[int] = []
|
||||
if board_ids:
|
||||
try:
|
||||
parsed_board_ids = [int(bid.strip()) for bid in board_ids.split(",") if bid.strip()]
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="board_ids must be comma-separated integers")
|
||||
|
||||
try:
|
||||
provider = await get_provider_for_account(current_user.account_id, db)
|
||||
tickets = await provider.search_tickets(
|
||||
query, board_id=board_id, status_id=status_id, include_closed=include_closed
|
||||
query,
|
||||
board_id=board_id,
|
||||
status_id=status_id,
|
||||
include_closed=include_closed,
|
||||
member_identifier=member_identifier,
|
||||
unassigned=unassigned,
|
||||
board_ids=parsed_board_ids,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return [
|
||||
PSATicketSearchResult(
|
||||
@@ -517,31 +595,37 @@ async def get_member_mappings(
|
||||
current_user: Annotated[User, Depends(require_account_owner)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Get all member mappings for the account."""
|
||||
"""Get all account users with their PSA member mappings (unmapped users included)."""
|
||||
conn = await _get_account_connection(current_user.account_id, db)
|
||||
if not conn:
|
||||
return []
|
||||
|
||||
result = await db.execute(
|
||||
# Fetch all active account users
|
||||
users_result = await db.execute(
|
||||
select(User).where(User.account_id == current_user.account_id, User.is_active.is_(True))
|
||||
)
|
||||
users = users_result.scalars().all()
|
||||
|
||||
# Fetch all existing mappings keyed by user_id for O(1) lookup
|
||||
mappings_result = await db.execute(
|
||||
select(PsaMemberMapping).where(PsaMemberMapping.psa_connection_id == conn.id)
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
mapping_by_user: dict[str, PsaMemberMapping] = {
|
||||
str(m.user_id): m for m in mappings_result.scalars().all()
|
||||
}
|
||||
|
||||
response = []
|
||||
for m in mappings:
|
||||
user_result = await db.execute(select(User).where(User.id == m.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user:
|
||||
response.append(PsaMemberMappingResponse(
|
||||
id=str(m.id),
|
||||
user_id=str(m.user_id),
|
||||
return [
|
||||
PsaMemberMappingResponse(
|
||||
id=str(m.id) if (m := mapping_by_user.get(str(user.id))) else None,
|
||||
user_id=str(user.id),
|
||||
user_email=user.email,
|
||||
user_name=user.name,
|
||||
external_member_id=m.external_member_id,
|
||||
external_member_name=m.external_member_name,
|
||||
matched_by=m.matched_by,
|
||||
))
|
||||
return response
|
||||
external_member_id=m.external_member_id if m else None,
|
||||
external_member_name=m.external_member_name if m else None,
|
||||
matched_by=m.matched_by if m else None,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
|
||||
@router.post("/member-mappings", response_model=list[PsaMemberMappingResponse])
|
||||
@@ -564,6 +648,7 @@ async def save_member_mappings(
|
||||
for m in mappings:
|
||||
mapping = PsaMemberMapping(
|
||||
psa_connection_id=conn.id,
|
||||
account_id=current_user.account_id,
|
||||
user_id=UUID(m.user_id),
|
||||
external_member_id=m.external_member_id,
|
||||
external_member_name=m.external_member_name,
|
||||
@@ -624,6 +709,7 @@ async def auto_match_members(
|
||||
if not existing.scalar_one_or_none():
|
||||
mapping = PsaMemberMapping(
|
||||
psa_connection_id=conn.id,
|
||||
account_id=current_user.account_id,
|
||||
user_id=user.id,
|
||||
external_member_id=cw_member.id,
|
||||
external_member_name=cw_member.name,
|
||||
|
||||
@@ -69,6 +69,7 @@ async def create_schedule(
|
||||
|
||||
schedule = MaintenanceSchedule(
|
||||
tree_id=data.tree_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
cron_expression=data.cron_expression,
|
||||
timezone=data.timezone,
|
||||
|
||||
362
backend/app/api/endpoints/network_diagrams.py
Normal file
362
backend/app/api/endpoints/network_diagrams.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Network diagrams API endpoints."""
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.device_type import DeviceType
|
||||
from app.models.network_diagram import NetworkDiagram
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
from app.schemas.network_diagram import (
|
||||
NetworkDiagramCreate,
|
||||
NetworkDiagramUpdate,
|
||||
NetworkDiagramResponse,
|
||||
NetworkDiagramListItem,
|
||||
AIGenerateRequest,
|
||||
AIGenerateResponse,
|
||||
DiagramImportRequest,
|
||||
DiagramImportResponse,
|
||||
DiagramExportResponse,
|
||||
DiagramNode,
|
||||
DiagramEdge,
|
||||
)
|
||||
from app.services import network_diagram_ai_service, storage_service
|
||||
|
||||
# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts
|
||||
_SLUG_CATEGORY: dict[str, str] = {
|
||||
"router": "network", "switch": "network", "access-point": "network", "load-balancer": "network",
|
||||
"firewall": "security", "badge-reader": "security",
|
||||
"server": "compute", "vm": "compute", "container": "compute",
|
||||
"nas": "storage", "san": "storage", "cloud-storage": "storage",
|
||||
"cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud",
|
||||
"workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint",
|
||||
"phone": "endpoint", "printer": "endpoint",
|
||||
"ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure",
|
||||
"patch-panel": "infrastructure", "camera": "infrastructure",
|
||||
"nvr": "infrastructure", "iot": "infrastructure",
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
|
||||
|
||||
|
||||
async def _get_diagram_or_404(
|
||||
diagram_id: UUID,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
) -> NetworkDiagram:
|
||||
diagram = await db.get(NetworkDiagram, diagram_id)
|
||||
if not diagram or diagram.account_id != account_id or diagram.is_archived:
|
||||
raise HTTPException(status_code=404, detail="Diagram not found")
|
||||
return diagram
|
||||
|
||||
|
||||
def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse:
|
||||
return NetworkDiagramResponse.model_validate(diagram)
|
||||
|
||||
|
||||
def _diagram_to_list_item(
|
||||
diagram: NetworkDiagram,
|
||||
custom_slug_category: dict[str, str] | None = None,
|
||||
) -> NetworkDiagramListItem:
|
||||
nodes = diagram.nodes if isinstance(diagram.nodes, list) else []
|
||||
slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})}
|
||||
|
||||
category_counts: dict[str, int] = {}
|
||||
for node in nodes:
|
||||
slug = node.get("type", "") if isinstance(node, dict) else ""
|
||||
cat = slug_to_cat.get(slug, "other")
|
||||
category_counts[cat] = category_counts.get(cat, 0) + 1
|
||||
|
||||
return NetworkDiagramListItem(
|
||||
id=diagram.id,
|
||||
name=diagram.name,
|
||||
client_name=diagram.client_name,
|
||||
description=diagram.description,
|
||||
node_count=len(nodes),
|
||||
category_counts=category_counts,
|
||||
thumbnail_url=diagram.thumbnail_url,
|
||||
created_by=diagram.created_by,
|
||||
created_at=diagram.created_at,
|
||||
updated_at=diagram.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]:
|
||||
stmt = select(DeviceType.slug).where(
|
||||
or_(
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
DeviceType.account_id == account_id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return {row[0] for row in result.all()}
|
||||
|
||||
|
||||
@router.get("/clients", response_model=list[str])
|
||||
async def list_client_names(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[str]:
|
||||
stmt = (
|
||||
select(NetworkDiagram.client_name)
|
||||
.where(
|
||||
NetworkDiagram.account_id == current_user.account_id,
|
||||
NetworkDiagram.is_archived.is_(False),
|
||||
NetworkDiagram.client_name.isnot(None),
|
||||
NetworkDiagram.client_name != "",
|
||||
)
|
||||
.distinct()
|
||||
.order_by(NetworkDiagram.client_name)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return [row[0] for row in result.all()]
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NetworkDiagramListItem])
|
||||
async def list_diagrams(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
client_name: str | None = Query(default=None),
|
||||
search: str | None = Query(default=None),
|
||||
) -> list[NetworkDiagramListItem]:
|
||||
stmt = (
|
||||
select(NetworkDiagram)
|
||||
.where(
|
||||
NetworkDiagram.account_id == current_user.account_id,
|
||||
NetworkDiagram.is_archived.is_(False),
|
||||
)
|
||||
.order_by(NetworkDiagram.updated_at.desc())
|
||||
)
|
||||
|
||||
if client_name:
|
||||
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
||||
|
||||
if search:
|
||||
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
search_filter = f"%{escaped}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
NetworkDiagram.name.ilike(search_filter),
|
||||
NetworkDiagram.client_name.ilike(search_filter),
|
||||
)
|
||||
)
|
||||
|
||||
# Single query for custom device types so category_counts is accurate
|
||||
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
|
||||
DeviceType.is_system.is_(False),
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
dt_result = await db.execute(dt_stmt)
|
||||
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [_diagram_to_list_item(r, custom_slug_category) for r in rows]
|
||||
|
||||
|
||||
@router.post("/", response_model=NetworkDiagramResponse, status_code=201)
|
||||
async def create_diagram(
|
||||
data: NetworkDiagramCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=data.name,
|
||||
client_name=data.client_name,
|
||||
asset_name=data.asset_name,
|
||||
description=data.description,
|
||||
nodes=[n.model_dump() for n in data.nodes],
|
||||
edges=[e.model_dump() for e in data.edges],
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(diagram)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.get("/{diagram_id}", response_model=NetworkDiagramResponse)
|
||||
async def get_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.put("/{diagram_id}", response_model=NetworkDiagramResponse)
|
||||
async def update_diagram(
|
||||
diagram_id: UUID,
|
||||
data: NetworkDiagramUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
if "nodes" in update_data and update_data["nodes"] is not None:
|
||||
update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]]
|
||||
if "edges" in update_data and update_data["edges"] is not None:
|
||||
update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]]
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(diagram, field, value)
|
||||
|
||||
diagram.updated_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.delete("/{diagram_id}", status_code=204)
|
||||
async def archive_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> None:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
diagram.is_archived = True
|
||||
diagram.updated_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201)
|
||||
async def duplicate_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
source = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
copy = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=f"Copy of {source.name}",
|
||||
client_name=source.client_name,
|
||||
asset_name=source.asset_name,
|
||||
description=source.description,
|
||||
nodes=source.nodes,
|
||||
edges=source.edges,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(copy)
|
||||
await db.commit()
|
||||
await db.refresh(copy)
|
||||
return _diagram_to_response(copy)
|
||||
|
||||
|
||||
@router.get("/{diagram_id}/export", response_model=DiagramExportResponse)
|
||||
async def export_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DiagramExportResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
|
||||
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
|
||||
return DiagramExportResponse(
|
||||
schemaVersion=1,
|
||||
name=diagram.name,
|
||||
client_name=diagram.client_name,
|
||||
description=diagram.description,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
exportedAt=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", response_model=DiagramImportResponse, status_code=201)
|
||||
async def import_diagram(
|
||||
data: DiagramImportRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DiagramImportResponse:
|
||||
available_slugs = await _get_available_slugs(current_user.account_id, db)
|
||||
|
||||
warnings: list[str] = []
|
||||
for node in data.nodes:
|
||||
if node.type not in available_slugs:
|
||||
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
|
||||
|
||||
diagram = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=data.name,
|
||||
client_name=data.client_name,
|
||||
description=data.description,
|
||||
nodes=[n.model_dump() for n in data.nodes],
|
||||
edges=[e.model_dump() for e in data.edges],
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(diagram)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
|
||||
return DiagramImportResponse(
|
||||
diagram=_diagram_to_response(diagram),
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
class ThumbnailUploadRequest(BaseModel):
|
||||
data_url: str # base64 PNG data URL: "data:image/png;base64,..."
|
||||
|
||||
|
||||
@router.post("/{diagram_id}/thumbnail", status_code=204)
|
||||
async def upload_thumbnail(
|
||||
diagram_id: UUID,
|
||||
body: ThumbnailUploadRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> None:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
try:
|
||||
header, encoded = body.data_url.split(",", 1)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid data URL format")
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
storage_key = await storage_service.upload_file(
|
||||
file_data=image_bytes,
|
||||
filename=f"thumbnail-{diagram_id}.png",
|
||||
content_type="image/png",
|
||||
account_id=str(current_user.account_id),
|
||||
)
|
||||
presigned_url = storage_service.get_presigned_url(storage_key)
|
||||
diagram.thumbnail_url = presigned_url
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/ai-generate", response_model=AIGenerateResponse)
|
||||
async def ai_generate_diagram(
|
||||
data: AIGenerateRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> AIGenerateResponse:
|
||||
available_slugs_set = await _get_available_slugs(current_user.account_id, db)
|
||||
available_slugs = list(available_slugs_set)
|
||||
|
||||
existing_node_ids: list[str] | None = None
|
||||
if data.mode == "merge" and data.existingBounds:
|
||||
existing_node_ids = []
|
||||
|
||||
try:
|
||||
return await network_diagram_ai_service.generate_diagram(
|
||||
request=data,
|
||||
available_slugs=available_slugs,
|
||||
existing_node_ids=existing_node_ids,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("AI diagram generation failed")
|
||||
raise HTTPException(status_code=500, detail="Diagram generation failed")
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.models.psa_connection import PsaConnection
|
||||
from app.models.session import Session
|
||||
@@ -98,7 +99,7 @@ async def get_onboarding_status(
|
||||
|
||||
@router.post("/onboarding-status/dismiss", response_model=OnboardingStatus)
|
||||
async def dismiss_onboarding(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> OnboardingStatus:
|
||||
"""Dismiss the onboarding checklist for the current user."""
|
||||
|
||||
@@ -91,6 +91,7 @@ async def submit_step_feedback(
|
||||
new_rating = StepRating(
|
||||
step_id=step_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
session_id=session_uuid,
|
||||
was_helpful=data.was_helpful,
|
||||
# rating is nullable now — thumbs-only mode
|
||||
|
||||
@@ -85,6 +85,7 @@ async def create_session(
|
||||
session = await script_builder_service.create_session(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
language=data.language,
|
||||
)
|
||||
|
||||
@@ -196,6 +196,7 @@ async def start_session(
|
||||
new_session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -693,6 +694,7 @@ async def prepare_session(
|
||||
new_session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=data.assigned_to_id or current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -770,6 +772,7 @@ async def batch_launch_sessions(
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -1102,6 +1105,7 @@ async def psa_post_to_ticket(
|
||||
# Log to audit trail
|
||||
log_entry = PsaPostLog(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
psa_connection_id=psa_connection.id if psa_connection else None,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
note_type=data.note_type,
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.session import Session
|
||||
from app.models.session_share import SessionShare, SessionShareView
|
||||
from app.models.user import User
|
||||
@@ -210,7 +211,7 @@ async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[Use
|
||||
async def access_share(
|
||||
share_token: str,
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
):
|
||||
"""Access a shared session via share token.
|
||||
|
||||
|
||||
@@ -460,6 +460,7 @@ async def rate_step(
|
||||
rating = StepRating(
|
||||
step_id=step_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
rating=rating_data.rating,
|
||||
was_helpful=rating_data.was_helpful,
|
||||
review_text=rating_data.review_text,
|
||||
|
||||
@@ -103,6 +103,7 @@ async def create_supporting_data(
|
||||
|
||||
item = SessionSupportingData(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
label=data.label,
|
||||
data_type=data.data_type,
|
||||
content=data.content,
|
||||
|
||||
@@ -18,12 +18,10 @@ async def list_target_lists(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""List all target lists for the current user's team."""
|
||||
if not current_user.team_id:
|
||||
return []
|
||||
"""List all target lists for the current user's account."""
|
||||
result = await db.execute(
|
||||
select(TargetList)
|
||||
.where(TargetList.team_id == current_user.team_id)
|
||||
.where(TargetList.account_id == current_user.account_id)
|
||||
.order_by(TargetList.name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -36,11 +34,9 @@ async def create_target_list(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Create a new target list for the current team."""
|
||||
if not current_user.team_id:
|
||||
raise HTTPException(status_code=400, detail="User must belong to a team")
|
||||
"""Create a new target list for the current account."""
|
||||
target_list = TargetList(
|
||||
team_id=current_user.team_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
@@ -61,7 +57,7 @@ async def get_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -81,7 +77,7 @@ async def update_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -91,7 +87,7 @@ async def update_target_list(
|
||||
if "name" in update_fields and data.name is not None:
|
||||
target_list.name = data.name
|
||||
if "description" in update_fields:
|
||||
target_list.description = data.description # allow setting to None
|
||||
target_list.description = data.description
|
||||
if "targets" in update_fields and data.targets is not None:
|
||||
target_list.targets = [t.model_dump() for t in data.targets]
|
||||
await db.commit()
|
||||
@@ -109,7 +105,7 @@ async def delete_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
|
||||
@@ -1048,6 +1048,7 @@ async def create_tree_share(
|
||||
# Create share
|
||||
tree_share = TreeShare(
|
||||
tree_id=tree.id,
|
||||
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
|
||||
share_token=share_token,
|
||||
created_by=current_user.id,
|
||||
allow_forking=share_data.allow_forking,
|
||||
|
||||
@@ -24,6 +24,7 @@ from app.api.endpoints import (
|
||||
branding,
|
||||
categories,
|
||||
copilot,
|
||||
device_types,
|
||||
feedback,
|
||||
flow_proposals,
|
||||
flowpilot_analytics,
|
||||
@@ -32,6 +33,7 @@ from app.api.endpoints import (
|
||||
invite,
|
||||
kb_accelerator,
|
||||
maintenance_schedules,
|
||||
network_diagrams,
|
||||
notifications,
|
||||
onboarding,
|
||||
public_templates,
|
||||
@@ -93,7 +95,6 @@ api_router.include_router(admin_settings.router)
|
||||
api_router.include_router(admin_categories.router)
|
||||
api_router.include_router(admin_survey.router)
|
||||
api_router.include_router(admin_gallery.router)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-facing endpoints — tenant context required
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -130,6 +131,7 @@ api_router.include_router(integrations.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(branding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
|
||||
# session_handoffs queue router must come before ai_sessions to avoid conflict
|
||||
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
|
||||
@@ -142,3 +144,4 @@ api_router.include_router(script_builder.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(device_types.router, dependencies=_tenant_deps)
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""
|
||||
Admin database engine — connects as resolutionflow_admin (BYPASSRLS).
|
||||
|
||||
Use ONLY for /admin/* endpoints and internal tooling.
|
||||
Never use this engine from user-facing endpoints.
|
||||
Use ONLY where explicit application-level access control makes database-layer
|
||||
tenant filtering unnecessary: /admin/* endpoints, internal tooling, and public
|
||||
endpoints that enforce their own authorization before returning data (e.g.
|
||||
share access via opaque token + visibility check).
|
||||
"""
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -25,7 +27,7 @@ _admin_session_factory = async_sessionmaker(
|
||||
|
||||
|
||||
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints."""
|
||||
"""Yield an admin DB session (BYPASSRLS). See module docstring for approved use cases."""
|
||||
async with _admin_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
@@ -199,7 +199,10 @@ async def generate_fixes(
|
||||
|
||||
try:
|
||||
text, in_tok, out_tok = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": FIX_SYSTEM_PROMPT},
|
||||
# cacheable: stable constant across all fix attempts
|
||||
],
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
@@ -232,7 +235,11 @@ async def generate_fixes(
|
||||
|
||||
try:
|
||||
text2, in_tok2, out_tok2 = await provider.generate_json(
|
||||
system_prompt=FIX_SYSTEM_PROMPT,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": FIX_SYSTEM_PROMPT},
|
||||
# cacheable: stable constant; retry reads the cached
|
||||
# system block from the first attempt above
|
||||
],
|
||||
messages=messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
@@ -3,16 +3,169 @@ AI Provider abstraction layer.
|
||||
|
||||
Supports Gemini (google-genai) and Anthropic (anthropic) as interchangeable
|
||||
backends for JSON generation used by the AI Flow Builder.
|
||||
|
||||
## Prompt caching (Anthropic only)
|
||||
|
||||
Callers may pass `system_prompt` as either:
|
||||
|
||||
- `str` — backward-compatible, uncached.
|
||||
- `list[SystemBlock]` — Anthropic structured system blocks. Each block is a
|
||||
dict of shape `{"type": "text", "text": str, "cache_control": {...}?}`.
|
||||
|
||||
Caching policy (policy α, per Phase 0.1 design):
|
||||
- If any block in the list carries an explicit `cache_control` key, that
|
||||
caller-authored configuration is honored verbatim.
|
||||
- If no block carries `cache_control`, the provider applies
|
||||
`cache_control: {"type": "ephemeral"}` to the first block only. First block
|
||||
is the common "large static prefix" case (e.g. system prompt, reference data).
|
||||
|
||||
Gemini ignores cache_control and concatenates list blocks into one system
|
||||
string — callers should not rely on Gemini for cache-hit behavior.
|
||||
|
||||
TODO(phase0-verify): When a dev environment is available, verify cache-hit
|
||||
behavior by hitting any FlowPilot endpoint twice within the 5-minute
|
||||
ephemeral TTL. First call should emit `anthropic.cache` with
|
||||
`cache_creation_input_tokens > 0`; second call with `cache_read_input_tokens > 0`.
|
||||
If the second call returns zero reads, inspect the prefix for silent
|
||||
invalidators (timestamps, unsorted JSON keys, varying tool list ordering).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Anthropic structured system block. See module docstring for caching policy.
|
||||
SystemBlock = dict[str, Any]
|
||||
|
||||
|
||||
def _normalize_system_for_anthropic(
|
||||
system_prompt: str | list[SystemBlock],
|
||||
) -> str | list[SystemBlock]:
|
||||
"""Return the value to pass as the `system=` parameter to the Anthropic API.
|
||||
|
||||
- Plain strings pass through untouched (uncached path).
|
||||
- Lists are returned as structured system blocks. If no block in the list
|
||||
carries an explicit `cache_control`, `cache_control: {"type": "ephemeral"}`
|
||||
is applied to the FIRST block only (policy α).
|
||||
- Caller-authored `cache_control` is never overwritten.
|
||||
"""
|
||||
if isinstance(system_prompt, str):
|
||||
return system_prompt
|
||||
|
||||
if not system_prompt:
|
||||
# Empty list is not a meaningful system prompt — pass empty string so
|
||||
# Anthropic treats this as "no system prompt" rather than erroring.
|
||||
return ""
|
||||
|
||||
blocks = [dict(b) for b in system_prompt]
|
||||
already_cached = any("cache_control" in b for b in blocks)
|
||||
|
||||
if not already_cached:
|
||||
blocks[0]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def _flatten_system_for_gemini(
|
||||
system_prompt: str | list[SystemBlock],
|
||||
) -> str:
|
||||
"""Gemini has no structured system blocks; concatenate list entries."""
|
||||
if isinstance(system_prompt, str):
|
||||
return system_prompt
|
||||
return "\n\n".join(b.get("text", "") for b in system_prompt)
|
||||
|
||||
|
||||
def build_anthropic_chat_messages(
|
||||
history: list[dict[str, Any]],
|
||||
new_message: str,
|
||||
images: list[dict[str, Any]] | None = None,
|
||||
format_reminder: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Construct the Anthropic `messages` payload for a cached multi-turn chat.
|
||||
|
||||
Responsibilities:
|
||||
- Copy the valid history messages in order.
|
||||
- Apply `cache_control: ephemeral` to the LAST history message so the entire
|
||||
conversation prefix is cached across turns. The new user message stays
|
||||
uncached (it changes each turn).
|
||||
- Append `format_reminder` to the new user message if provided. The reminder
|
||||
is invisible to storage (caller's concern) but helps enforce structured
|
||||
output compliance at generation time.
|
||||
- If `images` are provided, render the new user message as a multimodal
|
||||
content block list (images first, then text). Otherwise, render it as
|
||||
a plain string.
|
||||
|
||||
This helper is Anthropic-specific: the cache-breakpoint pattern, ephemeral
|
||||
cache_control, and multimodal block shape are all Anthropic conventions.
|
||||
Do not call it from Gemini code paths.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in history:
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
# Cache breakpoint on the last existing history message so the entire
|
||||
# conversation prefix is cached across turns. Safe only when there IS a
|
||||
# history message; otherwise the new message is the only message.
|
||||
if messages:
|
||||
last = messages[-1]
|
||||
messages[-1] = {
|
||||
"role": last["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": last["content"],
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
effective_text = new_message + (format_reminder or "")
|
||||
|
||||
if images:
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
for img in images:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": img["media_type"],
|
||||
"data": img["data"],
|
||||
},
|
||||
}
|
||||
)
|
||||
content_blocks.append({"type": "text", "text": effective_text})
|
||||
messages.append({"role": "user", "content": content_blocks})
|
||||
else:
|
||||
messages.append({"role": "user", "content": effective_text})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _log_anthropic_cache_usage(usage: Any, model: str) -> None:
|
||||
"""Emit a structured log line capturing cache_read / cache_creation tokens."""
|
||||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
input_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
if cache_read or cache_creation:
|
||||
logger.info(
|
||||
"anthropic.cache",
|
||||
extra={
|
||||
"event": "anthropic.cache",
|
||||
"model": model,
|
||||
"cache_read_input_tokens": cache_read,
|
||||
"cache_creation_input_tokens": cache_creation,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
@@ -20,14 +173,16 @@ class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Generate a JSON response from the AI model.
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instruction for the model.
|
||||
system_prompt: System-level instruction. Plain `str` is uncached
|
||||
(Anthropic) or used as-is (Gemini). `list[SystemBlock]` enables
|
||||
Anthropic prompt caching per module-docstring policy.
|
||||
messages: List of message dicts with "role" and "content" keys.
|
||||
max_tokens: Maximum output tokens.
|
||||
|
||||
@@ -39,37 +194,25 @@ class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def generate_text(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Generate a text response from the AI model (no JSON constraint).
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instruction for the model.
|
||||
messages: List of message dicts with "role" and "content" keys.
|
||||
max_tokens: Maximum output tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, input_tokens, output_tokens).
|
||||
See `generate_json` for argument semantics.
|
||||
"""
|
||||
...
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> "AsyncIterator[str]":
|
||||
"""Stream a text response token by token.
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instruction for the model.
|
||||
messages: List of message dicts with "role" and "content" keys.
|
||||
max_tokens: Maximum output tokens.
|
||||
|
||||
Yields:
|
||||
Text chunks as they are generated.
|
||||
See `generate_json` for argument semantics.
|
||||
"""
|
||||
raise NotImplementedError("Streaming not supported for this provider")
|
||||
# Make this an async generator to satisfy type checker
|
||||
@@ -85,14 +228,15 @@ class GeminiProvider(AIProvider):
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
client = genai.Client(api_key=self._api_key)
|
||||
system_text = _flatten_system_for_gemini(system_prompt)
|
||||
|
||||
# Convert messages to Gemini Content format
|
||||
contents: list[genai_types.Content] = []
|
||||
@@ -106,7 +250,7 @@ class GeminiProvider(AIProvider):
|
||||
)
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
system_instruction=system_text,
|
||||
max_output_tokens=max_tokens,
|
||||
response_mime_type="application/json",
|
||||
)
|
||||
@@ -137,14 +281,15 @@ class GeminiProvider(AIProvider):
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
client = genai.Client(api_key=self._api_key)
|
||||
system_text = _flatten_system_for_gemini(system_prompt)
|
||||
|
||||
contents: list[genai_types.Content] = []
|
||||
for msg in messages:
|
||||
@@ -157,7 +302,7 @@ class GeminiProvider(AIProvider):
|
||||
)
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
system_instruction=system_text,
|
||||
max_output_tokens=max_tokens,
|
||||
# No response_mime_type — allow free-form text
|
||||
)
|
||||
@@ -214,16 +359,17 @@ class AnthropicProvider(AIProvider):
|
||||
|
||||
async def generate_json(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||||
|
||||
response = await client.messages.create(
|
||||
model=self._model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
system=normalized_system,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
@@ -231,12 +377,14 @@ class AnthropicProvider(AIProvider):
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
|
||||
_log_anthropic_cache_usage(response.usage, self._model)
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
async def generate_text(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> tuple[str, int, int]:
|
||||
# Anthropic doesn't differentiate between JSON and text mode
|
||||
@@ -244,20 +392,28 @@ class AnthropicProvider(AIProvider):
|
||||
|
||||
async def generate_text_stream(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, str]],
|
||||
system_prompt: str | list[SystemBlock],
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[str]:
|
||||
client = _get_anthropic_client(self._api_key, self._timeout)
|
||||
normalized_system = _normalize_system_for_anthropic(system_prompt)
|
||||
|
||||
async with client.messages.stream(
|
||||
model=self._model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
system=normalized_system,
|
||||
messages=messages,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
# Per Anthropic SDK, get_final_message() resolves the stream's
|
||||
# final usage object (including cache_read/cache_creation tokens).
|
||||
try:
|
||||
final = await stream.get_final_message()
|
||||
_log_anthropic_cache_usage(final.usage, self._model)
|
||||
except Exception as exc: # best-effort telemetry, never fail the stream
|
||||
logger.debug("anthropic.cache streaming usage unavailable: %s", exc)
|
||||
|
||||
|
||||
def get_ai_provider(model: str | None = None) -> AIProvider:
|
||||
|
||||
@@ -146,7 +146,10 @@ async def scaffold_branches(
|
||||
user_message += f"Environment: {', '.join(tags)}\n"
|
||||
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=SCAFFOLD_SYSTEM_PROMPT,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": SCAFFOLD_SYSTEM_PROMPT},
|
||||
# cacheable: stable constant across all scaffold calls
|
||||
],
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
max_tokens=2048,
|
||||
)
|
||||
@@ -207,7 +210,13 @@ async def generate_branch_detail(
|
||||
|
||||
for attempt in range(3):
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=BRANCH_DETAIL_SYSTEM_PROMPT,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": BRANCH_DETAIL_SYSTEM_PROMPT},
|
||||
# cacheable: stable constant. Retries in this loop re-read the
|
||||
# cached system block rather than paying full input cost each
|
||||
# attempt — the ~2.5k-token prompt with few-shot example is
|
||||
# the dominant cost here.
|
||||
],
|
||||
messages=messages,
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
@@ -12,10 +12,19 @@ async def log_audit(
|
||||
resource_type: str,
|
||||
resource_id: Optional[UUID] = None,
|
||||
details: Optional[dict] = None,
|
||||
account_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
||||
if account_id is None:
|
||||
# Derive from the acting user's account as a fallback (one extra query).
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
result = await db.execute(select(User.account_id).where(User.id == user_id))
|
||||
account_id = result.scalar_one()
|
||||
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
|
||||
@@ -128,6 +128,7 @@ class Settings(BaseSettings):
|
||||
"variable_inference": "fast",
|
||||
"kb_convert": "standard",
|
||||
"script_build": "standard",
|
||||
"network_diagram_generate": "standard",
|
||||
}
|
||||
|
||||
def get_model_for_action(self, action_type: str) -> str:
|
||||
|
||||
@@ -425,7 +425,12 @@ async def convert_document(
|
||||
|
||||
try:
|
||||
raw_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": system_prompt},
|
||||
# cacheable: one of two stable constants (TROUBLESHOOTING_SYSTEM_PROMPT
|
||||
# or PROCEDURAL_SYSTEM_PROMPT) selected by target_type. Each
|
||||
# variant caches independently by text content.
|
||||
],
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
max_tokens=16384,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
"""Create batch sessions for a scheduled maintenance run."""
|
||||
# Import all models first to ensure SQLAlchemy mapper relationships resolve
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||
from app.models.session import Session
|
||||
from app.models.target_list import TargetList
|
||||
@@ -118,7 +118,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
async def _cleanup_expired_ai_conversations() -> None:
|
||||
"""Delete expired AI wizard conversations."""
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_conversation import AIConversation
|
||||
|
||||
async with async_session_maker() as db:
|
||||
|
||||
@@ -14,6 +14,8 @@ import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import _admin_session_factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
||||
@@ -52,13 +54,18 @@ async def _ensure_system_account(db: AsyncSession) -> uuid.UUID:
|
||||
async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||
"""Ensure the ResolutionFlow service account exists and return its ID.
|
||||
|
||||
Idempotent — safe to call on every startup. Creates the account if it
|
||||
does not exist. The account has no usable password and is_service_account=True
|
||||
so it can never log in via normal auth flows.
|
||||
Idempotent — safe to call on every startup. This lookup must bypass RLS
|
||||
because startup runs before any request-scoped tenant context exists and
|
||||
the users table is tenant-isolated in Phase 4. The service account is
|
||||
normally created by Alembic migration 1490781700bc; the runtime create path
|
||||
remains as a self-healing fallback for environments that predate that seed.
|
||||
"""
|
||||
_ = db # Retained for call-site compatibility in app lifespan startup.
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
result = await db.execute(
|
||||
async with _admin_session_factory() as admin_db:
|
||||
result = await admin_db.execute(
|
||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
@@ -66,10 +73,10 @@ async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||
if user is not None:
|
||||
if not user.is_service_account:
|
||||
user.is_service_account = True
|
||||
await db.commit()
|
||||
await admin_db.commit()
|
||||
return user.id
|
||||
|
||||
account_id = await _ensure_system_account(db)
|
||||
account_id = await _ensure_system_account(admin_db)
|
||||
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
@@ -85,7 +92,7 @@ async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
admin_db.add(new_user)
|
||||
await admin_db.commit()
|
||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||
return new_user.id
|
||||
|
||||
@@ -25,7 +25,8 @@ if settings.SENTRY_DSN:
|
||||
),
|
||||
)
|
||||
|
||||
from app.core.database import init_db, async_session_maker
|
||||
from app.core.database import init_db
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.core.logging_config import setup_logging
|
||||
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||
from app.core.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -56,6 +56,12 @@ from .session_handoff import SessionHandoff
|
||||
from .session_resolution_output import SessionResolutionOutput
|
||||
from .template_tree import TemplateTree
|
||||
from .platform_step import PlatformStep
|
||||
from .device_type import DeviceType
|
||||
from .network_diagram import NetworkDiagram
|
||||
from .session_fact import SessionFact
|
||||
from .session_suggested_fix import SessionSuggestedFix
|
||||
from .draft_template import DraftTemplate
|
||||
from .account_settings import AccountSettings
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -126,4 +132,10 @@ __all__ = [
|
||||
"SessionResolutionOutput",
|
||||
"TemplateTree",
|
||||
"PlatformStep",
|
||||
"DeviceType",
|
||||
"NetworkDiagram",
|
||||
"SessionFact",
|
||||
"SessionSuggestedFix",
|
||||
"DraftTemplate",
|
||||
"AccountSettings",
|
||||
]
|
||||
|
||||
99
backend/app/models/account_settings.py
Normal file
99
backend/app/models/account_settings.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Per-account settings with a JSONB preferences grab-bag.
|
||||
|
||||
Rows are created lazily on first write. Reads of a missing row return the
|
||||
caller-supplied default — no upfront row creation per account.
|
||||
|
||||
Settings live in `preferences` until they meet the promotion criteria in
|
||||
Section 4.6 of FLOWPILOT-MIGRATION.md (hot path / validation / joins), at
|
||||
which point a future migration adds a typed column and the helpers prefer it.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB, insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
class AccountSettings(Base):
|
||||
"""One row per account. Created lazily on first `set_setting` call."""
|
||||
__tablename__ = "account_settings"
|
||||
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
preferences: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
account: Mapped["Account"] = relationship("Account", foreign_keys=[account_id])
|
||||
|
||||
@classmethod
|
||||
async def get_setting(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
account_id: uuid.UUID,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""Return preferences[key] for the account, or `default` if no row/no key.
|
||||
|
||||
Never creates a row — this is the pure-read path.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(cls.preferences).where(cls.account_id == account_id)
|
||||
)
|
||||
prefs = result.scalar_one_or_none()
|
||||
if prefs is None:
|
||||
return default
|
||||
return prefs.get(key, default)
|
||||
|
||||
@classmethod
|
||||
async def set_setting(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
account_id: uuid.UUID,
|
||||
key: str,
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""Upsert preferences[key] = value for the account.
|
||||
|
||||
Creates the row on first write; on subsequent writes, merges the key
|
||||
into the existing preferences JSON without clobbering other keys.
|
||||
Uses PostgreSQL's `||` jsonb merge operator via ON CONFLICT DO UPDATE.
|
||||
"""
|
||||
stmt = pg_insert(cls).values(
|
||||
account_id=account_id,
|
||||
preferences={key: value},
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[cls.account_id],
|
||||
set_={
|
||||
# Merge the new {key: value} into the existing preferences.
|
||||
# The `||` operator on jsonb overwrites matching keys and keeps
|
||||
# all other keys intact.
|
||||
"preferences": cls.preferences.op("||")(stmt.excluded.preferences),
|
||||
"updated_at": text("now()"),
|
||||
},
|
||||
)
|
||||
await db.execute(stmt)
|
||||
@@ -214,6 +214,38 @@ class AISession(Base):
|
||||
comment="Current task lane state: {questions: [...], actions: [...]}",
|
||||
)
|
||||
|
||||
# ── Resolution / Escalation artifacts (Phase 1 — FlowPilot migration) ──
|
||||
# Markdown of the posted note + PSA external ID for round-trip traceability.
|
||||
resolution_note_markdown: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True,
|
||||
comment="Final Resolve note markdown, as posted to the PSA",
|
||||
)
|
||||
resolution_note_posted_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True,
|
||||
)
|
||||
resolution_note_external_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(128), nullable=True,
|
||||
comment="PSA (e.g. CW) ticket-note ID returned at post time",
|
||||
)
|
||||
escalation_package_markdown: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True,
|
||||
comment="Final Escalate handoff package markdown, as posted to the PSA",
|
||||
)
|
||||
escalation_package_posted_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True,
|
||||
)
|
||||
escalation_package_external_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(128), nullable=True,
|
||||
comment="PSA ticket-note ID for the escalation package",
|
||||
)
|
||||
# Incremented atomically by any write that invalidates the resolution
|
||||
# note preview cache (facts, suggested fixes, script generations).
|
||||
# See FLOWPILOT-MIGRATION.md Section 5.5.
|
||||
state_version: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0, server_default=sa.text("0"),
|
||||
comment="Monotonic preview-cache version; bumped on state-changing writes",
|
||||
)
|
||||
|
||||
# ── Branching ──
|
||||
is_branching: Mapped[bool] = mapped_column(
|
||||
default=False,
|
||||
|
||||
@@ -21,6 +21,12 @@ class AuditLog(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
|
||||
47
backend/app/models/device_type.py
Normal file
47
backend/app/models/device_type.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Device type model for network diagrams."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class DeviceType(Base):
|
||||
"""A device type for network diagram nodes (platform or account-custom)."""
|
||||
__tablename__ = "device_types"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
slug: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
comment="Unique identifier used in diagram node data",
|
||||
)
|
||||
label: Mapped[str] = mapped_column(
|
||||
String(100), nullable=False,
|
||||
comment="Display name",
|
||||
)
|
||||
category: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
comment="network, compute, storage, cloud, endpoint, infrastructure, security",
|
||||
)
|
||||
is_system: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False,
|
||||
comment="True for built-in types that cannot be deleted",
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
comment="Platform account for system types, tenant account for custom types",
|
||||
)
|
||||
sort_order: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0,
|
||||
comment="Display order within category",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
91
backend/app/models/draft_template.py
Normal file
91
backend/app/models/draft_template.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Draft template model — scripts generated during a session, pending templatization.
|
||||
|
||||
Created when an engineer picks "Run now, templatize after resolve" in the
|
||||
three-option dialog. Post-resolve, the TemplatizePrompt component reads pending
|
||||
drafts and lets the engineer accept (promotes to `script_templates`) or reject.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
Text, DateTime, ForeignKey, String, CheckConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.account import Account
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.user import User
|
||||
from app.models.script_template import ScriptCategory, ScriptTemplate
|
||||
|
||||
|
||||
class DraftTemplate(Base):
|
||||
"""A session-generated script pending conversion to a reusable template."""
|
||||
__tablename__ = "draft_templates"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"status IN ('pending', 'accepted', 'rejected')",
|
||||
name="ck_draft_templates_status",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
)
|
||||
source_session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("ai_sessions.id"),
|
||||
nullable=False,
|
||||
)
|
||||
source_user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
)
|
||||
script_body: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
proposed_parameters: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB, nullable=False
|
||||
)
|
||||
proposed_name: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
proposed_category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("script_categories.id"),
|
||||
nullable=True,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="pending"
|
||||
)
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Set when status transitions to 'accepted' and the draft is promoted
|
||||
# to a real script_templates row.
|
||||
promoted_template_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("script_templates.id"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
account: Mapped["Account"] = relationship("Account", foreign_keys=[account_id])
|
||||
source_session: Mapped["AISession"] = relationship(
|
||||
"AISession", foreign_keys=[source_session_id]
|
||||
)
|
||||
source_user: Mapped["User"] = relationship("User", foreign_keys=[source_user_id])
|
||||
proposed_category: Mapped["ScriptCategory | None"] = relationship(
|
||||
"ScriptCategory", foreign_keys=[proposed_category_id]
|
||||
)
|
||||
promoted_template: Mapped["ScriptTemplate | None"] = relationship(
|
||||
"ScriptTemplate", foreign_keys=[promoted_template_id]
|
||||
)
|
||||
53
backend/app/models/network_diagram.py
Normal file
53
backend/app/models/network_diagram.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Network diagram model."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class NetworkDiagram(Base):
|
||||
"""A network topology diagram scoped to one account."""
|
||||
__tablename__ = "network_diagrams"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_archived: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False,
|
||||
)
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])
|
||||
@@ -78,6 +78,20 @@ class ScriptTemplate(Base):
|
||||
is_gallery_featured: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"), index=True)
|
||||
gallery_sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default=text("0"))
|
||||
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default=text("0"))
|
||||
# ── Provenance (Phase 1 — FlowPilot migration) ──
|
||||
# Populated when a template is promoted from a post-resolve draft_templates row.
|
||||
# Powers the Script Library provenance chip:
|
||||
# "generated from CW #X · resolved by Y · used N times"
|
||||
source_session_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("ai_sessions.id"), nullable=True,
|
||||
)
|
||||
source_user_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id"), nullable=True,
|
||||
)
|
||||
source_ticket_ref: Mapped[Optional[str]] = mapped_column(
|
||||
String(64), nullable=True,
|
||||
comment="Human-readable PSA ticket ref for display, e.g. 'CW #48307'",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
79
backend/app/models/session_fact.py
Normal file
79
backend/app/models/session_fact.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Session fact model — the "What we know" backing store for a FlowPilot session.
|
||||
|
||||
A fact is an atomic, engineer-readable statement of what has been confirmed
|
||||
during troubleshooting. Facts accumulate across the session and drive the
|
||||
resolution note preview.
|
||||
|
||||
`source_ref` is a polymorphic pointer to a task-lane item inside
|
||||
`ai_sessions.pending_task_lane` JSON — it is NOT a FK. Integrity is enforced
|
||||
at the service layer per the FLOWPILOT-MIGRATION design doc Section 4.2.
|
||||
Phase 2 assigns stable UUIDs to those task-lane items so `source_ref` has
|
||||
something reliable to point to.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Text, DateTime, ForeignKey, String, CheckConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class SessionFact(Base):
|
||||
"""A single fact in the What-we-know section of a session's task lane."""
|
||||
__tablename__ = "session_facts"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"source_type IN ('question', 'diagnostic_check', 'user_note', 'ai_synthesis')",
|
||||
name="ck_session_facts_source_type",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("ai_sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
source_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
# Pointer to a task-lane item UUID inside ai_sessions.pending_task_lane.
|
||||
# NOT a FK. Null for `user_note` and `ai_synthesis` sources.
|
||||
source_ref: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), nullable=True
|
||||
)
|
||||
source_summary: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_by: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
session: Mapped["AISession"] = relationship("AISession", foreign_keys=[session_id])
|
||||
account: Mapped["Account"] = relationship("Account", foreign_keys=[account_id])
|
||||
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
|
||||
80
backend/app/models/session_suggested_fix.py
Normal file
80
backend/app/models/session_suggested_fix.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Session suggested-fix model — AI-proposed resolution path for a session.
|
||||
|
||||
A session can have multiple suggested fixes over its lifetime as the AI's
|
||||
understanding evolves. Only one is active at a time (superseded_at IS NULL);
|
||||
emitting a new [SUGGEST_FIX] marker supersedes the prior active one.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
Text, DateTime, ForeignKey, String, Integer, CheckConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.ai_session import AISession
|
||||
from app.models.account import Account
|
||||
from app.models.script_template import ScriptTemplate
|
||||
|
||||
|
||||
class SessionSuggestedFix(Base):
|
||||
"""One AI-proposed fix for a FlowPilot session."""
|
||||
__tablename__ = "session_suggested_fixes"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"confidence_pct BETWEEN 0 AND 100",
|
||||
name="ck_session_suggested_fixes_confidence_pct",
|
||||
),
|
||||
CheckConstraint(
|
||||
"user_decision IS NULL OR user_decision IN ("
|
||||
"'one_off', 'draft_template', 'build_template', 'dismissed')",
|
||||
name="ck_session_suggested_fixes_user_decision",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("ai_sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id"),
|
||||
nullable=False,
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
confidence_pct: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
script_template_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("script_templates.id"),
|
||||
nullable=True,
|
||||
)
|
||||
# Populated only when there's no matching template and the AI has
|
||||
# drafted a session-specific script.
|
||||
ai_drafted_script: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ai_drafted_parameters: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSONB, nullable=True
|
||||
)
|
||||
user_decision: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
# Set when a newer suggested fix supersedes this one.
|
||||
superseded_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
session: Mapped["AISession"] = relationship("AISession", foreign_keys=[session_id])
|
||||
account: Mapped["Account"] = relationship("Account", foreign_keys=[account_id])
|
||||
script_template: Mapped["ScriptTemplate | None"] = relationship(
|
||||
"ScriptTemplate", foreign_keys=[script_template_id]
|
||||
)
|
||||
@@ -8,7 +8,6 @@ from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
@@ -18,10 +17,6 @@ class TargetList(Base):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
team_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -25,6 +25,12 @@ class TreeShare(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
share_token: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
unique=True,
|
||||
|
||||
@@ -20,6 +20,7 @@ from .psa_connection import (
|
||||
PSATicketSearchResult, PSATicketStatusItem,
|
||||
PsaPostRequest, PsaPostResponse, PsaPreviewResponse, PsaPostLogResponse,
|
||||
PsaMemberMappingResponse, PsaMemberMappingSaveRequest, PsaMemberResponse, AutoMatchResult,
|
||||
PSABoardResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -50,4 +51,5 @@ __all__ = [
|
||||
"PSATicketSearchResult", "PSATicketStatusItem",
|
||||
"PsaPostRequest", "PsaPostResponse", "PsaPreviewResponse", "PsaPostLogResponse",
|
||||
"PsaMemberMappingResponse", "PsaMemberMappingSaveRequest", "PsaMemberResponse", "AutoMatchResult",
|
||||
"PSABoardResponse",
|
||||
]
|
||||
|
||||
@@ -28,6 +28,111 @@ class ActivityEntry(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# --- Admin Accounts & People Search ---
|
||||
|
||||
class AdminUserListItem(BaseModel):
|
||||
id: UUID
|
||||
email: EmailStr
|
||||
name: str
|
||||
role: str
|
||||
is_super_admin: bool = False
|
||||
is_active: bool = True
|
||||
account_id: Optional[UUID] = None
|
||||
account_role: Optional[str] = None
|
||||
account_name: Optional[str] = None
|
||||
account_display_code: Optional[str] = None
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AdminUserListResponse(BaseModel):
|
||||
items: list[AdminUserListItem]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
|
||||
|
||||
class AdminAccountMember(BaseModel):
|
||||
id: UUID
|
||||
email: EmailStr
|
||||
name: str
|
||||
role: str
|
||||
is_super_admin: bool = False
|
||||
is_active: bool = True
|
||||
account_role: Optional[str] = None
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AdminAccountOwnerSummary(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class AdminAccountSubscriptionSummary(BaseModel):
|
||||
id: UUID
|
||||
plan: str
|
||||
status: str
|
||||
billing_interval: Optional[str] = None
|
||||
current_period_end: Optional[datetime] = None
|
||||
cancel_at_period_end: bool = False
|
||||
|
||||
|
||||
class AdminAccountUsageSummary(BaseModel):
|
||||
tree_count: int = 0
|
||||
session_count_this_month: int = 0
|
||||
|
||||
|
||||
class AdminAccountInviteSummary(BaseModel):
|
||||
id: UUID
|
||||
email: EmailStr
|
||||
role: str
|
||||
expires_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class AdminAccountListItem(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
display_code: str
|
||||
created_at: datetime
|
||||
owner_id: Optional[UUID] = None
|
||||
owner: Optional[AdminAccountOwnerSummary] = None
|
||||
subscription: Optional[AdminAccountSubscriptionSummary] = None
|
||||
usage: AdminAccountUsageSummary = Field(default_factory=AdminAccountUsageSummary)
|
||||
member_count: int = 0
|
||||
active_member_count: int = 0
|
||||
pending_invite_count: int = 0
|
||||
sso_enabled: bool = False
|
||||
branding_company_name: Optional[str] = None
|
||||
members: list[AdminAccountMember] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AdminAccountListResponse(BaseModel):
|
||||
items: list[AdminAccountListItem]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
|
||||
|
||||
class AdminAccountDetailResponse(AdminAccountListItem):
|
||||
invites: list[AdminAccountInviteSummary] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AdminAccountCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
plan: Literal["free", "pro", "team"] = "free"
|
||||
owner_email: Optional[EmailStr] = Field(None, description="Email of an existing user to set as owner")
|
||||
|
||||
|
||||
class AdminAccountUpdate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
# --- Audit Logs ---
|
||||
|
||||
class AuditLogEntry(BaseModel):
|
||||
@@ -215,7 +320,7 @@ class AdminUserCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
account_mode: Literal["existing", "personal"]
|
||||
account_display_code: Optional[str] = Field(None, description="Required when account_mode='existing'")
|
||||
account_role: Optional[Literal["engineer", "viewer"]] = Field(None, description="Required when account_mode='existing'")
|
||||
account_role: Optional[Literal["owner", "admin", "engineer", "viewer"]] = Field(None, description="Required when account_mode='existing'")
|
||||
send_email: bool = True
|
||||
|
||||
|
||||
|
||||
37
backend/app/schemas/device_type.py
Normal file
37
backend/app/schemas/device_type.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Pydantic schemas for device types."""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DeviceTypeCreate(BaseModel):
|
||||
slug: str = Field(min_length=1, max_length=50, pattern=r"^[a-z0-9\-]+$")
|
||||
label: str = Field(min_length=1, max_length=100)
|
||||
category: str = Field(
|
||||
min_length=1, max_length=50,
|
||||
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
|
||||
)
|
||||
sort_order: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class DeviceTypeUpdate(BaseModel):
|
||||
label: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
category: str | None = Field(
|
||||
default=None, min_length=1, max_length=50,
|
||||
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
|
||||
)
|
||||
sort_order: int | None = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class DeviceTypeResponse(BaseModel):
|
||||
id: UUID
|
||||
slug: str
|
||||
label: str
|
||||
category: str
|
||||
is_system: bool
|
||||
account_id: UUID
|
||||
sort_order: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
145
backend/app/schemas/network_diagram.py
Normal file
145
backend/app/schemas/network_diagram.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Pydantic schemas for network diagrams."""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
class DeviceProperties(BaseModel):
|
||||
hostname: str | None = None
|
||||
ip: str | None = None
|
||||
subnet: str | None = None
|
||||
vendor: str | None = None
|
||||
model: str | None = None
|
||||
role: str | None = None
|
||||
vlan: str | None = None
|
||||
notes: str | None = None
|
||||
status: str = Field(default="unknown", pattern=r"^(unknown|online|offline|degraded)$")
|
||||
|
||||
|
||||
class NodeStyle(BaseModel):
|
||||
width: float | None = None
|
||||
height: float | None = None
|
||||
|
||||
|
||||
class DiagramNode(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
label: str
|
||||
position: Position
|
||||
properties: DeviceProperties = Field(default_factory=DeviceProperties)
|
||||
nodeType: str | None = None
|
||||
style: NodeStyle | None = None
|
||||
parentId: str | None = None
|
||||
|
||||
|
||||
class DiagramEdge(BaseModel):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
label: str | None = None
|
||||
connectionType: str = "ethernet"
|
||||
speed: str | None = None
|
||||
notes: str | None = None
|
||||
routing: str | None = None
|
||||
|
||||
|
||||
class NetworkDiagramCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
|
||||
|
||||
class NetworkDiagramUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] | None = None
|
||||
edges: list[DiagramEdge] | None = None
|
||||
|
||||
|
||||
class NetworkDiagramResponse(BaseModel):
|
||||
id: UUID
|
||||
account_id: UUID
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
thumbnail_url: str | None = None
|
||||
is_archived: bool = False
|
||||
created_by: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NetworkDiagramListItem(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
node_count: int = 0
|
||||
category_counts: dict[str, int] = Field(default_factory=dict)
|
||||
thumbnail_url: str | None = None
|
||||
created_by: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ExistingBounds(BaseModel):
|
||||
minX: float
|
||||
maxX: float
|
||||
minY: float
|
||||
maxY: float
|
||||
|
||||
|
||||
class AIGenerateRequest(BaseModel):
|
||||
description: str = Field(min_length=1, max_length=5000)
|
||||
client_name: str | None = None
|
||||
mode: str = Field(default="replace", pattern=r"^(replace|merge)$")
|
||||
existingBounds: ExistingBounds | None = None
|
||||
|
||||
|
||||
class AIGenerateResponse(BaseModel):
|
||||
nodes: list[DiagramNode]
|
||||
edges: list[DiagramEdge]
|
||||
suggestedName: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class DiagramImportRequest(BaseModel):
|
||||
schemaVersion: int = Field(ge=1, le=1)
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DiagramImportResponse(BaseModel):
|
||||
diagram: NetworkDiagramResponse
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DiagramExportResponse(BaseModel):
|
||||
schemaVersion: int = 1
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode]
|
||||
edges: list[DiagramEdge]
|
||||
exportedAt: str
|
||||
@@ -111,13 +111,13 @@ class PsaPostLogResponse(BaseModel):
|
||||
|
||||
|
||||
class PsaMemberMappingResponse(BaseModel):
|
||||
id: str
|
||||
id: str | None = None # None for users without a mapping
|
||||
user_id: str
|
||||
user_email: str
|
||||
user_name: str
|
||||
external_member_id: str
|
||||
external_member_name: str
|
||||
matched_by: str
|
||||
external_member_id: str | None = None
|
||||
external_member_name: str | None = None
|
||||
matched_by: str | None = None
|
||||
|
||||
|
||||
class PsaMemberMappingSaveRequest(BaseModel):
|
||||
@@ -136,3 +136,8 @@ class PsaMemberResponse(BaseModel):
|
||||
class AutoMatchResult(BaseModel):
|
||||
matched: list[PsaMemberMappingResponse]
|
||||
unmatched_users: int
|
||||
|
||||
|
||||
class PSABoardResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
|
||||
|
||||
class TargetListResponse(BaseModel):
|
||||
id: UUID
|
||||
team_id: UUID
|
||||
account_id: UUID
|
||||
created_by: Optional[UUID]
|
||||
name: str
|
||||
description: Optional[str]
|
||||
|
||||
@@ -68,4 +68,4 @@ class RoleUpdate(BaseModel):
|
||||
|
||||
|
||||
class AccountRoleUpdate(BaseModel):
|
||||
account_role: str = Field(..., pattern="^(engineer|viewer)$")
|
||||
account_role: str = Field(..., pattern="^(owner|admin|engineer|viewer)$")
|
||||
|
||||
@@ -10,10 +10,32 @@ Uses Anthropic prompt caching to reduce cost on multi-turn conversations:
|
||||
|
||||
Optionally connects to Microsoft Learn via Anthropic's MCP connector
|
||||
for real-time documentation lookups (controlled by ENABLE_MCP_MICROSOFT_LEARN).
|
||||
|
||||
## Architectural note — this module is the one MCP/beta chat caller
|
||||
|
||||
`chat_call_cached` below is the ONLY caller in the codebase that uses
|
||||
Anthropic's `client.beta.messages.create` endpoint, MCP servers, multimodal
|
||||
user messages, and the retry-without-MCP fallback. It is deliberately NOT
|
||||
routed through `AnthropicProvider` — MCP/beta/images are features of exactly
|
||||
one optional Anthropic beta endpoint and do not belong in a provider-agnostic
|
||||
abstraction that also serves Gemini.
|
||||
|
||||
If a new caller needs the same (MCP, beta, images, history caching) bundle,
|
||||
call `chat_call_cached` directly rather than pushing those concerns into
|
||||
`AnthropicProvider`. Cached-system-block plumbing is shared with the provider
|
||||
via `_normalize_system_for_anthropic` / `build_anthropic_chat_messages` /
|
||||
`_log_anthropic_cache_usage` in `app.core.ai_provider` — cache primitives are
|
||||
reusable, but the MCP/beta orchestration stays here.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.ai_provider import (
|
||||
_get_anthropic_client,
|
||||
_log_anthropic_cache_usage,
|
||||
_normalize_system_for_anthropic,
|
||||
build_anthropic_chat_messages,
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -184,7 +206,7 @@ async def _call_ai(
|
||||
to include alongside the new_message as vision content.
|
||||
"""
|
||||
if settings.AI_PROVIDER == "anthropic" and settings.ANTHROPIC_API_KEY:
|
||||
return await _call_anthropic_cached(
|
||||
return await chat_call_cached(
|
||||
system_base, rag_context, history, new_message, max_tokens,
|
||||
images=images,
|
||||
)
|
||||
@@ -202,7 +224,18 @@ async def _call_ai(
|
||||
)
|
||||
|
||||
|
||||
async def _call_anthropic_cached(
|
||||
# Appended to every chat turn's user message immediately before generation.
|
||||
# Invisible to storage (unified_chat_service strips markers before persisting),
|
||||
# but critical for structured output compliance — the model emits invalid
|
||||
# responses often enough without it that removing this reminder regresses UX.
|
||||
_CHAT_FORMAT_REMINDER = (
|
||||
"\n\n[SYSTEM: Remember — your response MUST end with [QUESTIONS] "
|
||||
"and/or [ACTIONS] markers containing valid JSON arrays. "
|
||||
"Responses without markers break the UI.]"
|
||||
)
|
||||
|
||||
|
||||
async def chat_call_cached(
|
||||
system_base: str,
|
||||
rag_context: str,
|
||||
history: list[dict[str, Any]],
|
||||
@@ -210,79 +243,56 @@ async def _call_anthropic_cached(
|
||||
max_tokens: int,
|
||||
images: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Call Anthropic with prompt caching on system prompt and history.
|
||||
"""Call Anthropic's chat surface with caching, MCP, images, and retry-without-MCP.
|
||||
|
||||
Uses structured system blocks so the static base prompt is cached
|
||||
independently from the per-query RAG context. Optionally connects
|
||||
to Microsoft Learn via MCP for real-time documentation lookups.
|
||||
This is the ONE MCP/beta/multimodal chat caller. It is deliberately NOT
|
||||
routed through `AnthropicProvider`. See module docstring for rationale.
|
||||
|
||||
Responsibilities unique to this function (not in the provider):
|
||||
- Anthropic beta endpoint (`client.beta.messages.create`)
|
||||
- Microsoft Learn MCP connector wiring (optional via ENABLE_MCP_MICROSOFT_LEARN)
|
||||
- Retry-without-MCP fallback when the MCP server misbehaves
|
||||
- Multimodal image blocks in the user message
|
||||
- Format-reminder append for structured-output compliance
|
||||
- Telemetry (`mcp.turn`, `mcp.fallback`) for Phase 0.5 MCP usage signal
|
||||
|
||||
Cache plumbing is shared with the provider via helpers in `ai_provider`:
|
||||
`_normalize_system_for_anthropic` (policy α — ephemeral on first block if
|
||||
none specified), `build_anthropic_chat_messages` (history cache breakpoint +
|
||||
multimodal user message + format reminder), `_log_anthropic_cache_usage`.
|
||||
"""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=settings.ANTHROPIC_API_KEY,
|
||||
client = _get_anthropic_client(
|
||||
settings.ANTHROPIC_API_KEY,
|
||||
timeout=settings.AI_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# System prompt as structured blocks:
|
||||
# Block 1: static base prompt (cached)
|
||||
# Block 2: RAG context (changes per query, not cached)
|
||||
# System prompt as structured blocks. The static base is cacheable; the
|
||||
# RAG context changes per query and must NOT be cached — so we mark the
|
||||
# base explicitly and leave the RAG block unmarked. `_normalize_system`
|
||||
# honors caller-authored cache_control verbatim (policy α).
|
||||
system_blocks: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_base,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
# cacheable: static system prompt, stable across all turns of all sessions
|
||||
},
|
||||
]
|
||||
if rag_context:
|
||||
system_blocks.append({"type": "text", "text": rag_context})
|
||||
|
||||
# Build messages with cache breakpoint on conversation history
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in history:
|
||||
messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
# Place cache breakpoint on the last history message so the entire
|
||||
# conversation prefix is cached across turns
|
||||
if messages:
|
||||
last = messages[-1]
|
||||
messages[-1] = {
|
||||
"role": last["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": last["content"],
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add the new user message (uncached — it's new each turn)
|
||||
# Append a format reminder to the user message so the model sees it
|
||||
# immediately before generating. This is invisible to the user (stripped
|
||||
# before storage) but critical for structured output compliance.
|
||||
format_reminder = (
|
||||
"\n\n[SYSTEM: Remember — your response MUST end with [QUESTIONS] "
|
||||
"and/or [ACTIONS] markers containing valid JSON arrays. "
|
||||
"Responses without markers break the UI.]"
|
||||
system_blocks.append(
|
||||
{"type": "text", "text": rag_context}
|
||||
# uncached: RAG retrieval varies per query
|
||||
)
|
||||
reminded_message = new_message + format_reminder
|
||||
normalized_system = _normalize_system_for_anthropic(system_blocks)
|
||||
|
||||
# If images are attached, build multimodal content blocks
|
||||
if images:
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
for img in images:
|
||||
content_blocks.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": img["media_type"],
|
||||
"data": img["data"],
|
||||
},
|
||||
})
|
||||
content_blocks.append({"type": "text", "text": reminded_message})
|
||||
messages.append({"role": "user", "content": content_blocks})
|
||||
else:
|
||||
messages.append({"role": "user", "content": reminded_message})
|
||||
messages = build_anthropic_chat_messages(
|
||||
history=history,
|
||||
new_message=new_message,
|
||||
images=images,
|
||||
format_reminder=_CHAT_FORMAT_REMINDER,
|
||||
)
|
||||
|
||||
# MCP server config (optional — controlled by settings)
|
||||
mcp_servers = anthropic.NOT_GIVEN
|
||||
@@ -304,12 +314,13 @@ async def _call_anthropic_cached(
|
||||
]
|
||||
|
||||
_mcp_active = mcp_servers is not anthropic.NOT_GIVEN
|
||||
_mcp_fallback_triggered = False
|
||||
|
||||
try:
|
||||
response = await client.beta.messages.create(
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
max_tokens=max_tokens,
|
||||
system=system_blocks,
|
||||
system=normalized_system,
|
||||
messages=messages,
|
||||
mcp_servers=mcp_servers,
|
||||
tools=tools,
|
||||
@@ -326,14 +337,24 @@ async def _call_anthropic_cached(
|
||||
or isinstance(e, (anthropic.BadRequestError, anthropic.APIStatusError))
|
||||
)
|
||||
if _is_mcp_error:
|
||||
_mcp_fallback_triggered = True
|
||||
logger.warning(
|
||||
"MCP server error (%s), retrying without MCP: %s",
|
||||
type(e).__name__, e,
|
||||
)
|
||||
# Phase 0.5 telemetry: per-turn fallback event.
|
||||
logger.info(
|
||||
"mcp.fallback",
|
||||
extra={
|
||||
"event": "mcp.fallback",
|
||||
"mcp_error_type": type(e).__name__,
|
||||
"mcp_error_message": str(e)[:500],
|
||||
},
|
||||
)
|
||||
response = await client.messages.create(
|
||||
model=settings.AI_MODEL_ANTHROPIC,
|
||||
max_tokens=max_tokens,
|
||||
system=system_blocks,
|
||||
system=normalized_system,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
@@ -355,18 +376,27 @@ async def _call_anthropic_cached(
|
||||
input_tokens = usage.input_tokens
|
||||
output_tokens = usage.output_tokens
|
||||
|
||||
# Log MCP tool usage
|
||||
# Phase 0.5 telemetry: per-turn MCP event. Emitted for every turn that
|
||||
# reached this code path (i.e., AI_PROVIDER=anthropic chat). `mcp_available`
|
||||
# reflects whether MCP was actually wired into the request (scope (ii) from
|
||||
# the Phase 0.5 design — Anthropic code path AND flag on). `mcp_invoked`
|
||||
# reflects whether the model chose to call an MCP tool on this turn.
|
||||
logger.info(
|
||||
"mcp.turn",
|
||||
extra={
|
||||
"event": "mcp.turn",
|
||||
"mcp_available": _mcp_active,
|
||||
"mcp_invoked": bool(mcp_tools_used),
|
||||
"mcp_tools": mcp_tools_used,
|
||||
"mcp_fallback_triggered": _mcp_fallback_triggered,
|
||||
},
|
||||
)
|
||||
|
||||
# Human-readable log retained for grep-based inspection.
|
||||
if mcp_tools_used:
|
||||
logger.info("MCP tools used: %s", ", ".join(mcp_tools_used))
|
||||
|
||||
# Log cache performance
|
||||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
if cache_read or cache_creation:
|
||||
logger.info(
|
||||
"Anthropic cache: read=%d creation=%d input=%d output=%d",
|
||||
cache_read, cache_creation, input_tokens, output_tokens,
|
||||
)
|
||||
_log_anthropic_cache_usage(usage, settings.AI_MODEL_ANTHROPIC)
|
||||
|
||||
return text, input_tokens, output_tokens
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class BranchManager:
|
||||
root = SessionBranch(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
parent_branch_id=None,
|
||||
branch_order=1,
|
||||
label="Root",
|
||||
@@ -68,9 +69,17 @@ class BranchManager:
|
||||
"status": "untried",
|
||||
})
|
||||
|
||||
# Load session to get account_id for FK constraints
|
||||
session_result = await self.db.execute(
|
||||
select(AISession).where(AISession.id == session_id)
|
||||
)
|
||||
session = session_result.scalar_one_or_none()
|
||||
account_id = session.account_id if session else None
|
||||
|
||||
fork_point = ForkPoint(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
account_id=account_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
trigger_step_id=trigger_step_id,
|
||||
fork_reason=fork_reason,
|
||||
@@ -90,6 +99,7 @@ class BranchManager:
|
||||
branch = SessionBranch(
|
||||
id=branch_ids[i],
|
||||
session_id=session_id,
|
||||
account_id=account_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
fork_point_step_id=trigger_step_id,
|
||||
branch_order=i + 1,
|
||||
|
||||
@@ -330,6 +330,7 @@ async def start_session(
|
||||
# 7. Create first step
|
||||
step = _create_step_from_parsed(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
step_order=0,
|
||||
parsed=parsed,
|
||||
input_tokens=input_tokens,
|
||||
@@ -433,6 +434,7 @@ async def process_response(
|
||||
# Create new step
|
||||
step = _create_step_from_parsed(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
step_order=session.step_count - 1,
|
||||
parsed=parsed,
|
||||
input_tokens=input_tokens,
|
||||
@@ -694,6 +696,7 @@ async def pickup_session(
|
||||
briefing_step = AISessionStep(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
branch_id=session.active_branch_id if session.is_branching else None,
|
||||
step_order=session.step_count,
|
||||
step_type="action",
|
||||
@@ -765,6 +768,7 @@ async def pickup_session(
|
||||
|
||||
next_step = _create_step_from_parsed(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
step_order=session.step_count - 1,
|
||||
parsed=parsed,
|
||||
input_tokens=input_tokens,
|
||||
@@ -997,6 +1001,7 @@ async def generate_status_update(
|
||||
step = AISessionStep(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
branch_id=session.active_branch_id if session.is_branching else None,
|
||||
step_order=session.step_count,
|
||||
step_type="status_update",
|
||||
@@ -1440,6 +1445,7 @@ def _format_engineer_response(request: StepResponseRequest) -> str:
|
||||
|
||||
def _create_step_from_parsed(
|
||||
session_id: UUID,
|
||||
account_id: UUID,
|
||||
step_order: int,
|
||||
parsed: dict[str, Any],
|
||||
input_tokens: int,
|
||||
@@ -1487,6 +1493,7 @@ def _create_step_from_parsed(
|
||||
return AISessionStep(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
account_id=account_id,
|
||||
branch_id=branch_id,
|
||||
step_order=step_order,
|
||||
step_type=step_type if parsed["type"] != "resolution_suggestion" else "action",
|
||||
|
||||
@@ -56,6 +56,7 @@ class HandoffManager:
|
||||
|
||||
handoff = SessionHandoff(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
handed_off_by=user_id,
|
||||
intent=intent,
|
||||
source_branch_id=session.active_branch_id,
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_session import AISession
|
||||
from app.services.knowledge_flywheel import analyze_session
|
||||
|
||||
|
||||
151
backend/app/services/network_diagram_ai_service.py
Normal file
151
backend/app/services/network_diagram_ai_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""AI service for generating network diagrams from natural language."""
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.config import settings
|
||||
from app.schemas.network_diagram import (
|
||||
AIGenerateRequest,
|
||||
AIGenerateResponse,
|
||||
DiagramNode,
|
||||
DiagramEdge,
|
||||
DeviceProperties,
|
||||
Position,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers.
|
||||
Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble.
|
||||
|
||||
Return this exact structure:
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "unique-string",
|
||||
"type": "device-type-slug",
|
||||
"label": "device label",
|
||||
"position": {{ "x": number, "y": number }},
|
||||
"properties": {{
|
||||
"hostname": "string or null",
|
||||
"ip": "string or null",
|
||||
"subnet": "string or null",
|
||||
"vendor": "string or null",
|
||||
"model": "string or null",
|
||||
"role": "string or null",
|
||||
"vlan": "string or null",
|
||||
"notes": "string or null",
|
||||
"status": "unknown"
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{
|
||||
"id": "unique-string",
|
||||
"source": "node-id",
|
||||
"target": "node-id",
|
||||
"label": "connection label or null",
|
||||
"connectionType": "ethernet|fiber|wifi|vpn|vlan|wan",
|
||||
"speed": "string or null",
|
||||
"notes": "string or null"
|
||||
}}
|
||||
],
|
||||
"suggestedName": "short descriptive diagram name",
|
||||
"notes": "any important assumptions or missing info, or null"
|
||||
}}
|
||||
|
||||
Available device type slugs: {available_slugs}
|
||||
|
||||
Position nodes thoughtfully in a logical network topology layout.
|
||||
Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y.
|
||||
Place WAN/internet at top, core network in middle, endpoints at bottom.
|
||||
{merge_instructions}"""
|
||||
|
||||
MERGE_INSTRUCTIONS = """
|
||||
IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices.
|
||||
The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}.
|
||||
Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100).
|
||||
You may create edges that connect new nodes to existing nodes if the description implies a connection.
|
||||
Use these existing node IDs for connections: {existing_node_ids}"""
|
||||
|
||||
|
||||
async def generate_diagram(
|
||||
request: AIGenerateRequest,
|
||||
available_slugs: list[str],
|
||||
existing_node_ids: list[str] | None = None,
|
||||
) -> AIGenerateResponse:
|
||||
merge_instructions = ""
|
||||
if request.mode == "merge" and request.existingBounds:
|
||||
b = request.existingBounds
|
||||
merge_instructions = MERGE_INSTRUCTIONS.format(
|
||||
minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY,
|
||||
existing_node_ids=", ".join(existing_node_ids or []),
|
||||
)
|
||||
|
||||
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
available_slugs=", ".join(available_slugs),
|
||||
merge_instructions=merge_instructions,
|
||||
)
|
||||
|
||||
model = settings.get_model_for_action("network_diagram_generate")
|
||||
provider = get_ai_provider(model)
|
||||
|
||||
messages = [{"role": "user", "content": request.description}]
|
||||
|
||||
response_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Network diagram AI generation: input_tokens=%d, output_tokens=%d",
|
||||
input_tokens, output_tokens,
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(response_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to parse AI response as JSON: %s", e)
|
||||
raise ValueError("AI generated an invalid response, please try again")
|
||||
|
||||
try:
|
||||
nodes = []
|
||||
for raw_node in data.get("nodes", []):
|
||||
node_type = raw_node.get("type", "server")
|
||||
if node_type not in available_slugs:
|
||||
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
||||
node_type = "server"
|
||||
|
||||
nodes.append(DiagramNode(
|
||||
id=raw_node["id"],
|
||||
type=node_type,
|
||||
label=raw_node.get("label", node_type),
|
||||
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
|
||||
properties=DeviceProperties(**{
|
||||
k: v for k, v in raw_node.get("properties", {}).items()
|
||||
if k in DeviceProperties.model_fields
|
||||
}),
|
||||
))
|
||||
|
||||
edges = []
|
||||
for raw_edge in data.get("edges", []):
|
||||
edges.append(DiagramEdge(
|
||||
id=raw_edge["id"],
|
||||
source=raw_edge["source"],
|
||||
target=raw_edge["target"],
|
||||
label=raw_edge.get("label"),
|
||||
connectionType=raw_edge.get("connectionType", "ethernet"),
|
||||
speed=raw_edge.get("speed"),
|
||||
notes=raw_edge.get("notes"),
|
||||
))
|
||||
except KeyError as e:
|
||||
logger.warning("AI response missing required field: %s", e)
|
||||
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
|
||||
|
||||
return AIGenerateResponse(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
suggestedName=data.get("suggestedName"),
|
||||
notes=data.get("notes"),
|
||||
)
|
||||
@@ -11,6 +11,7 @@ from app.services.psa.types import (
|
||||
PSAMember,
|
||||
PSAConfiguration,
|
||||
PSATimeEntry,
|
||||
PSABoard,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,6 +59,9 @@ class AutotaskProvider(PSAProvider):
|
||||
async def list_members(self) -> list[PSAMember]:
|
||||
raise NotImplementedError("Autotask integration coming soon")
|
||||
|
||||
async def list_boards(self) -> list[PSABoard]:
|
||||
raise NotImplementedError("list_boards not implemented for this provider")
|
||||
|
||||
async def get_ticket_configurations(self, ticket_id: str) -> list[PSAConfiguration]:
|
||||
raise NotImplementedError("Autotask integration coming soon")
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .types import (
|
||||
PSAMember,
|
||||
PSAConfiguration,
|
||||
PSATimeEntry,
|
||||
PSABoard,
|
||||
)
|
||||
|
||||
|
||||
@@ -64,6 +65,10 @@ class PSAProvider(ABC):
|
||||
async def list_members(self) -> list[PSAMember]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_boards(self) -> list[PSABoard]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_ticket_configurations(self, ticket_id: str) -> list[PSAConfiguration]:
|
||||
...
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.services.psa.types import (
|
||||
PSAMember,
|
||||
PSAConfiguration,
|
||||
PSATimeEntry,
|
||||
PSABoard,
|
||||
)
|
||||
from .client import ConnectWiseClient
|
||||
|
||||
@@ -55,11 +56,16 @@ class ConnectWiseProvider(PSAProvider):
|
||||
return self._map_ticket(data)
|
||||
|
||||
async def search_tickets(self, query: str, **filters) -> list[PSATicket]:
|
||||
"""Search CW tickets by summary. Supports board_id and status_id filters."""
|
||||
"""Search CW tickets by summary. Supports board_id, status_id, member_id,
|
||||
unassigned, board_ids, page, and page_size filters."""
|
||||
page_size = filters.get("page_size", 10)
|
||||
page = filters.get("page", 1)
|
||||
|
||||
params: dict = {
|
||||
"fields": "id,summary,company,board,status,priority,closedFlag",
|
||||
"orderBy": "id desc",
|
||||
"pageSize": 25,
|
||||
"pageSize": page_size,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
# Build CW condition query
|
||||
@@ -72,6 +78,14 @@ class ConnectWiseProvider(PSAProvider):
|
||||
conditions.append(f"status/id = {filters['status_id']}")
|
||||
if not filters.get("include_closed", False):
|
||||
conditions.append("closedFlag = false")
|
||||
if filters.get("member_identifier") is not None:
|
||||
conditions.append(f"resources contains '{filters['member_identifier']}'")
|
||||
if filters.get("unassigned", False):
|
||||
conditions.append("resources = null")
|
||||
board_ids: list[int] = filters.get("board_ids") or []
|
||||
if board_ids:
|
||||
board_list = ", ".join(str(bid) for bid in board_ids)
|
||||
conditions.append(f"board/id in ({board_list})")
|
||||
|
||||
if conditions:
|
||||
params["conditions"] = " and ".join(conditions)
|
||||
@@ -270,6 +284,32 @@ class ConnectWiseProvider(PSAProvider):
|
||||
psa_cache.set(cache_key, result, ttl_seconds=900)
|
||||
return result
|
||||
|
||||
async def list_boards(self) -> list[PSABoard]:
|
||||
"""List active CW service boards (cached 1 hour)."""
|
||||
cache_key = "boards"
|
||||
cached = psa_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
data = await self.client.get(
|
||||
"/service/boards",
|
||||
params={
|
||||
"fields": "id,name,inactiveFlag",
|
||||
"conditions": "inactiveFlag = false",
|
||||
"pageSize": 100,
|
||||
},
|
||||
)
|
||||
result = [
|
||||
PSABoard(
|
||||
id=b["id"],
|
||||
name=b["name"],
|
||||
inactive=b.get("inactiveFlag", False),
|
||||
)
|
||||
for b in (data if isinstance(data, list) else [])
|
||||
]
|
||||
psa_cache.set(cache_key, result, ttl_seconds=3600)
|
||||
return result
|
||||
|
||||
# ── Ticket Context ────────────────────────────────────────────────
|
||||
|
||||
async def get_ticket_context(
|
||||
@@ -536,7 +576,7 @@ class ConnectWiseProvider(PSAProvider):
|
||||
if work_type:
|
||||
payload["workType"] = {"name": work_type}
|
||||
|
||||
data = await self._client.post("/time/entries", payload)
|
||||
data = await self.client.post("/time/entries", payload)
|
||||
return PSATimeEntry(
|
||||
id=str(data["id"]),
|
||||
ticket_id=ticket_id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.services.psa.types import (
|
||||
PSAMember,
|
||||
PSAConfiguration,
|
||||
PSATimeEntry,
|
||||
PSABoard,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,6 +59,9 @@ class HaloPSAProvider(PSAProvider):
|
||||
async def list_members(self) -> list[PSAMember]:
|
||||
raise NotImplementedError("Halo PSA integration coming soon")
|
||||
|
||||
async def list_boards(self) -> list[PSABoard]:
|
||||
raise NotImplementedError("list_boards not implemented for this provider")
|
||||
|
||||
async def get_ticket_configurations(self, ticket_id: str) -> list[PSAConfiguration]:
|
||||
raise NotImplementedError("Halo PSA integration coming soon")
|
||||
|
||||
|
||||
@@ -67,6 +67,12 @@ class PSATimeEntry(BaseModel):
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class PSABoard(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
inactive: bool = False
|
||||
|
||||
|
||||
class NoteType:
|
||||
INTERNAL_ANALYSIS = "internal_analysis"
|
||||
RESOLUTION = "resolution"
|
||||
|
||||
@@ -371,6 +371,7 @@ async def push_documentation(
|
||||
# Log success
|
||||
log_entry = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
account_id=session.account_id,
|
||||
ai_session_id=session.id,
|
||||
psa_connection_id=session.psa_connection_id,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
@@ -394,6 +395,7 @@ async def push_documentation(
|
||||
# Log failure with retry scheduling
|
||||
log_entry = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
account_id=session.account_id,
|
||||
ai_session_id=session.id,
|
||||
psa_connection_id=session.psa_connection_id,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.psa_post_log import PsaPostLog
|
||||
from app.services.psa_documentation_service import retry_failed_push
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class ResolutionOutputGenerator:
|
||||
|
||||
output = SessionResolutionOutput(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
output_type=output_type,
|
||||
generated_content=content,
|
||||
status="draft",
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
team_id: UUID | None,
|
||||
language: str,
|
||||
initial_prompt: str | None = None,
|
||||
@@ -151,6 +152,7 @@ async def create_session(
|
||||
"""Create a new Script Builder session."""
|
||||
session = ScriptBuilderSession(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
team_id=team_id,
|
||||
language=language,
|
||||
)
|
||||
@@ -218,7 +220,15 @@ async def send_message(
|
||||
model = settings.get_model_for_action("script_build")
|
||||
provider = get_ai_provider(model=model)
|
||||
ai_text, input_tokens, output_tokens = await provider.generate_text(
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=[
|
||||
{"type": "text", "text": system_prompt},
|
||||
# cacheable: SYSTEM_PROMPT_TEMPLATE with a per-session language
|
||||
# substitution. Two sessions on the same language share a cache
|
||||
# entry; different languages cache independently. Conversation
|
||||
# history (ai_messages) is NOT cached at this layer — if that
|
||||
# becomes a cost driver, route script_builder through the chat
|
||||
# wrapper (0.4) which handles history caching.
|
||||
],
|
||||
messages=ai_messages,
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,10 @@ def _display_code() -> str:
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
engine = create_async_engine(settings.DATABASE_URL, echo=False)
|
||||
# Must use ADMIN_DATABASE_URL (BYPASSRLS) — Phase 4 enabled RLS on users.
|
||||
# The app-role connection has no tenant context at seed time and would see 0 rows.
|
||||
admin_url = getattr(settings, "ADMIN_DATABASE_URL", None) or settings.DATABASE_URL
|
||||
engine = create_async_engine(admin_url, echo=False)
|
||||
password_hash = get_password_hash(SHARED_PASSWORD)
|
||||
now = datetime.now(timezone.utc)
|
||||
team_account_id: uuid.UUID | None = None
|
||||
|
||||
@@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
||||
"""))
|
||||
|
||||
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
|
||||
# global categories, gallery items, and other platform-owned content.
|
||||
await conn.execute(sa.text("""
|
||||
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
|
||||
VALUES (
|
||||
'00000000-0000-0000-0000-000000000001',
|
||||
'ResolutionFlow System',
|
||||
'RF-SYS-1',
|
||||
NOW(), NOW()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""))
|
||||
|
||||
# Create async session maker
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
|
||||
@@ -19,8 +19,116 @@ class TestAdminEndpoints:
|
||||
"/api/v1/admin/users", headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
users = response.json()
|
||||
assert len(users) >= 2 # admin + test_user
|
||||
payload = response.json()
|
||||
assert payload["total"] >= 2 # admin + test_user
|
||||
assert len(payload["items"]) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_supports_search(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test admin people search by user email."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users",
|
||||
params={"search": test_user["email"]},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["total"] >= 1
|
||||
assert any(item["email"] == test_user["email"] for item in payload["items"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_accounts_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict
|
||||
):
|
||||
"""Test listing accounts with member data."""
|
||||
response = await client.get(
|
||||
"/api/v1/admin/accounts", headers=admin_auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["total"] >= 1
|
||||
assert len(payload["items"]) >= 1
|
||||
assert "members" in payload["items"][0]
|
||||
assert "subscription" in payload["items"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict
|
||||
):
|
||||
"""Test creating an empty account from admin."""
|
||||
response = await client.post(
|
||||
"/api/v1/admin/accounts",
|
||||
json={"name": "Acme Customer", "plan": "pro"},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
assert payload["name"] == "Acme Customer"
|
||||
assert payload["subscription"]["plan"] == "pro"
|
||||
assert payload["display_code"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_account_detail_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test fetching account detail for management view."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.get(
|
||||
f"/api/v1/admin/accounts/{account_id}",
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["id"] == account_id
|
||||
assert "members" in payload
|
||||
assert "invites" in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_name_as_admin(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test renaming an account from admin detail view."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/accounts/{account_id}",
|
||||
json={"name": "Renamed Customer Account"},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["id"] == account_id
|
||||
assert payload["name"] == "Renamed Customer Account"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_account_plan(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test changing an account's subscription plan."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/accounts/{account_id}/subscription/plan",
|
||||
json={"plan": "pro"},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["plan"] == "pro"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extend_account_trial(
|
||||
self, client: AsyncClient, admin_auth_headers: dict, test_user: dict
|
||||
):
|
||||
"""Test starting or extending an account trial."""
|
||||
account_id = test_user["user_data"]["account_id"]
|
||||
response = await client.put(
|
||||
f"/api/v1/admin/accounts/{account_id}/subscription/extend-trial",
|
||||
json={"days": 14},
|
||||
headers=admin_auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "trialing"
|
||||
assert response.json()["current_period_end"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_as_non_admin(
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestAdminGlobalCategories:
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Category"
|
||||
assert data["slug"] == "test-category"
|
||||
assert data["account_id"] is None
|
||||
assert data["account_id"] == "00000000-0000-0000-0000-000000000001" # PLATFORM_ACCOUNT_ID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_global_category(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.tree import Tree
|
||||
from app.models.script_template import ScriptTemplate, ScriptCategory
|
||||
|
||||
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree:
|
||||
name="Gallery Test Flow",
|
||||
tree_type="troubleshooting",
|
||||
visibility="public",
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=False,
|
||||
gallery_sort_order=0,
|
||||
tree_structure={
|
||||
@@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate
|
||||
script = ScriptTemplate(
|
||||
id=uuid.uuid4(),
|
||||
category_id=category.id,
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
name="Gallery Test Script",
|
||||
slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}",
|
||||
script_body="Write-Host 'Test'",
|
||||
|
||||
@@ -594,6 +594,7 @@ class TestPsaMetrics:
|
||||
post_log = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
ai_session_id=push_session_id,
|
||||
account_id=account_id,
|
||||
ticket_id="TICKET-123",
|
||||
note_type="internal",
|
||||
content_posted="Session summary",
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.security import get_password_hash
|
||||
from app.models.account import Account
|
||||
from app.models.team import Team
|
||||
from app.models.user import User
|
||||
|
||||
@@ -23,6 +24,8 @@ async def _create_team_with_admin(
|
||||
team_name: str = "Branding Test Team",
|
||||
) -> tuple[dict, str, Team]:
|
||||
"""Create a team + team admin user. Returns (auth_headers, team_id_str, team)."""
|
||||
account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper())
|
||||
test_db.add(account)
|
||||
team = Team(name=team_name)
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
@@ -36,6 +39,8 @@ async def _create_team_with_admin(
|
||||
team_id=team.id,
|
||||
is_team_admin=True,
|
||||
role="engineer",
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(user)
|
||||
await test_db.commit()
|
||||
@@ -58,6 +63,15 @@ async def _create_team_member(
|
||||
is_team_admin: bool = False,
|
||||
) -> dict:
|
||||
"""Create a regular team member. Returns auth_headers."""
|
||||
# Look up the account associated with this team via an existing member
|
||||
from sqlalchemy import select as _select
|
||||
from app.models.user import User as _User
|
||||
result = await test_db.execute(
|
||||
_select(_User).where(_User.team_id == team.id).limit(1)
|
||||
)
|
||||
team_member = result.scalar_one_or_none()
|
||||
member_account_id = team_member.account_id if team_member else None
|
||||
|
||||
email = f"member_{uuid.uuid4().hex[:8]}@test.com"
|
||||
user = User(
|
||||
email=email,
|
||||
@@ -67,6 +81,8 @@ async def _create_team_member(
|
||||
team_id=team.id,
|
||||
is_team_admin=is_team_admin,
|
||||
role="engineer",
|
||||
account_id=member_account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(user)
|
||||
await test_db.commit()
|
||||
|
||||
@@ -334,12 +334,13 @@ class TestDraftTreesAPI:
|
||||
"""Test that migration defaults existing trees to published status."""
|
||||
# Create a tree without specifying status (relies on DB default)
|
||||
from uuid import UUID, uuid4
|
||||
_platform_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
tree = Tree(
|
||||
name="Legacy Tree",
|
||||
description="Created before status field",
|
||||
tree_structure={"id": "root", "type": "solution", "title": "Fix"},
|
||||
author_id=None,
|
||||
account_id=None
|
||||
account_id=_platform_id,
|
||||
)
|
||||
test_db.add(tree)
|
||||
await test_db.commit()
|
||||
|
||||
@@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header
|
||||
test_db.add(other_team)
|
||||
await test_db.flush()
|
||||
|
||||
from uuid import UUID as _UUID
|
||||
other_tree = Tree(
|
||||
name="Other Team Tree",
|
||||
tree_type="maintenance",
|
||||
team_id=other_team.id,
|
||||
account_id=_UUID("00000000-0000-0000-0000-000000000001"),
|
||||
tree_structure={
|
||||
"steps": [
|
||||
{"id": "s1", "type": "procedure_step", "title": "Step",
|
||||
|
||||
96
backend/tests/test_network_diagrams.py
Normal file
96
backend/tests/test_network_diagrams.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.device_type import DeviceType
|
||||
from app.models.user import User
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
|
||||
async def _login_headers(client, email: str, password: str) -> dict[str, str]:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user):
|
||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
user = result.scalar_one()
|
||||
|
||||
test_db.add(
|
||||
DeviceType(
|
||||
id=uuid.uuid4(),
|
||||
slug="platform-router",
|
||||
label="Platform Router",
|
||||
category="network",
|
||||
is_system=True,
|
||||
account_id=PLATFORM_ACCOUNT_ID,
|
||||
sort_order=0,
|
||||
)
|
||||
)
|
||||
await test_db.commit()
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/device-types/",
|
||||
json={
|
||||
"slug": "tenant-appliance",
|
||||
"label": "Tenant Appliance",
|
||||
"category": "network",
|
||||
"sort_order": 3,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
assert create_response.json()["account_id"] == str(user.account_id)
|
||||
|
||||
list_response = await client.get("/api/v1/device-types/", headers=auth_headers)
|
||||
assert list_response.status_code == 200
|
||||
payload = list_response.json()
|
||||
slugs = {item["slug"] for item in payload}
|
||||
|
||||
assert "platform-router" in slugs
|
||||
assert "tenant-appliance" in slugs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user):
|
||||
other_user = {
|
||||
"email": "other-network@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Other Network User",
|
||||
}
|
||||
register_response = await client.post("/api/v1/auth/register", json=other_user)
|
||||
assert register_response.status_code in (200, 201)
|
||||
other_headers = await _login_headers(client, other_user["email"], other_user["password"])
|
||||
|
||||
owner_result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
owner = owner_result.scalar_one()
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/network-diagrams/",
|
||||
json={
|
||||
"name": "HQ Core",
|
||||
"client_name": "Acme",
|
||||
"description": "Primary topology",
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
diagram = create_response.json()
|
||||
assert diagram["account_id"] == str(owner.account_id)
|
||||
|
||||
own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers)
|
||||
assert own_get.status_code == 200
|
||||
|
||||
other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers)
|
||||
assert other_get.status_code == 404
|
||||
@@ -200,6 +200,7 @@ class TestAccountPermissions:
|
||||
})
|
||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||
|
||||
# Outsider should NOT see the private tree
|
||||
# Outsider should NOT see the private tree.
|
||||
# With RLS, the tree is invisible to other tenants — 404 not 403.
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||
assert response.status_code == 403
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
||||
await test_db.flush()
|
||||
|
||||
target_list = TargetList(
|
||||
team_id=team.id,
|
||||
account_id=account.id,
|
||||
created_by=user.id,
|
||||
name="Server Targets",
|
||||
|
||||
@@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.script_template import ScriptCategory, ScriptTemplate
|
||||
from app.models.tree import Tree
|
||||
|
||||
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f
|
||||
description="A featured flow for the gallery",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure=_make_tree_structure(4),
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=featured,
|
||||
is_active=True,
|
||||
usage_count=42,
|
||||
@@ -74,6 +77,7 @@ async def _create_featured_script(
|
||||
) -> ScriptTemplate:
|
||||
script = ScriptTemplate(
|
||||
category_id=category.id,
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
name=name,
|
||||
slug=name.lower().replace(" ", "-"),
|
||||
description="A gallery-featured script",
|
||||
@@ -312,7 +316,7 @@ class TestCategoriesEndpoint:
|
||||
from app.models.category import TreeCategory
|
||||
|
||||
# Create a category and a featured tree in that category
|
||||
cat = TreeCategory(name="Networking", slug="networking", is_active=True)
|
||||
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
|
||||
test_db.add(cat)
|
||||
await test_db.commit()
|
||||
await test_db.refresh(cat)
|
||||
@@ -321,6 +325,7 @@ class TestCategoriesEndpoint:
|
||||
name="Router Diagnostics",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure=_make_tree_structure(2),
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=True,
|
||||
is_active=True,
|
||||
usage_count=5,
|
||||
|
||||
@@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db
|
||||
|
||||
output = SessionResolutionOutput(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
output_type="psa_ticket_notes",
|
||||
generated_content="Original notes",
|
||||
status="draft",
|
||||
|
||||
@@ -16,11 +16,20 @@ Run with:
|
||||
The test DB is patherly_test (matches conftest.py default).
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
# All tests in this module use module-scoped async fixtures (admin_conn,
|
||||
# seed_rls_test_data) which run on the module event loop. Without this marker,
|
||||
# pytest-asyncio 0.23+ defaults tests to function-scoped loops, causing
|
||||
# "Future attached to a different loop" errors on the asyncpg connections.
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="module")
|
||||
|
||||
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
||||
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
||||
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
||||
@@ -37,7 +46,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def admin_conn():
|
||||
def _ensure_rls_schema():
|
||||
"""Re-apply Alembic migrations before the module runs.
|
||||
|
||||
Function-scoped test_db fixtures in other modules drop and recreate the
|
||||
public schema using Base.metadata.create_all, which does not enable RLS
|
||||
or create DB roles. This fixture re-runs 'alembic upgrade head' so that
|
||||
the full migration-managed schema (including RLS policies) is in place.
|
||||
"""
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "upgrade", "head"],
|
||||
cwd=backend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def admin_conn(_ensure_rls_schema):
|
||||
"""Superuser asyncpg connection for fixture setup and teardown."""
|
||||
conn = await asyncpg.connect(_ADMIN_DSN)
|
||||
yield conn
|
||||
@@ -170,7 +197,6 @@ async def conn_no_context():
|
||||
# trees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -178,7 +204,6 @@ async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
assert len(rows) == 0, "Account A should not see Account B trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
@@ -186,7 +211,6 @@ async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
assert len(rows) >= 1, "Account A should see its own trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
rows = await conn_no_context.fetch(
|
||||
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
||||
@@ -198,7 +222,6 @@ async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
# tree_tags — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -206,7 +229,6 @@ async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
rows_a = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||
@@ -222,7 +244,6 @@ async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
# tree_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -234,7 +255,6 @@ async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# step_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -246,7 +266,6 @@ async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# psa_connections — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -258,9 +277,782 @@ async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
# flow_proposals — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def session_row_ids(admin_conn):
|
||||
"""
|
||||
Insert one `sessions` row and one `ai_sessions` row for each of
|
||||
ACCOUNT_A and ACCOUNT_B using the superuser connection (BYPASSRLS).
|
||||
Returns a dict with the inserted IDs for use in tests.
|
||||
Cleans up on exit.
|
||||
"""
|
||||
# Resolve a valid tree_id and user_id for each account
|
||||
tree_a = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
|
||||
)
|
||||
tree_b = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
user_a = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
|
||||
)
|
||||
user_b = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
|
||||
assert tree_a is not None, f"No tree found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
|
||||
assert tree_b is not None, f"No tree found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
|
||||
assert user_a is not None, f"No user found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
|
||||
assert user_b is not None, f"No user found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
|
||||
|
||||
tree_a_id = str(tree_a["id"])
|
||||
tree_b_id = str(tree_b["id"])
|
||||
user_a_id = str(user_a["id"])
|
||||
user_b_id = str(user_b["id"])
|
||||
|
||||
session_a_id = str(uuid.uuid4())
|
||||
session_b_id = str(uuid.uuid4())
|
||||
ai_session_a_id = str(uuid.uuid4())
|
||||
ai_session_b_id = str(uuid.uuid4())
|
||||
|
||||
# Insert sessions rows (sessions uses started_at not created_at)
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES
|
||||
('{session_a_id}', '{tree_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()),
|
||||
('{session_b_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW())
|
||||
""")
|
||||
|
||||
# Insert ai_sessions rows
|
||||
# confidence_tier valid values: 'guided' | 'exploring' | 'discovery'
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_sessions (
|
||||
id, user_id, account_id, session_type, intake_type,
|
||||
intake_content, status, confidence_tier, confidence_score,
|
||||
created_at, updated_at
|
||||
) VALUES
|
||||
('{ai_session_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
|
||||
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
|
||||
NOW(), NOW()),
|
||||
('{ai_session_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
|
||||
NOW(), NOW())
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Seed Account B rows for every "cannot-see" table that would otherwise be
|
||||
# empty. Without these, isolation tests pass vacuously even when RLS is off.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# session_branches (FK: ai_sessions.id)
|
||||
branch_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_branches (
|
||||
id, session_id, account_id, branch_order, label, status,
|
||||
conversation_messages, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 1, 'test-branch', 'active',
|
||||
'[]'::jsonb, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID)
|
||||
branch_b_id = str(branch_b_row["id"])
|
||||
|
||||
# session_supporting_data (FK: sessions.id)
|
||||
supporting_data_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_supporting_data (
|
||||
id, session_id, account_id, label, data_type, content,
|
||||
sort_order, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'test-data', 'text_snippet',
|
||||
'test content', 0, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", session_b_id, ACCOUNT_B_ID)
|
||||
supporting_data_b_id = str(supporting_data_b_row["id"])
|
||||
|
||||
# session_resolution_outputs (FK: ai_sessions.id)
|
||||
resolution_output_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_resolution_outputs (
|
||||
id, session_id, account_id, output_type, generated_content,
|
||||
status, generated_by_model, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'psa_ticket_notes',
|
||||
'test content', 'draft', 'test-model', NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID)
|
||||
resolution_output_b_id = str(resolution_output_b_row["id"])
|
||||
|
||||
# session_handoffs (FK: ai_sessions.id, users.id)
|
||||
handoff_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_handoffs (
|
||||
id, session_id, account_id, handed_off_by, intent, snapshot,
|
||||
priority, psa_note_pushed, notification_sent, created_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, 'park',
|
||||
'{}'::jsonb, 'normal', false, false, NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
|
||||
handoff_b_id = str(handoff_b_row["id"])
|
||||
|
||||
# maintenance_schedules (FK: trees.id)
|
||||
maintenance_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO maintenance_schedules (
|
||||
id, tree_id, account_id, cron_expression, timezone,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, '0 9 * * 1', 'UTC',
|
||||
NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", tree_b_id, ACCOUNT_B_ID)
|
||||
maintenance_b_id = str(maintenance_b_row["id"])
|
||||
|
||||
# psa_post_log (FK: ai_sessions.id, users.id)
|
||||
psa_log_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO psa_post_log (
|
||||
id, ai_session_id, account_id, ticket_id, note_type,
|
||||
content_posted, status, posted_by, posted_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'TEST-0001', 'internal',
|
||||
'test note', 'success', $3::uuid, NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
|
||||
psa_log_b_id = str(psa_log_b_row["id"])
|
||||
|
||||
# script_templates requires a script_categories row — insert a temporary one
|
||||
script_category_b_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO script_categories (id, name, slug, sort_order, is_active, created_at, updated_at)
|
||||
VALUES ('{script_category_b_id}', 'RLS Test Category', 'rls-test-category-{script_category_b_id[:8]}',
|
||||
0, true, NOW(), NOW())
|
||||
""")
|
||||
|
||||
script_template_b_row = await admin_conn.fetchrow(f"""
|
||||
INSERT INTO script_templates (
|
||||
id, category_id, account_id, name, slug, script_body,
|
||||
complexity, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), '{script_category_b_id}'::uuid, $1::uuid,
|
||||
'RLS Test Template', 'rls-test-template-b-' || gen_random_uuid()::text,
|
||||
'Write-Host "test"', 'beginner', true, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ACCOUNT_B_ID)
|
||||
script_template_b_id = str(script_template_b_row["id"])
|
||||
|
||||
# script_generations (FK: script_templates.id, users.id)
|
||||
script_gen_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO script_generations (
|
||||
id, template_id, user_id, account_id, parameters_used,
|
||||
generated_script, created_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, '{}'::jsonb,
|
||||
'test script', NOW()
|
||||
) RETURNING id
|
||||
""", script_template_b_id, user_b_id, ACCOUNT_B_ID)
|
||||
script_gen_b_id = str(script_gen_b_row["id"])
|
||||
|
||||
try:
|
||||
yield {
|
||||
"session_a": session_a_id,
|
||||
"session_b": session_b_id,
|
||||
"ai_session_a": ai_session_a_id,
|
||||
"ai_session_b": ai_session_b_id,
|
||||
}
|
||||
finally:
|
||||
# Cleanup in reverse FK order (children before parents)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_generations WHERE id = '{script_gen_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_branches WHERE id = '{branch_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_supporting_data WHERE id = '{supporting_data_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_resolution_outputs WHERE id = '{resolution_output_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_handoffs WHERE id = '{handoff_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM maintenance_schedules WHERE id = '{maintenance_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM psa_post_log WHERE id = '{psa_log_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_templates WHERE id = '{script_template_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_categories WHERE id = '{script_category_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B sessions"
|
||||
|
||||
|
||||
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
|
||||
)
|
||||
assert len(rows) == 1, "Account A should see its own sessions"
|
||||
|
||||
|
||||
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
|
||||
rows = await conn_no_context.fetch(
|
||||
f"SELECT id FROM sessions WHERE id IN "
|
||||
f"('{session_row_ids['session_a']}', '{session_row_ids['session_b']}')"
|
||||
)
|
||||
assert len(rows) == 0, "No-context connection should see no sessions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ai_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
|
||||
|
||||
|
||||
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
|
||||
)
|
||||
assert len(rows) == 1, "Account A should see its own ai_sessions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_branches"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_supporting_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_supporting_data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_resolution_outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_resolution_outputs"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_handoffs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_handoffs"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_templates"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_generations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_generations"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# maintenance_schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B maintenance_schedules"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# psa_post_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B psa_post_log"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_library — visibility-aware policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
|
||||
"""Private/non-public steps owned by Account B must not be visible to Account A."""
|
||||
private_step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{private_step_id}', '{ACCOUNT_B_ID}', 'RLS Private Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_library "
|
||||
f"WHERE id = '{private_step_id}' AND visibility != 'public'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B's private step_library rows"
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM step_library WHERE id = '{private_step_id}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
|
||||
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
|
||||
public_step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{public_step_id}', '{ACCOUNT_B_ID}', 'RLS Public Step', 'action',
|
||||
'{{}}'::jsonb, 'public', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_library WHERE id = '{public_step_id}'"
|
||||
)
|
||||
assert len(rows) == 1, (
|
||||
"Account A should see public steps owned by Account B "
|
||||
"(cross-tenant public visibility policy)"
|
||||
)
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 3 RLS isolation tests
|
||||
# Tables: step_ratings, step_usage_log, target_lists,
|
||||
# session_shares, audit_logs, tree_shares
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared by Phase 3 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_user_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
async def _get_tree_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_ratings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step ratings belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
# Need a step_library row as FK target
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
rating_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_ratings (
|
||||
id, step_id, user_id, account_id, is_verified_use, is_visible,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
FALSE, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_ratings"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_usage_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step usage logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
# Need a sessions row as FK for usage log
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_usage_log (
|
||||
id, step_id, user_id, account_id, session_id, used_at
|
||||
) VALUES (
|
||||
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{session_id}', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# target_lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see target lists belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
tl_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO target_lists (
|
||||
id, account_id, created_by, name, targets, created_at, updated_at
|
||||
) VALUES (
|
||||
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
|
||||
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B target_lists"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see session shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need a sessions row as FK
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-rls-test-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO session_shares (
|
||||
id, session_id, account_id, share_token, visibility,
|
||||
created_by, view_count, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', 'account', '{user_b_id}',
|
||||
0, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audit_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see audit logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO audit_logs (
|
||||
id, user_id, account_id, action, resource_type, created_at
|
||||
) VALUES (
|
||||
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'test.action', 'test_resource', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B audit_logs"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tree_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see tree shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-tree-rls-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO tree_shares (
|
||||
id, tree_id, account_id, share_token, created_by,
|
||||
allow_forking, created_at
|
||||
) VALUES (
|
||||
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', '{user_b_id}', TRUE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 4 RLS isolation tests
|
||||
# Tables: users, script_builder_sessions, ai_session_steps, notifications
|
||||
#
|
||||
# Note: platform_steps and template_trees have no account_id column and no RLS —
|
||||
# they are globally readable by all authenticated users.
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see users belonging to Account B."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B users"
|
||||
|
||||
|
||||
async def test_users_account_a_can_see_own(admin_conn, conn_a):
|
||||
"""Account A must be able to see its own users."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
)
|
||||
assert len(rows) > 0, "Account A should see its own users"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_builder_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see script builder sessions belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO script_builder_sessions (
|
||||
id, user_id, account_id, language, created_at, updated_at
|
||||
) VALUES (
|
||||
'{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'powershell', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_builder_sessions"
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ai_session_steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see ai_session_steps belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need an ai_sessions row as FK
|
||||
ai_session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_sessions (
|
||||
id, user_id, account_id, flow_type, status, confidence_tier,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'troubleshooting', 'active', 'guided', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_session_steps (
|
||||
id, session_id, account_id, step_type, content,
|
||||
created_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}',
|
||||
'question', 'Phase4 RLS test step', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_session_steps"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# notifications
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see notifications belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
notif_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO notifications (
|
||||
id, user_id, account_id, type, title, message,
|
||||
is_read, created_at
|
||||
) VALUES (
|
||||
'{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'info', 'Phase4 RLS Test', 'RLS isolation test notification',
|
||||
FALSE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B notifications"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'")
|
||||
|
||||
|
||||
@@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}],
|
||||
@@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=other_user.id,
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
|
||||
89
backend/tests/test_service_account.py
Normal file
89
backend/tests/test_service_account.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core import service_account as service_account_module
|
||||
from app.core.service_account import (
|
||||
SERVICE_ACCOUNT_EMAIL,
|
||||
SYSTEM_ACCOUNT_DISPLAY_CODE,
|
||||
ensure_service_account,
|
||||
)
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class _SessionFactoryOverride:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_creates_and_reuses_seeded_user(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
service_account_id = await ensure_service_account(test_db)
|
||||
|
||||
created_user = (
|
||||
await test_db.execute(select(User).where(User.id == service_account_id))
|
||||
).scalar_one()
|
||||
assert created_user.email == SERVICE_ACCOUNT_EMAIL
|
||||
assert created_user.is_service_account is True
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
assert created_user.account_id == system_account.id
|
||||
|
||||
second_id = await ensure_service_account(test_db)
|
||||
assert second_id == service_account_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_marks_existing_user_as_service_account(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
existing_user = User(
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name="ResolutionFlow",
|
||||
password_hash="!service-account-no-login",
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=False,
|
||||
must_change_password=False,
|
||||
account_id=system_account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(existing_user)
|
||||
await test_db.commit()
|
||||
|
||||
resolved_id = await ensure_service_account(test_db)
|
||||
await test_db.refresh(existing_user)
|
||||
|
||||
assert resolved_id == existing_user.id
|
||||
assert existing_user.is_service_account is True
|
||||
@@ -3,37 +3,10 @@ import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.team import Team
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
|
||||
"""Override auth_headers to ensure the test user has a team_id assigned."""
|
||||
# Fetch the user from DB and assign a team
|
||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
user = result.scalar_one()
|
||||
|
||||
# Create a team and assign the user to it
|
||||
team = Team(name="Test Team")
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
|
||||
user.team_id = team.id
|
||||
await test_db.commit()
|
||||
|
||||
# Re-login to get a fresh token
|
||||
login_data = {
|
||||
"email": test_user["email"],
|
||||
"password": test_user["password"],
|
||||
}
|
||||
resp = await client.post("/api/v1/auth/login/json", json=login_data)
|
||||
assert resp.status_code == 200
|
||||
token_data = resp.json()
|
||||
return {"Authorization": f"Bearer {token_data['access_token']}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
|
||||
assert get.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from team B cannot access team A's list."""
|
||||
async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from account B cannot access account A's target list."""
|
||||
import uuid
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
# Create team A list using existing auth_headers
|
||||
# Create account A list using existing auth_headers
|
||||
create = await client.post(
|
||||
"/api/v1/target-lists/",
|
||||
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]},
|
||||
json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create.status_code == 201
|
||||
list_id = create.json()["id"]
|
||||
|
||||
# Create a separate team B with its own user
|
||||
team_b = Team(name=f"Team B {uuid.uuid4()}")
|
||||
test_db.add(team_b)
|
||||
# Create a separate account B with its own user
|
||||
account_b = Account(
|
||||
name=f"Account B {uuid.uuid4()}",
|
||||
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
|
||||
)
|
||||
test_db.add(account_b)
|
||||
await test_db.flush()
|
||||
|
||||
user_b = User(
|
||||
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
password_hash=get_password_hash("password123"),
|
||||
name="User B",
|
||||
is_active=True,
|
||||
team_id=team_b.id,
|
||||
account_id=account_b.id,
|
||||
account_role="engineer",
|
||||
role="engineer",
|
||||
)
|
||||
test_db.add(user_b)
|
||||
await test_db.flush()
|
||||
await test_db.commit()
|
||||
|
||||
# Get auth token for user B
|
||||
login = await client.post(
|
||||
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
token_b = login.json()["access_token"]
|
||||
headers_b = {"Authorization": f"Bearer {token_b}"}
|
||||
|
||||
# Team B cannot access Team A's list
|
||||
# Account B cannot access Account A's list
|
||||
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
|
||||
assert resp.status_code == 404
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user