Compare commits
79 Commits
docs/updat
...
feat/netwo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fb865ada8 | ||
|
|
89ca2a0fa5 | ||
|
|
112b1c2649 | ||
|
|
fd13a0618d | ||
|
|
3372e77a2a | ||
|
|
47353a68cd | ||
|
|
31324aa154 | ||
|
|
17ce5b1dfb | ||
|
|
dd1a13d713 | ||
|
|
2a6178e246 | ||
|
|
327a5c7c14 | ||
|
|
4527571d5f | ||
|
|
3c2b1dd16e | ||
|
|
bb24078d60 | ||
|
|
dd95b8892c | ||
|
|
0dc2801916 | ||
|
|
b490719667 | ||
|
|
663a96c8a5 | ||
|
|
2ea56f2563 | ||
|
|
6e5614e7b4 | ||
|
|
e6a4c93203 | ||
|
|
65ba60b2ae | ||
|
|
74c08f41c4 | ||
|
|
92ce84ef71 | ||
|
|
3c62a6993c | ||
|
|
a9c4bcc08b | ||
|
|
fe33ad1d5a | ||
|
|
3aaf0e58aa | ||
|
|
855cff07c2 | ||
|
|
87de51b06e | ||
|
|
f6e7613a5e | ||
|
|
ddd55167c1 | ||
|
|
2622258b04 | ||
|
|
90d7aa04a9 | ||
|
|
2a977e4d81 | ||
|
|
1371c2edd9 | ||
|
|
25233dbfae | ||
|
|
ab49635de2 | ||
|
|
354b44844c | ||
|
|
1ec7bbbbd3 | ||
|
|
b9e37ecdfb | ||
|
|
074548678f | ||
|
|
24afe5eb41 | ||
|
|
c16f3968d5 | ||
|
|
973efb1f81 | ||
|
|
bb35cff38d | ||
|
|
947516f81e | ||
|
|
f54d7ecd78 | ||
|
|
46593ba8ca | ||
|
|
52553d62d2 | ||
|
|
a48660700a | ||
|
|
3ff886363c | ||
|
|
501442e5f0 | ||
|
|
6f53ec06f5 | ||
|
|
ec322f7cdf | ||
|
|
f9248aeaa8 | ||
|
|
c6da4ebee5 | ||
|
|
64f004a62c | ||
|
|
ba36e37dab | ||
|
|
9e6965512b | ||
|
|
893b8a5008 | ||
|
|
e05472615b | ||
|
|
00fdd663bc | ||
|
|
8cf58add22 | ||
|
|
6c231ef1c6 | ||
|
|
758cd61621 | ||
|
|
b9fcdd5d73 | ||
|
|
4273ed0e5c | ||
|
|
0107d2d896 | ||
|
|
79ae34108a | ||
|
|
bd29f590a2 | ||
|
|
ce4cfc3240 | ||
|
|
82ee177d9b | ||
|
|
ed8de92c52 | ||
|
|
5bd331ca92 | ||
|
|
87fac02e9b | ||
|
|
4f4bc435da | ||
|
|
ac2b193909 | ||
|
|
b641ac6c55 |
5
.github/workflows/ci.yml
vendored
5
.github/workflows/ci.yml
vendored
@@ -31,6 +31,8 @@ jobs:
|
||||
SECRET_KEY: ci-test-secret-key-not-for-production
|
||||
DEBUG: "true"
|
||||
APP_NAME: ResolutionFlow
|
||||
TEST_DB_NAME: resolutionflow_test
|
||||
DB_APP_ROLE_PASSWORD: app_secret_ci
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -47,6 +49,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -r backend/requirements.txt -r backend/requirements-dev.txt
|
||||
|
||||
- name: Run Alembic migrations
|
||||
run: cd backend && alembic upgrade head
|
||||
|
||||
- name: Check tenant filter enforcement
|
||||
run: cd backend && python scripts/check_tenant_filters.py
|
||||
# Warn mode only (exits 0). Switch to --fail after Phase 1 backlog clears.
|
||||
|
||||
@@ -9,6 +9,8 @@ All notable changes to ResolutionFlow are documented here.
|
||||
- Recurring Issue Detection — client-specific pattern alerts (#60)
|
||||
- Step Feedback Flag — "This Step is Wrong" reporting (#58)
|
||||
- **Tenant Isolation Phase 0** — multi-tenant data isolation (#132) with app-layer filtering helpers (`tenant_filter()`, `get_tenant_context`), cross-tenant access audit (analytics, categories, AI sessions, trees), UUID endpoint isolation with 404 responses for unauthorized access, ownership checks on all sensitive operations, and CI grep gate for missing tenant filters
|
||||
- **Tenant Isolation Phase 2** — PostgreSQL Row Level Security (RLS) on 11 session-related tables (ai_sessions, session_steps, session_tags, etc.), account_id NOT NULL enforcement on all write paths, Alembic migrations with dual-env support (Railway native vars + explicit DATABASE_URL_SYNC), RLS test coverage with cross-account isolation verification, migration CI/CD integration
|
||||
- **Tenant Isolation Phase 3** — RLS on audit_logs and tree_shares tables, cross-tenant session access for public shares (via get_admin_db), complete account_id propagation across PSA integration write paths, final RLS policy enforcement
|
||||
- **Script Library default view** — "All Scripts" tab now displays all accessible scripts (team + library)
|
||||
- **Session documentation overhaul** — reformatted PSA resolution/escalation notes with cleaner headers, inline engineer responses, decimal hour display (0.25 hrs), follow-up recommendations, and improved "What We Know" section from evidence items
|
||||
- **Client communication improvements** — new `request_info` audience type for client-facing information requests, improved status update and email draft prompts with per-context guidance
|
||||
|
||||
@@ -375,6 +375,12 @@ gh run view <id> --json jobs --jq '.jobs[] | {name: .name, conclusion: .conclusi
|
||||
|
||||
**106. Guard async "select item → load data → apply state" flows with a ref:** When a component lets the user switch between items (chat sessions, flows, scripts) and loads data asynchronously on each switch, the load for item A can complete *after* the user has already switched to item B — overwriting B's state with A's stale data. Fix pattern: keep a `currentSelectionRef = useRef(initialId)` and update it synchronously whenever the selection changes (in every creation/switch path). After every `await`, bail out if `currentSelectionRef.current !== thisItemId`. See `AssistantChatPage.tsx` `selectChat` for the reference implementation (`currentChatRef`).
|
||||
|
||||
**107. Startup routines must use `_admin_session_factory()` after Phase 4 RLS:** Any code that runs at startup (lifespan, `ensure_service_account`, seed scripts) and touches tenant-isolated tables (`users`, etc.) must use `_admin_session_factory()` — not `get_db()`. Phase 4 enabled RLS on `users`; a tenant-scoped session has no `app.current_account_id` set at startup, so all queries return 0 rows or fail. `get_service_account_id` in `deps.py` is safe — it reads from `app.state` cached at startup, never hits the DB per-request.
|
||||
|
||||
**108. Tables with no `account_id` column (never add to RLS migrations):** `script_categories`, `platform_steps`, `template_trees`, `plan_feature_defaults`, `accounts` — global/platform tables documented with "No account_id. No RLS." in their model files. When writing RLS migrations, scan at the class level (check for `account_id: Mapped` within the class block), not the file level — multiple classes in one `.py` file can have different columns (e.g. `ScriptCategory` vs `ScriptTemplate` in `script_template.py`).
|
||||
|
||||
**109. `tree_shares.account_id` must equal `tree.account_id`, not the actor's account:** When creating a `TreeShare`, always use `account_id=tree.account_id` (tree owner's tenant). A super admin in tenant A sharing tenant B's tree must produce a share row in tenant B's RLS context — using `current_user.account_id` instead makes the share invisible to the tree owner after RLS is enforced.
|
||||
|
||||
## RBAC & Permissions
|
||||
|
||||
- **Role hierarchy:** super_admin > team_admin > engineer > viewer
|
||||
@@ -522,7 +528,7 @@ When a feature, fix, or significant piece of work is finished and merged/committ
|
||||
<!-- gitnexus:start -->
|
||||
# GitNexus — Code Intelligence
|
||||
|
||||
This project is indexed by GitNexus as **resolutionflow** (14787 symbols, 31366 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||
This project is indexed by GitNexus as **resolutionflow** (16703 symbols, 35922 relationships, 300 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
|
||||
|
||||
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
> **Purpose:** Quick-reference file showing exactly where the project stands.
|
||||
> **For Claude Code:** Read this first to understand what's done and what's next.
|
||||
> **Last Updated:** March 23, 2026
|
||||
> **Last Updated:** April 12, 2026
|
||||
|
||||
---
|
||||
|
||||
@@ -163,6 +163,13 @@
|
||||
- SQL wildcard escaping in tag search
|
||||
- PSA credentials encrypted at rest (Fernet)
|
||||
|
||||
### Tenant Isolation (Phases 1-4 Complete)
|
||||
- PostgreSQL RLS enabled across tenant-scoped tables in phased rollout
|
||||
- `account_id` propagation completed across core content, sessions, analytics, notifications, shares, and remaining Phase 4 tables
|
||||
- Global platform tables correctly excluded from tenant RLS where they have no `account_id` (`script_categories`, `platform_steps`, `template_trees`)
|
||||
- Runtime bootstrap paths updated to use BYPASSRLS/admin sessions where needed (auth/user mutations, startup service account, background jobs, seed scripts)
|
||||
- Preview Railway backend and frontend deployments green for PR 136 after the Phase 4 fixes
|
||||
|
||||
### Copilot-First Dashboard (March 2026)
|
||||
|
||||
- Redesigned dashboard as FlowPilot copilot launchpad (ChatGPT-style input)
|
||||
|
||||
@@ -29,13 +29,37 @@ from app.models.session_branch import SessionBranch # noqa: F401
|
||||
from app.models.fork_point import ForkPoint # noqa: F401
|
||||
from app.models.session_handoff import SessionHandoff # noqa: F401
|
||||
from app.models.session_resolution_output import SessionResolutionOutput # noqa: F401
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _alembic_sync_url() -> str:
|
||||
"""Return a psycopg2-compatible sync URL for Alembic.
|
||||
|
||||
Priority order:
|
||||
1. DATABASE_URL_SYNC — in Railway this is set as a reference variable
|
||||
(${{pgvector.DATABASE_URL}}) that resolves to the correct postgres
|
||||
superuser credentials for the current environment (production, PR preview,
|
||||
etc.). This always works even on fresh databases before any custom roles
|
||||
have been created, because it uses the postgres superuser.
|
||||
2. ADMIN_DATABASE_URL (resolutionflow_admin, BYPASSRLS) converted to a sync
|
||||
driver — fallback for local dev where DATABASE_URL_SYNC may not be set.
|
||||
"""
|
||||
if settings.DATABASE_URL_SYNC:
|
||||
return settings.DATABASE_URL_SYNC
|
||||
|
||||
admin_url = settings.ADMIN_DATABASE_URL
|
||||
if admin_url and "+asyncpg" in admin_url:
|
||||
return admin_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
return settings.DATABASE_URL_SYNC
|
||||
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url with the sync version for migrations
|
||||
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL_SYNC)
|
||||
config.set_main_option("sqlalchemy.url", _alembic_sync_url())
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None:
|
||||
@@ -86,7 +110,7 @@ def run_migrations_online() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
connectable = create_engine(
|
||||
settings.DATABASE_URL_SYNC,
|
||||
_alembic_sync_url(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
|
||||
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
59
backend/alembic/versions/04f013768235_enable_rls_phase3.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Enable RLS on Phase 3 tables.
|
||||
|
||||
Tables covered:
|
||||
- step_ratings (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- step_usage_log (account_id NOT NULL since migration 7167e9374b0c)
|
||||
- target_lists (account_id NOT NULL since migration 2c6aabd89bc6)
|
||||
- session_shares (account_id NOT NULL since session_share model)
|
||||
- audit_logs (account_id NOT NULL since migration 2a9056eddd90)
|
||||
- tree_shares (account_id NOT NULL since migration a05e1a1bea7c)
|
||||
|
||||
All use a standard intra-tenant isolation policy.
|
||||
Token-based access to session_shares and tree_shares goes through
|
||||
endpoints that use get_admin_db (BYPASSRLS), so a strict tenant
|
||||
policy here is correct.
|
||||
|
||||
Revision ID: 04f013768235
|
||||
Revises: a05e1a1bea7c
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = '04f013768235'
|
||||
down_revision: Union[str, None] = 'a05e1a1bea7c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000')::uuid"
|
||||
)
|
||||
|
||||
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||
|
||||
_PHASE3_TABLES = [
|
||||
"step_ratings",
|
||||
"step_usage_log",
|
||||
"target_lists",
|
||||
"session_shares",
|
||||
"audit_logs",
|
||||
"tree_shares",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_STANDARD_USING})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _PHASE3_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
132
backend/alembic/versions/073_add_device_types_table.py
Normal file
132
backend/alembic/versions/073_add_device_types_table.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Add account-scoped device_types table with platform seed data.
|
||||
|
||||
Revision ID: 073
|
||||
Revises: b3c7e9f2a1d8
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import uuid
|
||||
|
||||
|
||||
revision = "073"
|
||||
down_revision = "b3c7e9f2a1d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_PLATFORM_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
SYSTEM_DEVICE_TYPES = [
|
||||
("router", "Router", "network", 0),
|
||||
("switch", "Switch", "network", 1),
|
||||
("firewall", "Firewall", "network", 2),
|
||||
("access-point", "Access Point", "network", 3),
|
||||
("load-balancer", "Load Balancer", "network", 4),
|
||||
("server", "Server", "compute", 0),
|
||||
("workstation", "Workstation", "compute", 1),
|
||||
("vm", "Virtual Machine", "compute", 2),
|
||||
("container", "Container", "compute", 3),
|
||||
("nas", "NAS", "storage", 0),
|
||||
("san", "SAN", "storage", 1),
|
||||
("cloud-storage", "Cloud Storage", "storage", 2),
|
||||
("cloud", "Cloud", "cloud", 0),
|
||||
("aws", "AWS", "cloud", 1),
|
||||
("azure", "Azure", "cloud", 2),
|
||||
("gcp", "Google Cloud", "cloud", 3),
|
||||
("printer", "Printer", "endpoint", 0),
|
||||
("phone", "Phone", "endpoint", 1),
|
||||
("iot", "IoT Device", "endpoint", 2),
|
||||
("camera", "Camera", "endpoint", 3),
|
||||
("tablet", "Tablet", "endpoint", 4),
|
||||
("laptop", "Laptop", "endpoint", 5),
|
||||
("ups", "UPS", "infrastructure", 0),
|
||||
("pdu", "PDU", "infrastructure", 1),
|
||||
("rack", "Rack", "infrastructure", 2),
|
||||
("patch-panel", "Patch Panel", "infrastructure", 3),
|
||||
("nvr", "NVR", "security", 0),
|
||||
("badge-reader", "Badge Reader", "security", 1),
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"device_types",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("slug", sa.String(50), nullable=False),
|
||||
sa.Column("label", sa.String(100), nullable=False),
|
||||
sa.Column("category", sa.String(50), nullable=False),
|
||||
sa.Column("is_system", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
op.create_unique_constraint("uq_device_types_slug_account", "device_types", ["slug", "account_id"])
|
||||
op.create_index("ix_device_types_account_id", "device_types", ["account_id"])
|
||||
|
||||
device_types_table = sa.table(
|
||||
"device_types",
|
||||
sa.column("id", UUID(as_uuid=True)),
|
||||
sa.column("slug", sa.String),
|
||||
sa.column("label", sa.String),
|
||||
sa.column("category", sa.String),
|
||||
sa.column("is_system", sa.Boolean),
|
||||
sa.column("account_id", UUID(as_uuid=True)),
|
||||
sa.column("sort_order", sa.Integer),
|
||||
)
|
||||
|
||||
op.bulk_insert(device_types_table, [
|
||||
{
|
||||
"id": uuid.uuid4(),
|
||||
"slug": slug,
|
||||
"label": label,
|
||||
"category": category,
|
||||
"is_system": True,
|
||||
"account_id": uuid.UUID(_PLATFORM_UUID),
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
for slug, label, category, sort_order in SYSTEM_DEVICE_TYPES
|
||||
])
|
||||
|
||||
op.execute("ALTER TABLE device_types ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE device_types FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_select ON device_types
|
||||
FOR SELECT
|
||||
USING (
|
||||
account_id = {_CURRENT_ACCOUNT}
|
||||
OR account_id = '{_PLATFORM_UUID}'::uuid
|
||||
)
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_insert ON device_types
|
||||
FOR INSERT
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_update ON device_types
|
||||
FOR UPDATE
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
op.execute(f"""
|
||||
CREATE POLICY device_types_delete ON device_types
|
||||
FOR DELETE
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS device_types_delete ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_update ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_insert ON device_types")
|
||||
op.execute("DROP POLICY IF EXISTS device_types_select ON device_types")
|
||||
op.execute("ALTER TABLE device_types DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_table("device_types")
|
||||
57
backend/alembic/versions/074_add_network_diagrams_table.py
Normal file
57
backend/alembic/versions/074_add_network_diagrams_table.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Add network_diagrams table.
|
||||
|
||||
Revision ID: 074
|
||||
Revises: 073
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
|
||||
revision = "074"
|
||||
down_revision = "073"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_CURRENT_ACCOUNT = (
|
||||
"COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"network_diagrams",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("account_id", UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("client_name", sa.String(255), nullable=True),
|
||||
sa.Column("asset_name", sa.String(255), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("nodes", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("edges", JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("thumbnail_url", sa.Text(), nullable=True),
|
||||
sa.Column("is_archived", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("created_by", UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
op.create_index("ix_network_diagrams_account_id", "network_diagrams", ["account_id"])
|
||||
op.create_index("idx_network_diagrams_account_client", "network_diagrams", ["account_id", "client_name"])
|
||||
op.execute("ALTER TABLE network_diagrams ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE network_diagrams FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON network_diagrams
|
||||
USING (account_id = {_CURRENT_ACCOUNT})
|
||||
WITH CHECK (account_id = {_CURRENT_ACCOUNT})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP POLICY IF EXISTS tenant_isolation ON network_diagrams")
|
||||
op.execute("ALTER TABLE network_diagrams DISABLE ROW LEVEL SECURITY")
|
||||
op.drop_table("network_diagrams")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Drop team_id from target_lists.
|
||||
|
||||
account_id (NOT NULL) is now the tenant isolation key; team_id is redundant.
|
||||
All reads/writes use account_id via RLS + application filter.
|
||||
|
||||
Revision ID: 172ad76d7d20
|
||||
Revises: 04f013768235
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '172ad76d7d20'
|
||||
down_revision: Union[str, None] = '04f013768235'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_index('ix_target_lists_team_id', table_name='target_lists', if_exists=True)
|
||||
op.drop_constraint('target_lists_team_id_fkey', 'target_lists', type_='foreignkey')
|
||||
op.drop_column('target_lists', 'team_id')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column('target_lists', sa.Column('team_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'target_lists_team_id_fkey', 'target_lists', 'teams',
|
||||
['team_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
op.create_index('ix_target_lists_team_id', 'target_lists', ['team_id'])
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add account_id to audit_logs and backfill via user_id.
|
||||
|
||||
Revision ID: 2a9056eddd90
|
||||
Revises: 70a5dd746e83
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '2a9056eddd90'
|
||||
down_revision: Union[str, None] = '70a5dd746e83'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('audit_logs', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_audit_logs_account_id', 'audit_logs', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the acting user's account
|
||||
op.execute("""
|
||||
UPDATE audit_logs al
|
||||
SET account_id = u.account_id
|
||||
FROM users u
|
||||
WHERE al.user_id = u.id
|
||||
AND u.account_id IS NOT NULL
|
||||
AND al.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM audit_logs WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} audit_logs rows have NULL account_id after backfill. "
|
||||
"All audit log entries must have an associated user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('audit_logs', 'account_id', nullable=False)
|
||||
op.create_index('ix_audit_logs_account_id', 'audit_logs', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_audit_logs_account_id', table_name='audit_logs')
|
||||
op.drop_constraint('fk_audit_logs_account_id', 'audit_logs', type_='foreignkey')
|
||||
op.drop_column('audit_logs', 'account_id')
|
||||
90
backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py
Normal file
90
backend/alembic/versions/70a5dd746e83_enable_rls_phase2.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Enable RLS on Phase 2 session and supporting tables.
|
||||
|
||||
10 tables use a standard tenant-only policy.
|
||||
step_library uses a visibility-aware policy — public steps visible to all tenants.
|
||||
|
||||
NOTE: session_messages does not exist in this codebase (removed from plan).
|
||||
script_generations is the correct table name (not script_template_generations).
|
||||
sessions and ai_sessions are two separate tables, both in scope.
|
||||
|
||||
Prerequisites:
|
||||
- Phase 1 migration must have run (resolutionflow_app role exists, Phase 1 tables have RLS)
|
||||
- NOT NULL write-path bugs fixed (P2-A commits b641ac6)
|
||||
- shares.py cross-tenant session fix deployed (P2-B commit ac2b193)
|
||||
|
||||
Revision ID: 70a5dd746e83
|
||||
Revises: c5f48b9890f9
|
||||
Create Date: 2026-04-10 06:54:49.431817
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '70a5dd746e83'
|
||||
down_revision: Union[str, None] = 'c5f48b9890f9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_NULL_UUID = "00000000-0000-0000-0000-000000000000"
|
||||
_CURRENT_ACCOUNT = (
|
||||
f"COALESCE(NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
f"'{_NULL_UUID}')::uuid"
|
||||
)
|
||||
|
||||
# Standard tenant-only policy — account_id must match the current tenant.
|
||||
# When no tenant context is set, COALESCE returns the nil UUID so zero rows
|
||||
# are visible (fail-closed).
|
||||
_STANDARD_USING = f"account_id = {_CURRENT_ACCOUNT}"
|
||||
|
||||
# Visibility-aware policy for step_library — public steps (visibility='public')
|
||||
# must be visible to ALL tenants regardless of account_id. This covers the
|
||||
# visibility='public' arm of build_step_visibility_filter() in app/core/filters.py.
|
||||
# The created_by arm (private steps visible to their author) is covered
|
||||
# transitively: private steps share account_id with their creator, so the
|
||||
# account_id match handles it. This relies on account_id NOT NULL on step_library.
|
||||
_STEP_LIBRARY_USING = f"account_id = {_CURRENT_ACCOUNT} OR visibility = 'public'"
|
||||
|
||||
# Standard tables: strict tenant isolation, no cross-tenant visibility.
|
||||
_STANDARD_TABLES = [
|
||||
"sessions",
|
||||
"ai_sessions",
|
||||
"session_branches",
|
||||
"session_supporting_data",
|
||||
"session_resolution_outputs",
|
||||
"session_handoffs",
|
||||
"script_templates",
|
||||
"script_generations",
|
||||
"maintenance_schedules",
|
||||
"psa_post_log",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Standard tenant-isolation tables ────────────────────────────────────
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_STANDARD_USING})
|
||||
""")
|
||||
|
||||
# ── step_library ────────────────────────────────────────────────────────
|
||||
# Public steps (visibility='public') must be readable by all tenants so
|
||||
# the Solutions Library browsing experience works without tenant context.
|
||||
# Private/team steps remain tenant-scoped.
|
||||
op.execute("ALTER TABLE step_library ENABLE ROW LEVEL SECURITY")
|
||||
op.execute("ALTER TABLE step_library FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON step_library
|
||||
USING ({_STEP_LIBRARY_USING})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _STANDARD_TABLES + ["step_library"]:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add account_id to tree_shares and backfill via tree owner's account.
|
||||
|
||||
The share belongs to the tree's tenant, not the actor who created it.
|
||||
A super admin in account A can share a tree owned by account B; that share
|
||||
must land in account B so account B's RLS filter sees it.
|
||||
|
||||
Revision ID: a05e1a1bea7c
|
||||
Revises: 2a9056eddd90
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = 'a05e1a1bea7c'
|
||||
down_revision: Union[str, None] = '2a9056eddd90'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('tree_shares', sa.Column('account_id', sa.UUID(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_tree_shares_account_id', 'tree_shares', 'accounts',
|
||||
['account_id'], ['id'], ondelete='CASCADE',
|
||||
)
|
||||
|
||||
# Backfill: derive from the tree's account, not the creator's account.
|
||||
# A share lives in the same tenant as its tree so that the tree owner's
|
||||
# RLS context covers their own shares regardless of who created them.
|
||||
op.execute("""
|
||||
UPDATE tree_shares ts
|
||||
SET account_id = t.account_id
|
||||
FROM trees t
|
||||
WHERE ts.tree_id = t.id
|
||||
AND t.account_id IS NOT NULL
|
||||
AND ts.account_id IS NULL
|
||||
""")
|
||||
|
||||
result = op.get_bind().execute(
|
||||
sa.text("SELECT COUNT(*) FROM tree_shares WHERE account_id IS NULL")
|
||||
)
|
||||
count = result.scalar()
|
||||
if count > 0:
|
||||
raise RuntimeError(
|
||||
f"ROLLBACK: {count} tree_shares rows have NULL account_id after backfill. "
|
||||
"All share entries must have a creating user with an account."
|
||||
)
|
||||
|
||||
op.alter_column('tree_shares', 'account_id', nullable=False)
|
||||
op.create_index('ix_tree_shares_account_id', 'tree_shares', ['account_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_tree_shares_account_id', table_name='tree_shares')
|
||||
op.drop_constraint('fk_tree_shares_account_id', 'tree_shares', type_='foreignkey')
|
||||
op.drop_column('tree_shares', 'account_id')
|
||||
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
85
backend/alembic/versions/b3c7e9f2a1d8_enable_rls_phase4.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Enable RLS on Phase 4 tables — all remaining tenant-scoped tables.
|
||||
|
||||
All tables in this migration already have account_id NOT NULL (enforced by
|
||||
earlier migrations). This migration adds ENABLE ROW LEVEL SECURITY,
|
||||
FORCE ROW LEVEL SECURITY, and the appropriate tenant isolation policy to each.
|
||||
|
||||
Policy variants used:
|
||||
- Standard: account_id = current_setting(app.current_account_id)::uuid
|
||||
- Platform: standard OR account_id = PLATFORM_ACCOUNT_ID
|
||||
(for global content tables readable by all tenants)
|
||||
|
||||
Skipped intentionally:
|
||||
- accounts — IS the root table; no account_id column
|
||||
- plan_feature_defaults — platform config; no account_id column
|
||||
- script_categories — global lookup table; no account_id column
|
||||
- platform_steps — global content; no account_id column (readable by all)
|
||||
- template_trees — global content; no account_id column (readable by all)
|
||||
|
||||
Revision ID: b3c7e9f2a1d8
|
||||
Revises: 172ad76d7d20
|
||||
Create Date: 2026-04-12
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b3c7e9f2a1d8"
|
||||
down_revision: Union[str, None] = "172ad76d7d20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Standard policy — tenant sees only own rows.
|
||||
_STANDARD_TABLES = [
|
||||
"users",
|
||||
"account_invites",
|
||||
"account_limit_overrides",
|
||||
"account_feature_overrides",
|
||||
"subscriptions",
|
||||
"ai_chat_sessions",
|
||||
"ai_conversations",
|
||||
"ai_session_steps",
|
||||
"ai_session_embeddings",
|
||||
"ai_suggestions",
|
||||
"ai_usage",
|
||||
"assistant_chats",
|
||||
"attachments",
|
||||
"copilot_conversations",
|
||||
"feedback",
|
||||
"file_uploads",
|
||||
"fork_points",
|
||||
"kb_imports",
|
||||
"notifications",
|
||||
"notification_configs",
|
||||
"notification_logs",
|
||||
"psa_activity_logs",
|
||||
"psa_member_mappings",
|
||||
"script_builder_sessions",
|
||||
"session_ratings",
|
||||
"tree_embeddings",
|
||||
"user_folders",
|
||||
"user_pinned_trees",
|
||||
]
|
||||
|
||||
_POLICY_EXPR = (
|
||||
"account_id = COALESCE("
|
||||
"NULLIF(current_setting('app.current_account_id', TRUE), ''), "
|
||||
"'00000000-0000-0000-0000-000000000000'"
|
||||
")::uuid"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
|
||||
op.execute(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
|
||||
op.execute(f"""
|
||||
CREATE POLICY tenant_isolation ON {table}
|
||||
USING ({_POLICY_EXPR})
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for table in _STANDARD_TABLES:
|
||||
op.execute(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")
|
||||
op.execute(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
|
||||
@@ -24,10 +24,14 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
token: Annotated[str, Depends(oauth2_scheme)]
|
||||
) -> User:
|
||||
"""Get current authenticated user from JWT token."""
|
||||
"""Get current authenticated user from JWT token.
|
||||
|
||||
Must use get_admin_db (BYPASSRLS): this dep runs before require_tenant_context
|
||||
sets app.current_account_id, so the users table RLS would block the lookup.
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -77,10 +81,14 @@ async def get_refresh_token_payload(
|
||||
async def get_current_active_user(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> User:
|
||||
"""Ensure user is active (not disabled). Auto-downgrades expired trials.
|
||||
Enforces must_change_password — blocks all routes except allowlist."""
|
||||
Enforces must_change_password — blocks all routes except allowlist.
|
||||
|
||||
Uses get_admin_db: runs before require_tenant_context sets the ContextVar,
|
||||
so tenant-scoped tables (subscriptions) would return 0 rows via app role.
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.core.audit import log_audit
|
||||
from app.models.refresh_token import RefreshToken
|
||||
@@ -148,7 +149,7 @@ async def update_member_role(
|
||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||
async def transfer_ownership(
|
||||
data: TransferOwnershipRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Transfer account ownership to another member (owner only)."""
|
||||
@@ -377,7 +378,7 @@ async def list_invites(
|
||||
|
||||
@router.post("/me/leave")
|
||||
async def leave_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Leave the current account (non-owners only). Creates a personal account."""
|
||||
@@ -423,7 +424,7 @@ class DeleteAccountRequest(BaseModel):
|
||||
@router.delete("/me")
|
||||
async def delete_account(
|
||||
data: DeleteAccountRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
||||
|
||||
@@ -43,6 +43,7 @@ async def create_suggestion(
|
||||
suggestion = AISuggestion(
|
||||
tree_id=data.tree_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
session_id=data.session_id,
|
||||
action_type=data.action_type,
|
||||
target_node_id=data.target_node_id,
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update as sa_update
|
||||
from app.core.config import settings
|
||||
from app.core.settings_manager import SettingsManager
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
@@ -67,7 +67,7 @@ def _generate_display_code() -> str:
|
||||
async def register(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Register a new user.
|
||||
|
||||
@@ -232,7 +232,7 @@ async def register(
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login and get access token."""
|
||||
# Find user by email
|
||||
@@ -270,7 +270,7 @@ async def login(
|
||||
async def login_json(
|
||||
request: Request,
|
||||
credentials: UserLogin,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Login with JSON body (alternative to form data)."""
|
||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||
@@ -304,7 +304,7 @@ async def login_json(
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Refresh access token using refresh token (rotation: old token is revoked)."""
|
||||
user_id = payload.get("sub")
|
||||
@@ -368,7 +368,7 @@ async def get_me(
|
||||
async def update_me(
|
||||
data: UserUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Update current user's profile (name, email)."""
|
||||
update_fields = data.model_fields_set - {"current_password"}
|
||||
@@ -415,7 +415,7 @@ async def update_me(
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Logout user by revoking the refresh token."""
|
||||
jti = payload.get("jti")
|
||||
@@ -438,7 +438,7 @@ async def change_password(
|
||||
request: Request,
|
||||
data: ChangePasswordRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
@@ -478,7 +478,7 @@ async def change_password(
|
||||
async def forgot_password(
|
||||
request: Request,
|
||||
data: ForgotPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Request a password reset email. Always returns success (anti-enumeration)."""
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
@@ -513,7 +513,7 @@ async def forgot_password(
|
||||
@router.post("/password/verify-reset-token", response_model=VerifyResetTokenResponse)
|
||||
async def verify_reset_token(
|
||||
data: VerifyResetTokenRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify a password reset token is valid."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -544,7 +544,7 @@ async def verify_reset_token(
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
data: ResetPasswordRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Reset password using a valid reset token."""
|
||||
payload = decode_token(data.token)
|
||||
@@ -611,7 +611,7 @@ async def reset_password(
|
||||
|
||||
@router.get("/email/verification-status")
|
||||
async def get_verification_status(
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Check if email verification is enabled on the platform."""
|
||||
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -623,7 +623,7 @@ async def get_verification_status(
|
||||
async def send_verification_email(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Send an email verification link to the current user."""
|
||||
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
@@ -662,7 +662,7 @@ async def send_verification_email(
|
||||
@router.post("/email/verify")
|
||||
async def verify_email(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Verify an email using a token. Public endpoint."""
|
||||
token = data.get("token")
|
||||
|
||||
120
backend/app/api/endpoints/device_types.py
Normal file
120
backend/app/api/endpoints/device_types.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Device types API endpoints."""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.device_type import DeviceType
|
||||
from app.schemas.device_type import (
|
||||
DeviceTypeCreate,
|
||||
DeviceTypeUpdate,
|
||||
DeviceTypeResponse,
|
||||
)
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
router = APIRouter(prefix="/device-types", tags=["device-types"])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[DeviceTypeResponse])
|
||||
async def list_device_types(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[DeviceTypeResponse]:
|
||||
stmt = (
|
||||
select(DeviceType)
|
||||
.where(
|
||||
or_(
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
.order_by(DeviceType.category, DeviceType.sort_order, DeviceType.label)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [DeviceTypeResponse.model_validate(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/", response_model=DeviceTypeResponse, status_code=201)
|
||||
async def create_device_type(
|
||||
data: DeviceTypeCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DeviceTypeResponse:
|
||||
existing = await db.execute(
|
||||
select(DeviceType).where(
|
||||
DeviceType.slug == data.slug,
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' already exists for your account")
|
||||
|
||||
system_existing = await db.execute(
|
||||
select(DeviceType).where(
|
||||
DeviceType.slug == data.slug,
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
)
|
||||
)
|
||||
if system_existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Device type '{data.slug}' conflicts with a system type")
|
||||
|
||||
device_type = DeviceType(
|
||||
slug=data.slug,
|
||||
label=data.label,
|
||||
category=data.category,
|
||||
is_system=False,
|
||||
account_id=current_user.account_id,
|
||||
sort_order=data.sort_order,
|
||||
)
|
||||
db.add(device_type)
|
||||
await db.commit()
|
||||
await db.refresh(device_type)
|
||||
return DeviceTypeResponse.model_validate(device_type)
|
||||
|
||||
|
||||
@router.put("/{device_type_id}", response_model=DeviceTypeResponse)
|
||||
async def update_device_type(
|
||||
device_type_id: UUID,
|
||||
data: DeviceTypeUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DeviceTypeResponse:
|
||||
device_type = await db.get(DeviceType, device_type_id)
|
||||
if not device_type:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
if device_type.is_system:
|
||||
raise HTTPException(status_code=403, detail="Cannot modify system device types")
|
||||
if device_type.account_id != current_user.account_id:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(device_type, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(device_type)
|
||||
return DeviceTypeResponse.model_validate(device_type)
|
||||
|
||||
|
||||
@router.delete("/{device_type_id}", status_code=204)
|
||||
async def delete_device_type(
|
||||
device_type_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> None:
|
||||
device_type = await db.get(DeviceType, device_type_id)
|
||||
if not device_type:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
if device_type.is_system:
|
||||
raise HTTPException(status_code=403, detail="Cannot delete system device types")
|
||||
if device_type.account_id != current_user.account_id:
|
||||
raise HTTPException(status_code=404, detail="Device type not found")
|
||||
|
||||
await db.delete(device_type)
|
||||
await db.commit()
|
||||
@@ -69,6 +69,7 @@ async def create_schedule(
|
||||
|
||||
schedule = MaintenanceSchedule(
|
||||
tree_id=data.tree_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
cron_expression=data.cron_expression,
|
||||
timezone=data.timezone,
|
||||
|
||||
331
backend/app/api/endpoints/network_diagrams.py
Normal file
331
backend/app/api/endpoints/network_diagrams.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Network diagrams API endpoints."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.device_type import DeviceType
|
||||
from app.models.network_diagram import NetworkDiagram
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
from app.schemas.network_diagram import (
|
||||
NetworkDiagramCreate,
|
||||
NetworkDiagramUpdate,
|
||||
NetworkDiagramResponse,
|
||||
NetworkDiagramListItem,
|
||||
AIGenerateRequest,
|
||||
AIGenerateResponse,
|
||||
DiagramImportRequest,
|
||||
DiagramImportResponse,
|
||||
DiagramExportResponse,
|
||||
DiagramNode,
|
||||
DiagramEdge,
|
||||
)
|
||||
from app.services import network_diagram_ai_service
|
||||
|
||||
# Maps system device-type slugs to their category — mirrors frontend deviceRegistry.ts
|
||||
_SLUG_CATEGORY: dict[str, str] = {
|
||||
"router": "network", "switch": "network", "access-point": "network", "load-balancer": "network",
|
||||
"firewall": "security", "badge-reader": "security",
|
||||
"server": "compute", "vm": "compute", "container": "compute",
|
||||
"nas": "storage", "san": "storage", "cloud-storage": "storage",
|
||||
"cloud": "cloud", "aws": "cloud", "azure": "cloud", "gcp": "cloud", "isp": "cloud",
|
||||
"workstation": "endpoint", "laptop": "endpoint", "tablet": "endpoint",
|
||||
"phone": "endpoint", "printer": "endpoint",
|
||||
"ups": "infrastructure", "pdu": "infrastructure", "rack": "infrastructure",
|
||||
"patch-panel": "infrastructure", "camera": "infrastructure",
|
||||
"nvr": "infrastructure", "iot": "infrastructure",
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/network-diagrams", tags=["network-diagrams"])
|
||||
|
||||
|
||||
async def _get_diagram_or_404(
|
||||
diagram_id: UUID,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
) -> NetworkDiagram:
|
||||
diagram = await db.get(NetworkDiagram, diagram_id)
|
||||
if not diagram or diagram.account_id != account_id or diagram.is_archived:
|
||||
raise HTTPException(status_code=404, detail="Diagram not found")
|
||||
return diagram
|
||||
|
||||
|
||||
def _diagram_to_response(diagram: NetworkDiagram) -> NetworkDiagramResponse:
|
||||
return NetworkDiagramResponse.model_validate(diagram)
|
||||
|
||||
|
||||
def _diagram_to_list_item(
|
||||
diagram: NetworkDiagram,
|
||||
custom_slug_category: dict[str, str] | None = None,
|
||||
) -> NetworkDiagramListItem:
|
||||
nodes = diagram.nodes if isinstance(diagram.nodes, list) else []
|
||||
slug_to_cat = {**_SLUG_CATEGORY, **(custom_slug_category or {})}
|
||||
|
||||
category_counts: dict[str, int] = {}
|
||||
for node in nodes:
|
||||
slug = node.get("type", "") if isinstance(node, dict) else ""
|
||||
cat = slug_to_cat.get(slug, "other")
|
||||
category_counts[cat] = category_counts.get(cat, 0) + 1
|
||||
|
||||
return NetworkDiagramListItem(
|
||||
id=diagram.id,
|
||||
name=diagram.name,
|
||||
client_name=diagram.client_name,
|
||||
description=diagram.description,
|
||||
node_count=len(nodes),
|
||||
category_counts=category_counts,
|
||||
created_by=diagram.created_by,
|
||||
created_at=diagram.created_at,
|
||||
updated_at=diagram.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def _get_available_slugs(account_id: UUID, db: AsyncSession) -> set[str]:
|
||||
stmt = select(DeviceType.slug).where(
|
||||
or_(
|
||||
DeviceType.account_id == PLATFORM_ACCOUNT_ID,
|
||||
DeviceType.account_id == account_id,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return {row[0] for row in result.all()}
|
||||
|
||||
|
||||
@router.get("/clients", response_model=list[str])
|
||||
async def list_client_names(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> list[str]:
|
||||
stmt = (
|
||||
select(NetworkDiagram.client_name)
|
||||
.where(
|
||||
NetworkDiagram.account_id == current_user.account_id,
|
||||
NetworkDiagram.is_archived.is_(False),
|
||||
NetworkDiagram.client_name.isnot(None),
|
||||
NetworkDiagram.client_name != "",
|
||||
)
|
||||
.distinct()
|
||||
.order_by(NetworkDiagram.client_name)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return [row[0] for row in result.all()]
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NetworkDiagramListItem])
|
||||
async def list_diagrams(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
client_name: str | None = Query(default=None),
|
||||
search: str | None = Query(default=None),
|
||||
) -> list[NetworkDiagramListItem]:
|
||||
stmt = (
|
||||
select(NetworkDiagram)
|
||||
.where(
|
||||
NetworkDiagram.account_id == current_user.account_id,
|
||||
NetworkDiagram.is_archived.is_(False),
|
||||
)
|
||||
.order_by(NetworkDiagram.updated_at.desc())
|
||||
)
|
||||
|
||||
if client_name:
|
||||
stmt = stmt.where(NetworkDiagram.client_name == client_name)
|
||||
|
||||
if search:
|
||||
escaped = search.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
search_filter = f"%{escaped}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
NetworkDiagram.name.ilike(search_filter),
|
||||
NetworkDiagram.client_name.ilike(search_filter),
|
||||
)
|
||||
)
|
||||
|
||||
# Single query for custom device types so category_counts is accurate
|
||||
dt_stmt = select(DeviceType.slug, DeviceType.category).where(
|
||||
DeviceType.is_system.is_(False),
|
||||
DeviceType.account_id == current_user.account_id,
|
||||
)
|
||||
dt_result = await db.execute(dt_stmt)
|
||||
custom_slug_category = {row[0]: row[1] for row in dt_result.all()}
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [_diagram_to_list_item(r, custom_slug_category) for r in rows]
|
||||
|
||||
|
||||
@router.post("/", response_model=NetworkDiagramResponse, status_code=201)
|
||||
async def create_diagram(
|
||||
data: NetworkDiagramCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=data.name,
|
||||
client_name=data.client_name,
|
||||
asset_name=data.asset_name,
|
||||
description=data.description,
|
||||
nodes=[n.model_dump() for n in data.nodes],
|
||||
edges=[e.model_dump() for e in data.edges],
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(diagram)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.get("/{diagram_id}", response_model=NetworkDiagramResponse)
|
||||
async def get_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.put("/{diagram_id}", response_model=NetworkDiagramResponse)
|
||||
async def update_diagram(
|
||||
diagram_id: UUID,
|
||||
data: NetworkDiagramUpdate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
if "nodes" in update_data and update_data["nodes"] is not None:
|
||||
update_data["nodes"] = [n.model_dump() if hasattr(n, "model_dump") else n for n in update_data["nodes"]]
|
||||
if "edges" in update_data and update_data["edges"] is not None:
|
||||
update_data["edges"] = [e.model_dump() if hasattr(e, "model_dump") else e for e in update_data["edges"]]
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(diagram, field, value)
|
||||
|
||||
diagram.updated_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
return _diagram_to_response(diagram)
|
||||
|
||||
|
||||
@router.delete("/{diagram_id}", status_code=204)
|
||||
async def archive_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> None:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
diagram.is_archived = True
|
||||
diagram.updated_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/{diagram_id}/duplicate", response_model=NetworkDiagramResponse, status_code=201)
|
||||
async def duplicate_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> NetworkDiagramResponse:
|
||||
source = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
copy = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=f"Copy of {source.name}",
|
||||
client_name=source.client_name,
|
||||
asset_name=source.asset_name,
|
||||
description=source.description,
|
||||
nodes=source.nodes,
|
||||
edges=source.edges,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(copy)
|
||||
await db.commit()
|
||||
await db.refresh(copy)
|
||||
return _diagram_to_response(copy)
|
||||
|
||||
|
||||
@router.get("/{diagram_id}/export", response_model=DiagramExportResponse)
|
||||
async def export_diagram(
|
||||
diagram_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DiagramExportResponse:
|
||||
diagram = await _get_diagram_or_404(diagram_id, current_user.account_id, db)
|
||||
nodes = [DiagramNode(**n) for n in (diagram.nodes or [])]
|
||||
edges = [DiagramEdge(**e) for e in (diagram.edges or [])]
|
||||
return DiagramExportResponse(
|
||||
schemaVersion=1,
|
||||
name=diagram.name,
|
||||
client_name=diagram.client_name,
|
||||
description=diagram.description,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
exportedAt=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", response_model=DiagramImportResponse, status_code=201)
|
||||
async def import_diagram(
|
||||
data: DiagramImportRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> DiagramImportResponse:
|
||||
available_slugs = await _get_available_slugs(current_user.account_id, db)
|
||||
|
||||
warnings: list[str] = []
|
||||
for node in data.nodes:
|
||||
if node.type not in available_slugs:
|
||||
warnings.append(f"Unknown device type '{node.type}' — will render with default icon")
|
||||
|
||||
diagram = NetworkDiagram(
|
||||
account_id=current_user.account_id,
|
||||
name=data.name,
|
||||
client_name=data.client_name,
|
||||
description=data.description,
|
||||
nodes=[n.model_dump() for n in data.nodes],
|
||||
edges=[e.model_dump() for e in data.edges],
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(diagram)
|
||||
await db.commit()
|
||||
await db.refresh(diagram)
|
||||
|
||||
return DiagramImportResponse(
|
||||
diagram=_diagram_to_response(diagram),
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/ai-generate", response_model=AIGenerateResponse)
|
||||
async def ai_generate_diagram(
|
||||
data: AIGenerateRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> AIGenerateResponse:
|
||||
available_slugs_set = await _get_available_slugs(current_user.account_id, db)
|
||||
available_slugs = list(available_slugs_set)
|
||||
|
||||
existing_node_ids: list[str] | None = None
|
||||
if data.mode == "merge" and data.existingBounds:
|
||||
existing_node_ids = []
|
||||
|
||||
try:
|
||||
return await network_diagram_ai_service.generate_diagram(
|
||||
request=data,
|
||||
available_slugs=available_slugs,
|
||||
existing_node_ids=existing_node_ids,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
except Exception:
|
||||
logger.exception("AI diagram generation failed")
|
||||
raise HTTPException(status_code=500, detail="Diagram generation failed")
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_active_user
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.models.psa_connection import PsaConnection
|
||||
from app.models.session import Session
|
||||
@@ -98,7 +99,7 @@ async def get_onboarding_status(
|
||||
|
||||
@router.post("/onboarding-status/dismiss", response_model=OnboardingStatus)
|
||||
async def dismiss_onboarding(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> OnboardingStatus:
|
||||
"""Dismiss the onboarding checklist for the current user."""
|
||||
|
||||
@@ -91,6 +91,7 @@ async def submit_step_feedback(
|
||||
new_rating = StepRating(
|
||||
step_id=step_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
session_id=session_uuid,
|
||||
was_helpful=data.was_helpful,
|
||||
# rating is nullable now — thumbs-only mode
|
||||
|
||||
@@ -85,6 +85,7 @@ async def create_session(
|
||||
session = await script_builder_service.create_session(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
team_id=current_user.team_id,
|
||||
language=data.language,
|
||||
)
|
||||
|
||||
@@ -196,6 +196,7 @@ async def start_session(
|
||||
new_session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -693,6 +694,7 @@ async def prepare_session(
|
||||
new_session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=data.assigned_to_id or current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -770,6 +772,7 @@ async def batch_launch_sessions(
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_snapshot=tree_snapshot,
|
||||
path_taken=[],
|
||||
decisions=[],
|
||||
@@ -1102,6 +1105,7 @@ async def psa_post_to_ticket(
|
||||
# Log to audit trail
|
||||
log_entry = PsaPostLog(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
psa_connection_id=psa_connection.id if psa_connection else None,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
note_type=data.note_type,
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.models.session import Session
|
||||
from app.models.session_share import SessionShare, SessionShareView
|
||||
from app.models.user import User
|
||||
@@ -210,7 +211,7 @@ async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[Use
|
||||
async def access_share(
|
||||
share_token: str,
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
):
|
||||
"""Access a shared session via share token.
|
||||
|
||||
|
||||
@@ -460,6 +460,7 @@ async def rate_step(
|
||||
rating = StepRating(
|
||||
step_id=step_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
rating=rating_data.rating,
|
||||
was_helpful=rating_data.was_helpful,
|
||||
review_text=rating_data.review_text,
|
||||
|
||||
@@ -103,6 +103,7 @@ async def create_supporting_data(
|
||||
|
||||
item = SessionSupportingData(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
label=data.label,
|
||||
data_type=data.data_type,
|
||||
content=data.content,
|
||||
|
||||
@@ -18,12 +18,10 @@ async def list_target_lists(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""List all target lists for the current user's team."""
|
||||
if not current_user.team_id:
|
||||
return []
|
||||
"""List all target lists for the current user's account."""
|
||||
result = await db.execute(
|
||||
select(TargetList)
|
||||
.where(TargetList.team_id == current_user.team_id)
|
||||
.where(TargetList.account_id == current_user.account_id)
|
||||
.order_by(TargetList.name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
@@ -36,11 +34,9 @@ async def create_target_list(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Create a new target list for the current team."""
|
||||
if not current_user.team_id:
|
||||
raise HTTPException(status_code=400, detail="User must belong to a team")
|
||||
"""Create a new target list for the current account."""
|
||||
target_list = TargetList(
|
||||
team_id=current_user.team_id,
|
||||
account_id=current_user.account_id,
|
||||
created_by=current_user.id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
@@ -61,7 +57,7 @@ async def get_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -81,7 +77,7 @@ async def update_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
@@ -91,7 +87,7 @@ async def update_target_list(
|
||||
if "name" in update_fields and data.name is not None:
|
||||
target_list.name = data.name
|
||||
if "description" in update_fields:
|
||||
target_list.description = data.description # allow setting to None
|
||||
target_list.description = data.description
|
||||
if "targets" in update_fields and data.targets is not None:
|
||||
target_list.targets = [t.model_dump() for t in data.targets]
|
||||
await db.commit()
|
||||
@@ -109,7 +105,7 @@ async def delete_target_list(
|
||||
result = await db.execute(
|
||||
select(TargetList).where(
|
||||
TargetList.id == list_id,
|
||||
TargetList.team_id == current_user.team_id,
|
||||
TargetList.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
target_list = result.scalar_one_or_none()
|
||||
|
||||
@@ -1048,6 +1048,7 @@ async def create_tree_share(
|
||||
# Create share
|
||||
tree_share = TreeShare(
|
||||
tree_id=tree.id,
|
||||
account_id=tree.account_id, # share belongs to the tree's tenant, not the actor
|
||||
share_token=share_token,
|
||||
created_by=current_user.id,
|
||||
allow_forking=share_data.allow_forking,
|
||||
|
||||
@@ -24,6 +24,7 @@ from app.api.endpoints import (
|
||||
branding,
|
||||
categories,
|
||||
copilot,
|
||||
device_types,
|
||||
feedback,
|
||||
flow_proposals,
|
||||
flowpilot_analytics,
|
||||
@@ -32,6 +33,7 @@ from app.api.endpoints import (
|
||||
invite,
|
||||
kb_accelerator,
|
||||
maintenance_schedules,
|
||||
network_diagrams,
|
||||
notifications,
|
||||
onboarding,
|
||||
public_templates,
|
||||
@@ -93,7 +95,6 @@ api_router.include_router(admin_settings.router)
|
||||
api_router.include_router(admin_categories.router)
|
||||
api_router.include_router(admin_survey.router)
|
||||
api_router.include_router(admin_gallery.router)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-facing endpoints — tenant context required
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -130,6 +131,7 @@ api_router.include_router(integrations.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(onboarding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(branding.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(supporting_data.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(network_diagrams.router, dependencies=_tenant_deps)
|
||||
# session_handoffs queue router must come before ai_sessions to avoid conflict
|
||||
api_router.include_router(session_handoffs.queue_router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_resolutions.router, dependencies=_tenant_deps)
|
||||
@@ -142,3 +144,4 @@ api_router.include_router(script_builder.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(beta_feedback.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_branches.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(session_handoffs.router, dependencies=_tenant_deps)
|
||||
api_router.include_router(device_types.router, dependencies=_tenant_deps)
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""
|
||||
Admin database engine — connects as resolutionflow_admin (BYPASSRLS).
|
||||
|
||||
Use ONLY for /admin/* endpoints and internal tooling.
|
||||
Never use this engine from user-facing endpoints.
|
||||
Use ONLY where explicit application-level access control makes database-layer
|
||||
tenant filtering unnecessary: /admin/* endpoints, internal tooling, and public
|
||||
endpoints that enforce their own authorization before returning data (e.g.
|
||||
share access via opaque token + visibility check).
|
||||
"""
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -25,7 +27,7 @@ _admin_session_factory = async_sessionmaker(
|
||||
|
||||
|
||||
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an admin DB session (BYPASSRLS). Use only on /admin/* endpoints."""
|
||||
"""Yield an admin DB session (BYPASSRLS). See module docstring for approved use cases."""
|
||||
async with _admin_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
@@ -12,10 +12,19 @@ async def log_audit(
|
||||
resource_type: str,
|
||||
resource_id: Optional[UUID] = None,
|
||||
details: Optional[dict] = None,
|
||||
account_id: Optional[UUID] = None,
|
||||
) -> None:
|
||||
"""Record an audit log entry. Does not commit — piggybacks on the caller's commit."""
|
||||
if account_id is None:
|
||||
# Derive from the acting user's account as a fallback (one extra query).
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
result = await db.execute(select(User.account_id).where(User.id == user_id))
|
||||
account_id = result.scalar_one()
|
||||
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
|
||||
@@ -128,6 +128,7 @@ class Settings(BaseSettings):
|
||||
"variable_inference": "fast",
|
||||
"kb_convert": "standard",
|
||||
"script_build": "standard",
|
||||
"network_diagram_generate": "standard",
|
||||
}
|
||||
|
||||
def get_model_for_action(self, action_type: str) -> str:
|
||||
|
||||
@@ -21,7 +21,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
"""Create batch sessions for a scheduled maintenance run."""
|
||||
# Import all models first to ensure SQLAlchemy mapper relationships resolve
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.maintenance_schedule import MaintenanceSchedule
|
||||
from app.models.session import Session
|
||||
from app.models.target_list import TargetList
|
||||
@@ -118,7 +118,7 @@ async def _fire_maintenance_schedule(schedule_id: str) -> None:
|
||||
async def _cleanup_expired_ai_conversations() -> None:
|
||||
"""Delete expired AI wizard conversations."""
|
||||
import app.models # noqa: F401
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_conversation import AIConversation
|
||||
|
||||
async with async_session_maker() as db:
|
||||
|
||||
@@ -14,6 +14,8 @@ import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import _admin_session_factory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_ACCOUNT_EMAIL = "noreply@resolutionflow.com"
|
||||
@@ -52,40 +54,45 @@ async def _ensure_system_account(db: AsyncSession) -> uuid.UUID:
|
||||
async def ensure_service_account(db: AsyncSession) -> uuid.UUID:
|
||||
"""Ensure the ResolutionFlow service account exists and return its ID.
|
||||
|
||||
Idempotent — safe to call on every startup. Creates the account if it
|
||||
does not exist. The account has no usable password and is_service_account=True
|
||||
so it can never log in via normal auth flows.
|
||||
Idempotent — safe to call on every startup. This lookup must bypass RLS
|
||||
because startup runs before any request-scoped tenant context exists and
|
||||
the users table is tenant-isolated in Phase 4. The service account is
|
||||
normally created by Alembic migration 1490781700bc; the runtime create path
|
||||
remains as a self-healing fallback for environments that predate that seed.
|
||||
"""
|
||||
_ = db # Retained for call-site compatibility in app lifespan startup.
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
async with _admin_session_factory() as admin_db:
|
||||
result = await admin_db.execute(
|
||||
select(User).where(User.email == SERVICE_ACCOUNT_EMAIL)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is not None:
|
||||
if not user.is_service_account:
|
||||
user.is_service_account = True
|
||||
await db.commit()
|
||||
return user.id
|
||||
if user is not None:
|
||||
if not user.is_service_account:
|
||||
user.is_service_account = True
|
||||
await admin_db.commit()
|
||||
return user.id
|
||||
|
||||
account_id = await _ensure_system_account(db)
|
||||
account_id = await _ensure_system_account(admin_db)
|
||||
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name=SERVICE_ACCOUNT_NAME,
|
||||
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=True,
|
||||
must_change_password=False,
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||
return new_user.id
|
||||
new_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name=SERVICE_ACCOUNT_NAME,
|
||||
password_hash="!service-account-no-login", # bcrypt can't produce this prefix
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=True,
|
||||
must_change_password=False,
|
||||
account_id=account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
admin_db.add(new_user)
|
||||
await admin_db.commit()
|
||||
logger.info(f"[service_account] Created service account (id={new_user.id})")
|
||||
return new_user.id
|
||||
|
||||
@@ -25,7 +25,8 @@ if settings.SENTRY_DSN:
|
||||
),
|
||||
)
|
||||
|
||||
from app.core.database import init_db, async_session_maker
|
||||
from app.core.database import init_db
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.core.logging_config import setup_logging
|
||||
from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||
from app.core.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -56,6 +56,8 @@ from .session_handoff import SessionHandoff
|
||||
from .session_resolution_output import SessionResolutionOutput
|
||||
from .template_tree import TemplateTree
|
||||
from .platform_step import PlatformStep
|
||||
from .device_type import DeviceType
|
||||
from .network_diagram import NetworkDiagram
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -126,4 +128,6 @@ __all__ = [
|
||||
"SessionResolutionOutput",
|
||||
"TemplateTree",
|
||||
"PlatformStep",
|
||||
"DeviceType",
|
||||
"NetworkDiagram",
|
||||
]
|
||||
|
||||
@@ -21,6 +21,12 @@ class AuditLog(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
|
||||
resource_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
|
||||
47
backend/app/models/device_type.py
Normal file
47
backend/app/models/device_type.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Device type model for network diagrams."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import String, Boolean, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class DeviceType(Base):
|
||||
"""A device type for network diagram nodes (platform or account-custom)."""
|
||||
__tablename__ = "device_types"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
slug: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
comment="Unique identifier used in diagram node data",
|
||||
)
|
||||
label: Mapped[str] = mapped_column(
|
||||
String(100), nullable=False,
|
||||
comment="Display name",
|
||||
)
|
||||
category: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
comment="network, compute, storage, cloud, endpoint, infrastructure, security",
|
||||
)
|
||||
is_system: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False,
|
||||
comment="True for built-in types that cannot be deleted",
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
comment="Platform account for system types, tenant account for custom types",
|
||||
)
|
||||
sort_order: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0,
|
||||
comment="Display order within category",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
53
backend/app/models/network_diagram.py
Normal file
53
backend/app/models/network_diagram.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Network diagram model."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class NetworkDiagram(Base):
|
||||
"""A network topology diagram scoped to one account."""
|
||||
__tablename__ = "network_diagrams"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
asset_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
nodes: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
edges: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, server_default="'[]'")
|
||||
thumbnail_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_archived: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False,
|
||||
)
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
creator: Mapped["User | None"] = relationship("User", foreign_keys=[created_by])
|
||||
@@ -8,7 +8,6 @@ from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
@@ -18,10 +17,6 @@ class TargetList(Base):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
team_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("teams.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
|
||||
@@ -25,6 +25,12 @@ class TreeShare(Base):
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
share_token: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
unique=True,
|
||||
|
||||
37
backend/app/schemas/device_type.py
Normal file
37
backend/app/schemas/device_type.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Pydantic schemas for device types."""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DeviceTypeCreate(BaseModel):
|
||||
slug: str = Field(min_length=1, max_length=50, pattern=r"^[a-z0-9\-]+$")
|
||||
label: str = Field(min_length=1, max_length=100)
|
||||
category: str = Field(
|
||||
min_length=1, max_length=50,
|
||||
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
|
||||
)
|
||||
sort_order: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class DeviceTypeUpdate(BaseModel):
|
||||
label: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
category: str | None = Field(
|
||||
default=None, min_length=1, max_length=50,
|
||||
pattern=r"^(network|compute|storage|cloud|endpoint|infrastructure|security)$",
|
||||
)
|
||||
sort_order: int | None = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class DeviceTypeResponse(BaseModel):
|
||||
id: UUID
|
||||
slug: str
|
||||
label: str
|
||||
category: str
|
||||
is_system: bool
|
||||
account_id: UUID
|
||||
sort_order: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
136
backend/app/schemas/network_diagram.py
Normal file
136
backend/app/schemas/network_diagram.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Pydantic schemas for network diagrams."""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
class DeviceProperties(BaseModel):
|
||||
hostname: str | None = None
|
||||
ip: str | None = None
|
||||
subnet: str | None = None
|
||||
vendor: str | None = None
|
||||
model: str | None = None
|
||||
role: str | None = None
|
||||
vlan: str | None = None
|
||||
notes: str | None = None
|
||||
status: str = Field(default="unknown", pattern=r"^(unknown|online|offline|degraded)$")
|
||||
|
||||
|
||||
class DiagramNode(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
label: str
|
||||
position: Position
|
||||
properties: DeviceProperties = Field(default_factory=DeviceProperties)
|
||||
|
||||
|
||||
class DiagramEdge(BaseModel):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
label: str | None = None
|
||||
connectionType: str = "ethernet"
|
||||
speed: str | None = None
|
||||
notes: str | None = None
|
||||
routing: str | None = None
|
||||
|
||||
|
||||
class NetworkDiagramCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
|
||||
|
||||
class NetworkDiagramUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] | None = None
|
||||
edges: list[DiagramEdge] | None = None
|
||||
|
||||
|
||||
class NetworkDiagramResponse(BaseModel):
|
||||
id: UUID
|
||||
account_id: UUID
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
asset_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
thumbnail_url: str | None = None
|
||||
is_archived: bool = False
|
||||
created_by: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NetworkDiagramListItem(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
node_count: int = 0
|
||||
category_counts: dict[str, int] = Field(default_factory=dict)
|
||||
created_by: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ExistingBounds(BaseModel):
|
||||
minX: float
|
||||
maxX: float
|
||||
minY: float
|
||||
maxY: float
|
||||
|
||||
|
||||
class AIGenerateRequest(BaseModel):
|
||||
description: str = Field(min_length=1, max_length=5000)
|
||||
client_name: str | None = None
|
||||
mode: str = Field(default="replace", pattern=r"^(replace|merge)$")
|
||||
existingBounds: ExistingBounds | None = None
|
||||
|
||||
|
||||
class AIGenerateResponse(BaseModel):
|
||||
nodes: list[DiagramNode]
|
||||
edges: list[DiagramEdge]
|
||||
suggestedName: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class DiagramImportRequest(BaseModel):
|
||||
schemaVersion: int = Field(ge=1, le=1)
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode] = Field(default_factory=list)
|
||||
edges: list[DiagramEdge] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DiagramImportResponse(BaseModel):
|
||||
diagram: NetworkDiagramResponse
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DiagramExportResponse(BaseModel):
|
||||
schemaVersion: int = 1
|
||||
name: str
|
||||
client_name: str | None = None
|
||||
description: str | None = None
|
||||
nodes: list[DiagramNode]
|
||||
edges: list[DiagramEdge]
|
||||
exportedAt: str
|
||||
@@ -23,7 +23,7 @@ class TargetListUpdate(BaseModel):
|
||||
|
||||
class TargetListResponse(BaseModel):
|
||||
id: UUID
|
||||
team_id: UUID
|
||||
account_id: UUID
|
||||
created_by: Optional[UUID]
|
||||
name: str
|
||||
description: Optional[str]
|
||||
|
||||
@@ -34,6 +34,7 @@ class BranchManager:
|
||||
root = SessionBranch(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
parent_branch_id=None,
|
||||
branch_order=1,
|
||||
label="Root",
|
||||
@@ -68,9 +69,17 @@ class BranchManager:
|
||||
"status": "untried",
|
||||
})
|
||||
|
||||
# Load session to get account_id for FK constraints
|
||||
session_result = await self.db.execute(
|
||||
select(AISession).where(AISession.id == session_id)
|
||||
)
|
||||
session = session_result.scalar_one_or_none()
|
||||
account_id = session.account_id if session else None
|
||||
|
||||
fork_point = ForkPoint(
|
||||
id=uuid.uuid4(),
|
||||
session_id=session_id,
|
||||
account_id=account_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
trigger_step_id=trigger_step_id,
|
||||
fork_reason=fork_reason,
|
||||
@@ -90,6 +99,7 @@ class BranchManager:
|
||||
branch = SessionBranch(
|
||||
id=branch_ids[i],
|
||||
session_id=session_id,
|
||||
account_id=account_id,
|
||||
parent_branch_id=parent_branch_id,
|
||||
fork_point_step_id=trigger_step_id,
|
||||
branch_order=i + 1,
|
||||
|
||||
@@ -56,6 +56,7 @@ class HandoffManager:
|
||||
|
||||
handoff = SessionHandoff(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
handed_off_by=user_id,
|
||||
intent=intent,
|
||||
source_branch_id=session.active_branch_id,
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.ai_session import AISession
|
||||
from app.services.knowledge_flywheel import analyze_session
|
||||
|
||||
|
||||
151
backend/app/services/network_diagram_ai_service.py
Normal file
151
backend/app/services/network_diagram_ai_service.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""AI service for generating network diagrams from natural language."""
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.core.config import settings
|
||||
from app.schemas.network_diagram import (
|
||||
AIGenerateRequest,
|
||||
AIGenerateResponse,
|
||||
DiagramNode,
|
||||
DiagramEdge,
|
||||
DeviceProperties,
|
||||
Position,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """You are a network diagram generator for MSP engineers.
|
||||
Given a plain English description of a network, you must return ONLY valid JSON with no markdown, no explanation, no preamble.
|
||||
|
||||
Return this exact structure:
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"id": "unique-string",
|
||||
"type": "device-type-slug",
|
||||
"label": "device label",
|
||||
"position": {{ "x": number, "y": number }},
|
||||
"properties": {{
|
||||
"hostname": "string or null",
|
||||
"ip": "string or null",
|
||||
"subnet": "string or null",
|
||||
"vendor": "string or null",
|
||||
"model": "string or null",
|
||||
"role": "string or null",
|
||||
"vlan": "string or null",
|
||||
"notes": "string or null",
|
||||
"status": "unknown"
|
||||
}}
|
||||
}}
|
||||
],
|
||||
"edges": [
|
||||
{{
|
||||
"id": "unique-string",
|
||||
"source": "node-id",
|
||||
"target": "node-id",
|
||||
"label": "connection label or null",
|
||||
"connectionType": "ethernet|fiber|wifi|vpn|vlan|wan",
|
||||
"speed": "string or null",
|
||||
"notes": "string or null"
|
||||
}}
|
||||
],
|
||||
"suggestedName": "short descriptive diagram name",
|
||||
"notes": "any important assumptions or missing info, or null"
|
||||
}}
|
||||
|
||||
Available device type slugs: {available_slugs}
|
||||
|
||||
Position nodes thoughtfully in a logical network topology layout.
|
||||
Use x/y coordinates between 0 and 1200 for x, 0 and 800 for y.
|
||||
Place WAN/internet at top, core network in middle, endpoints at bottom.
|
||||
{merge_instructions}"""
|
||||
|
||||
MERGE_INSTRUCTIONS = """
|
||||
IMPORTANT: You are ADDING devices to an existing diagram. Do NOT replace existing devices.
|
||||
The existing diagram occupies this bounding box: minX={minX}, maxX={maxX}, minY={minY}, maxY={maxY}.
|
||||
Place all new nodes OUTSIDE this bounding box — below (y > {maxY} + 100) or to the right (x > {maxX} + 100).
|
||||
You may create edges that connect new nodes to existing nodes if the description implies a connection.
|
||||
Use these existing node IDs for connections: {existing_node_ids}"""
|
||||
|
||||
|
||||
async def generate_diagram(
|
||||
request: AIGenerateRequest,
|
||||
available_slugs: list[str],
|
||||
existing_node_ids: list[str] | None = None,
|
||||
) -> AIGenerateResponse:
|
||||
merge_instructions = ""
|
||||
if request.mode == "merge" and request.existingBounds:
|
||||
b = request.existingBounds
|
||||
merge_instructions = MERGE_INSTRUCTIONS.format(
|
||||
minX=b.minX, maxX=b.maxX, minY=b.minY, maxY=b.maxY,
|
||||
existing_node_ids=", ".join(existing_node_ids or []),
|
||||
)
|
||||
|
||||
system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
available_slugs=", ".join(available_slugs),
|
||||
merge_instructions=merge_instructions,
|
||||
)
|
||||
|
||||
model = settings.get_model_for_action("network_diagram_generate")
|
||||
provider = get_ai_provider(model)
|
||||
|
||||
messages = [{"role": "user", "content": request.description}]
|
||||
|
||||
response_text, input_tokens, output_tokens = await provider.generate_json(
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Network diagram AI generation: input_tokens=%d, output_tokens=%d",
|
||||
input_tokens, output_tokens,
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(response_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to parse AI response as JSON: %s", e)
|
||||
raise ValueError("AI generated an invalid response, please try again")
|
||||
|
||||
try:
|
||||
nodes = []
|
||||
for raw_node in data.get("nodes", []):
|
||||
node_type = raw_node.get("type", "server")
|
||||
if node_type not in available_slugs:
|
||||
logger.warning("Unknown device type '%s', falling back to 'server'", node_type)
|
||||
node_type = "server"
|
||||
|
||||
nodes.append(DiagramNode(
|
||||
id=raw_node["id"],
|
||||
type=node_type,
|
||||
label=raw_node.get("label", node_type),
|
||||
position=Position(**raw_node.get("position", {"x": 0, "y": 0})),
|
||||
properties=DeviceProperties(**{
|
||||
k: v for k, v in raw_node.get("properties", {}).items()
|
||||
if k in DeviceProperties.model_fields
|
||||
}),
|
||||
))
|
||||
|
||||
edges = []
|
||||
for raw_edge in data.get("edges", []):
|
||||
edges.append(DiagramEdge(
|
||||
id=raw_edge["id"],
|
||||
source=raw_edge["source"],
|
||||
target=raw_edge["target"],
|
||||
label=raw_edge.get("label"),
|
||||
connectionType=raw_edge.get("connectionType", "ethernet"),
|
||||
speed=raw_edge.get("speed"),
|
||||
notes=raw_edge.get("notes"),
|
||||
))
|
||||
except KeyError as e:
|
||||
logger.warning("AI response missing required field: %s", e)
|
||||
raise ValueError(f"AI generated incomplete data (missing {e}), please try again")
|
||||
|
||||
return AIGenerateResponse(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
suggestedName=data.get("suggestedName"),
|
||||
notes=data.get("notes"),
|
||||
)
|
||||
@@ -371,6 +371,7 @@ async def push_documentation(
|
||||
# Log success
|
||||
log_entry = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
account_id=session.account_id,
|
||||
ai_session_id=session.id,
|
||||
psa_connection_id=session.psa_connection_id,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
@@ -394,6 +395,7 @@ async def push_documentation(
|
||||
# Log failure with retry scheduling
|
||||
log_entry = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
account_id=session.account_id,
|
||||
ai_session_id=session.id,
|
||||
psa_connection_id=session.psa_connection_id,
|
||||
ticket_id=session.psa_ticket_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.psa_post_log import PsaPostLog
|
||||
from app.services.psa_documentation_service import retry_failed_push
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class ResolutionOutputGenerator:
|
||||
|
||||
output = SessionResolutionOutput(
|
||||
session_id=session_id,
|
||||
account_id=session.account_id,
|
||||
output_type=output_type,
|
||||
generated_content=content,
|
||||
status="draft",
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.admin_database import _admin_session_factory as async_session_maker
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ def _extract_script_from_response(content: str, language: str) -> tuple[str | No
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
team_id: UUID | None,
|
||||
language: str,
|
||||
initial_prompt: str | None = None,
|
||||
@@ -151,6 +152,7 @@ async def create_session(
|
||||
"""Create a new Script Builder session."""
|
||||
session = ScriptBuilderSession(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
team_id=team_id,
|
||||
language=language,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,10 @@ def _display_code() -> str:
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
engine = create_async_engine(settings.DATABASE_URL, echo=False)
|
||||
# Must use ADMIN_DATABASE_URL (BYPASSRLS) — Phase 4 enabled RLS on users.
|
||||
# The app-role connection has no tenant context at seed time and would see 0 rows.
|
||||
admin_url = getattr(settings, "ADMIN_DATABASE_URL", None) or settings.DATABASE_URL
|
||||
engine = create_async_engine(admin_url, echo=False)
|
||||
password_hash = get_password_hash(SHARED_PASSWORD)
|
||||
now = datetime.now(timezone.utc)
|
||||
team_account_id: uuid.UUID | None = None
|
||||
|
||||
@@ -75,6 +75,19 @@ async def test_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
('team', NULL, NULL, NULL, true, true, '["markdown", "text", "html"]')
|
||||
"""))
|
||||
|
||||
# Seed the platform/system account (PLATFORM_ACCOUNT_ID) needed by
|
||||
# global categories, gallery items, and other platform-owned content.
|
||||
await conn.execute(sa.text("""
|
||||
INSERT INTO accounts (id, name, display_code, created_at, updated_at)
|
||||
VALUES (
|
||||
'00000000-0000-0000-0000-000000000001',
|
||||
'ResolutionFlow System',
|
||||
'RF-SYS-1',
|
||||
NOW(), NOW()
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""))
|
||||
|
||||
# Create async session maker
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestAdminGlobalCategories:
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Category"
|
||||
assert data["slug"] == "test-category"
|
||||
assert data["account_id"] is None
|
||||
assert data["account_id"] == "00000000-0000-0000-0000-000000000001" # PLATFORM_ACCOUNT_ID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_global_category(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.tree import Tree
|
||||
from app.models.script_template import ScriptTemplate, ScriptCategory
|
||||
|
||||
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -22,6 +23,7 @@ async def _create_tree(db: AsyncSession, admin_user_id: str) -> Tree:
|
||||
name="Gallery Test Flow",
|
||||
tree_type="troubleshooting",
|
||||
visibility="public",
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=False,
|
||||
gallery_sort_order=0,
|
||||
tree_structure={
|
||||
@@ -53,6 +55,7 @@ async def _create_script(db: AsyncSession, admin_user_id: str) -> ScriptTemplate
|
||||
script = ScriptTemplate(
|
||||
id=uuid.uuid4(),
|
||||
category_id=category.id,
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
name="Gallery Test Script",
|
||||
slug=f"gallery-test-script-{uuid.uuid4().hex[:6]}",
|
||||
script_body="Write-Host 'Test'",
|
||||
|
||||
@@ -594,6 +594,7 @@ class TestPsaMetrics:
|
||||
post_log = PsaPostLog(
|
||||
id=uuid.uuid4(),
|
||||
ai_session_id=push_session_id,
|
||||
account_id=account_id,
|
||||
ticket_id="TICKET-123",
|
||||
note_type="internal",
|
||||
content_posted="Session summary",
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.security import get_password_hash
|
||||
from app.models.account import Account
|
||||
from app.models.team import Team
|
||||
from app.models.user import User
|
||||
|
||||
@@ -23,6 +24,8 @@ async def _create_team_with_admin(
|
||||
team_name: str = "Branding Test Team",
|
||||
) -> tuple[dict, str, Team]:
|
||||
"""Create a team + team admin user. Returns (auth_headers, team_id_str, team)."""
|
||||
account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper())
|
||||
test_db.add(account)
|
||||
team = Team(name=team_name)
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
@@ -36,6 +39,8 @@ async def _create_team_with_admin(
|
||||
team_id=team.id,
|
||||
is_team_admin=True,
|
||||
role="engineer",
|
||||
account_id=account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(user)
|
||||
await test_db.commit()
|
||||
@@ -58,6 +63,15 @@ async def _create_team_member(
|
||||
is_team_admin: bool = False,
|
||||
) -> dict:
|
||||
"""Create a regular team member. Returns auth_headers."""
|
||||
# Look up the account associated with this team via an existing member
|
||||
from sqlalchemy import select as _select
|
||||
from app.models.user import User as _User
|
||||
result = await test_db.execute(
|
||||
_select(_User).where(_User.team_id == team.id).limit(1)
|
||||
)
|
||||
team_member = result.scalar_one_or_none()
|
||||
member_account_id = team_member.account_id if team_member else None
|
||||
|
||||
email = f"member_{uuid.uuid4().hex[:8]}@test.com"
|
||||
user = User(
|
||||
email=email,
|
||||
@@ -67,6 +81,8 @@ async def _create_team_member(
|
||||
team_id=team.id,
|
||||
is_team_admin=is_team_admin,
|
||||
role="engineer",
|
||||
account_id=member_account_id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(user)
|
||||
await test_db.commit()
|
||||
|
||||
@@ -334,12 +334,13 @@ class TestDraftTreesAPI:
|
||||
"""Test that migration defaults existing trees to published status."""
|
||||
# Create a tree without specifying status (relies on DB default)
|
||||
from uuid import UUID, uuid4
|
||||
_platform_id = UUID("00000000-0000-0000-0000-000000000001")
|
||||
tree = Tree(
|
||||
name="Legacy Tree",
|
||||
description="Created before status field",
|
||||
tree_structure={"id": "root", "type": "solution", "title": "Fix"},
|
||||
author_id=None,
|
||||
account_id=None
|
||||
account_id=_platform_id,
|
||||
)
|
||||
test_db.add(tree)
|
||||
await test_db.commit()
|
||||
|
||||
@@ -127,10 +127,12 @@ async def test_cannot_schedule_other_teams_tree(client: AsyncClient, auth_header
|
||||
test_db.add(other_team)
|
||||
await test_db.flush()
|
||||
|
||||
from uuid import UUID as _UUID
|
||||
other_tree = Tree(
|
||||
name="Other Team Tree",
|
||||
tree_type="maintenance",
|
||||
team_id=other_team.id,
|
||||
account_id=_UUID("00000000-0000-0000-0000-000000000001"),
|
||||
tree_structure={
|
||||
"steps": [
|
||||
{"id": "s1", "type": "procedure_step", "title": "Step",
|
||||
|
||||
96
backend/tests/test_network_diagrams.py
Normal file
96
backend/tests/test_network_diagrams.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.device_type import DeviceType
|
||||
from app.models.user import User
|
||||
from app.core.service_account import PLATFORM_ACCOUNT_ID
|
||||
|
||||
|
||||
async def _login_headers(client, email: str, password: str) -> dict[str, str]:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_device_types_include_platform_and_account_custom(client, test_db, auth_headers, test_user):
|
||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
user = result.scalar_one()
|
||||
|
||||
test_db.add(
|
||||
DeviceType(
|
||||
id=uuid.uuid4(),
|
||||
slug="platform-router",
|
||||
label="Platform Router",
|
||||
category="network",
|
||||
is_system=True,
|
||||
account_id=PLATFORM_ACCOUNT_ID,
|
||||
sort_order=0,
|
||||
)
|
||||
)
|
||||
await test_db.commit()
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/device-types/",
|
||||
json={
|
||||
"slug": "tenant-appliance",
|
||||
"label": "Tenant Appliance",
|
||||
"category": "network",
|
||||
"sort_order": 3,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
assert create_response.json()["account_id"] == str(user.account_id)
|
||||
|
||||
list_response = await client.get("/api/v1/device-types/", headers=auth_headers)
|
||||
assert list_response.status_code == 200
|
||||
payload = list_response.json()
|
||||
slugs = {item["slug"] for item in payload}
|
||||
|
||||
assert "platform-router" in slugs
|
||||
assert "tenant-appliance" in slugs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_diagrams_are_account_scoped(client, test_db, auth_headers, test_user):
|
||||
other_user = {
|
||||
"email": "other-network@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Other Network User",
|
||||
}
|
||||
register_response = await client.post("/api/v1/auth/register", json=other_user)
|
||||
assert register_response.status_code in (200, 201)
|
||||
other_headers = await _login_headers(client, other_user["email"], other_user["password"])
|
||||
|
||||
owner_result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
owner = owner_result.scalar_one()
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/network-diagrams/",
|
||||
json={
|
||||
"name": "HQ Core",
|
||||
"client_name": "Acme",
|
||||
"description": "Primary topology",
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
diagram = create_response.json()
|
||||
assert diagram["account_id"] == str(owner.account_id)
|
||||
|
||||
own_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=auth_headers)
|
||||
assert own_get.status_code == 200
|
||||
|
||||
other_get = await client.get(f"/api/v1/network-diagrams/{diagram['id']}", headers=other_headers)
|
||||
assert other_get.status_code == 404
|
||||
@@ -200,6 +200,7 @@ class TestAccountPermissions:
|
||||
})
|
||||
outsider_headers = {"Authorization": f"Bearer {outsider_login.json()['access_token']}"}
|
||||
|
||||
# Outsider should NOT see the private tree
|
||||
# Outsider should NOT see the private tree.
|
||||
# With RLS, the tree is invisible to other tenants — 404 not 403.
|
||||
response = await client.get(f"/api/v1/trees/{tree_id}", headers=outsider_headers)
|
||||
assert response.status_code == 403
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -464,7 +464,6 @@ async def test_target_list_account_id_from_team_admin(test_db: AsyncSession):
|
||||
await test_db.flush()
|
||||
|
||||
target_list = TargetList(
|
||||
team_id=team.id,
|
||||
account_id=account.id,
|
||||
created_by=user.id,
|
||||
name="Server Targets",
|
||||
|
||||
@@ -11,6 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.script_template import ScriptCategory, ScriptTemplate
|
||||
from app.models.tree import Tree
|
||||
|
||||
_PLATFORM_ACCOUNT_ID = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -41,6 +43,7 @@ async def _create_featured_tree(db: AsyncSession, name: str = "Featured Flow", f
|
||||
description="A featured flow for the gallery",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure=_make_tree_structure(4),
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=featured,
|
||||
is_active=True,
|
||||
usage_count=42,
|
||||
@@ -74,6 +77,7 @@ async def _create_featured_script(
|
||||
) -> ScriptTemplate:
|
||||
script = ScriptTemplate(
|
||||
category_id=category.id,
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
name=name,
|
||||
slug=name.lower().replace(" ", "-"),
|
||||
description="A gallery-featured script",
|
||||
@@ -312,7 +316,7 @@ class TestCategoriesEndpoint:
|
||||
from app.models.category import TreeCategory
|
||||
|
||||
# Create a category and a featured tree in that category
|
||||
cat = TreeCategory(name="Networking", slug="networking", is_active=True)
|
||||
cat = TreeCategory(name="Networking", slug="networking", is_active=True, account_id=_PLATFORM_ACCOUNT_ID)
|
||||
test_db.add(cat)
|
||||
await test_db.commit()
|
||||
await test_db.refresh(cat)
|
||||
@@ -321,6 +325,7 @@ class TestCategoriesEndpoint:
|
||||
name="Router Diagnostics",
|
||||
tree_type="troubleshooting",
|
||||
tree_structure=_make_tree_structure(2),
|
||||
account_id=_PLATFORM_ACCOUNT_ID,
|
||||
is_gallery_featured=True,
|
||||
is_active=True,
|
||||
usage_count=5,
|
||||
|
||||
@@ -62,6 +62,7 @@ async def test_edit_output(client: AsyncClient, test_user, auth_headers, test_db
|
||||
|
||||
output = SessionResolutionOutput(
|
||||
session_id=session.id,
|
||||
account_id=session.account_id,
|
||||
output_type="psa_ticket_notes",
|
||||
generated_content="Original notes",
|
||||
status="draft",
|
||||
|
||||
@@ -16,11 +16,20 @@ Run with:
|
||||
The test DB is patherly_test (matches conftest.py default).
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
|
||||
# All tests in this module use module-scoped async fixtures (admin_conn,
|
||||
# seed_rls_test_data) which run on the module event loop. Without this marker,
|
||||
# pytest-asyncio 0.23+ defaults tests to function-scoped loops, causing
|
||||
# "Future attached to a different loop" errors on the asyncpg connections.
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="module")
|
||||
|
||||
_DB_HOST = os.getenv("TEST_DB_HOST", "localhost")
|
||||
_DB_PORT = int(os.getenv("TEST_DB_PORT", "5432"))
|
||||
_DB_NAME = os.getenv("TEST_DB_NAME", "patherly_test") # matches conftest.py
|
||||
@@ -37,7 +46,25 @@ ACCOUNT_B_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def admin_conn():
|
||||
def _ensure_rls_schema():
|
||||
"""Re-apply Alembic migrations before the module runs.
|
||||
|
||||
Function-scoped test_db fixtures in other modules drop and recreate the
|
||||
public schema using Base.metadata.create_all, which does not enable RLS
|
||||
or create DB roles. This fixture re-runs 'alembic upgrade head' so that
|
||||
the full migration-managed schema (including RLS policies) is in place.
|
||||
"""
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "upgrade", "head"],
|
||||
cwd=backend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def admin_conn(_ensure_rls_schema):
|
||||
"""Superuser asyncpg connection for fixture setup and teardown."""
|
||||
conn = await asyncpg.connect(_ADMIN_DSN)
|
||||
yield conn
|
||||
@@ -170,7 +197,6 @@ async def conn_no_context():
|
||||
# trees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -178,7 +204,6 @@ async def test_trees_account_a_cannot_see_account_b_rows(conn_a):
|
||||
assert len(rows) == 0, "Account A should not see Account B trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
@@ -186,7 +211,6 @@ async def test_trees_account_a_can_see_own_rows(conn_a):
|
||||
assert len(rows) >= 1, "Account A should see its own trees"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
rows = await conn_no_context.fetch(
|
||||
"SELECT id FROM trees WHERE is_default = FALSE AND is_public = FALSE"
|
||||
@@ -198,7 +222,6 @@ async def test_trees_no_context_sees_no_private_trees(conn_no_context):
|
||||
# tree_tags — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -206,7 +229,6 @@ async def test_tree_tags_account_a_cannot_see_account_b_tags(conn_a):
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
rows_a = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_tags WHERE account_id = '{PLATFORM_ACCOUNT_ID}'"
|
||||
@@ -222,7 +244,6 @@ async def test_tree_tags_both_tenants_see_platform_tags(conn_a, conn_b):
|
||||
# tree_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -234,7 +255,6 @@ async def test_tree_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# step_categories — platform visibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_categories WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -246,7 +266,6 @@ async def test_step_categories_account_a_cannot_see_account_b(conn_a):
|
||||
# psa_connections — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_connections WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
@@ -258,9 +277,782 @@ async def test_psa_connections_account_a_cannot_see_account_b(conn_a):
|
||||
# flow_proposals — tenant-only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_proposals_account_a_cannot_see_account_b(conn_a):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM flow_proposals WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def session_row_ids(admin_conn):
|
||||
"""
|
||||
Insert one `sessions` row and one `ai_sessions` row for each of
|
||||
ACCOUNT_A and ACCOUNT_B using the superuser connection (BYPASSRLS).
|
||||
Returns a dict with the inserted IDs for use in tests.
|
||||
Cleans up on exit.
|
||||
"""
|
||||
# Resolve a valid tree_id and user_id for each account
|
||||
tree_a = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
|
||||
)
|
||||
tree_b = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
user_a = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}' LIMIT 1"
|
||||
)
|
||||
user_b = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
|
||||
assert tree_a is not None, f"No tree found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
|
||||
assert tree_b is not None, f"No tree found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
|
||||
assert user_a is not None, f"No user found for ACCOUNT_A ({ACCOUNT_A_ID}) — seed_rls_test_data must run first"
|
||||
assert user_b is not None, f"No user found for ACCOUNT_B ({ACCOUNT_B_ID}) — seed_rls_test_data must run first"
|
||||
|
||||
tree_a_id = str(tree_a["id"])
|
||||
tree_b_id = str(tree_b["id"])
|
||||
user_a_id = str(user_a["id"])
|
||||
user_b_id = str(user_b["id"])
|
||||
|
||||
session_a_id = str(uuid.uuid4())
|
||||
session_b_id = str(uuid.uuid4())
|
||||
ai_session_a_id = str(uuid.uuid4())
|
||||
ai_session_b_id = str(uuid.uuid4())
|
||||
|
||||
# Insert sessions rows (sessions uses started_at not created_at)
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES
|
||||
('{session_a_id}', '{tree_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()),
|
||||
('{session_b_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW())
|
||||
""")
|
||||
|
||||
# Insert ai_sessions rows
|
||||
# confidence_tier valid values: 'guided' | 'exploring' | 'discovery'
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_sessions (
|
||||
id, user_id, account_id, session_type, intake_type,
|
||||
intake_content, status, confidence_tier, confidence_score,
|
||||
created_at, updated_at
|
||||
) VALUES
|
||||
('{ai_session_a_id}', '{user_a_id}', '{ACCOUNT_A_ID}',
|
||||
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
|
||||
NOW(), NOW()),
|
||||
('{ai_session_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'guided', 'free_text', '{{}}'::jsonb, 'active', 'guided', 0.0,
|
||||
NOW(), NOW())
|
||||
""")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Seed Account B rows for every "cannot-see" table that would otherwise be
|
||||
# empty. Without these, isolation tests pass vacuously even when RLS is off.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# session_branches (FK: ai_sessions.id)
|
||||
branch_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_branches (
|
||||
id, session_id, account_id, branch_order, label, status,
|
||||
conversation_messages, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 1, 'test-branch', 'active',
|
||||
'[]'::jsonb, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID)
|
||||
branch_b_id = str(branch_b_row["id"])
|
||||
|
||||
# session_supporting_data (FK: sessions.id)
|
||||
supporting_data_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_supporting_data (
|
||||
id, session_id, account_id, label, data_type, content,
|
||||
sort_order, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'test-data', 'text_snippet',
|
||||
'test content', 0, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", session_b_id, ACCOUNT_B_ID)
|
||||
supporting_data_b_id = str(supporting_data_b_row["id"])
|
||||
|
||||
# session_resolution_outputs (FK: ai_sessions.id)
|
||||
resolution_output_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_resolution_outputs (
|
||||
id, session_id, account_id, output_type, generated_content,
|
||||
status, generated_by_model, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'psa_ticket_notes',
|
||||
'test content', 'draft', 'test-model', NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID)
|
||||
resolution_output_b_id = str(resolution_output_b_row["id"])
|
||||
|
||||
# session_handoffs (FK: ai_sessions.id, users.id)
|
||||
handoff_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO session_handoffs (
|
||||
id, session_id, account_id, handed_off_by, intent, snapshot,
|
||||
priority, psa_note_pushed, notification_sent, created_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, 'park',
|
||||
'{}'::jsonb, 'normal', false, false, NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
|
||||
handoff_b_id = str(handoff_b_row["id"])
|
||||
|
||||
# maintenance_schedules (FK: trees.id)
|
||||
maintenance_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO maintenance_schedules (
|
||||
id, tree_id, account_id, cron_expression, timezone,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, '0 9 * * 1', 'UTC',
|
||||
NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", tree_b_id, ACCOUNT_B_ID)
|
||||
maintenance_b_id = str(maintenance_b_row["id"])
|
||||
|
||||
# psa_post_log (FK: ai_sessions.id, users.id)
|
||||
psa_log_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO psa_post_log (
|
||||
id, ai_session_id, account_id, ticket_id, note_type,
|
||||
content_posted, status, posted_by, posted_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, 'TEST-0001', 'internal',
|
||||
'test note', 'success', $3::uuid, NOW()
|
||||
) RETURNING id
|
||||
""", ai_session_b_id, ACCOUNT_B_ID, user_b_id)
|
||||
psa_log_b_id = str(psa_log_b_row["id"])
|
||||
|
||||
# script_templates requires a script_categories row — insert a temporary one
|
||||
script_category_b_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO script_categories (id, name, slug, sort_order, is_active, created_at, updated_at)
|
||||
VALUES ('{script_category_b_id}', 'RLS Test Category', 'rls-test-category-{script_category_b_id[:8]}',
|
||||
0, true, NOW(), NOW())
|
||||
""")
|
||||
|
||||
script_template_b_row = await admin_conn.fetchrow(f"""
|
||||
INSERT INTO script_templates (
|
||||
id, category_id, account_id, name, slug, script_body,
|
||||
complexity, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), '{script_category_b_id}'::uuid, $1::uuid,
|
||||
'RLS Test Template', 'rls-test-template-b-' || gen_random_uuid()::text,
|
||||
'Write-Host "test"', 'beginner', true, NOW(), NOW()
|
||||
) RETURNING id
|
||||
""", ACCOUNT_B_ID)
|
||||
script_template_b_id = str(script_template_b_row["id"])
|
||||
|
||||
# script_generations (FK: script_templates.id, users.id)
|
||||
script_gen_b_row = await admin_conn.fetchrow("""
|
||||
INSERT INTO script_generations (
|
||||
id, template_id, user_id, account_id, parameters_used,
|
||||
generated_script, created_at
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1::uuid, $2::uuid, $3::uuid, '{}'::jsonb,
|
||||
'test script', NOW()
|
||||
) RETURNING id
|
||||
""", script_template_b_id, user_b_id, ACCOUNT_B_ID)
|
||||
script_gen_b_id = str(script_gen_b_row["id"])
|
||||
|
||||
try:
|
||||
yield {
|
||||
"session_a": session_a_id,
|
||||
"session_b": session_b_id,
|
||||
"ai_session_a": ai_session_a_id,
|
||||
"ai_session_b": ai_session_b_id,
|
||||
}
|
||||
finally:
|
||||
# Cleanup in reverse FK order (children before parents)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_generations WHERE id = '{script_gen_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_branches WHERE id = '{branch_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_supporting_data WHERE id = '{supporting_data_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_resolution_outputs WHERE id = '{resolution_output_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM session_handoffs WHERE id = '{handoff_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM maintenance_schedules WHERE id = '{maintenance_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM psa_post_log WHERE id = '{psa_log_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_templates WHERE id = '{script_template_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_categories WHERE id = '{script_category_b_id}'"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM sessions WHERE id IN ('{session_a_id}', '{session_b_id}')"
|
||||
)
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM ai_sessions WHERE id IN ('{ai_session_a_id}', '{ai_session_b_id}')"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_sessions_account_a_cannot_see_account_b_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_b']}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B sessions"
|
||||
|
||||
|
||||
async def test_sessions_account_a_can_see_own_sessions(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM sessions WHERE id = '{session_row_ids['session_a']}'"
|
||||
)
|
||||
assert len(rows) == 1, "Account A should see its own sessions"
|
||||
|
||||
|
||||
async def test_sessions_no_context_sees_nothing(conn_no_context, session_row_ids):
|
||||
rows = await conn_no_context.fetch(
|
||||
f"SELECT id FROM sessions WHERE id IN "
|
||||
f"('{session_row_ids['session_a']}', '{session_row_ids['session_b']}')"
|
||||
)
|
||||
assert len(rows) == 0, "No-context connection should see no sessions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ai_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_ai_sessions_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_b']}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_sessions"
|
||||
|
||||
|
||||
async def test_ai_sessions_account_a_can_see_own(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_sessions WHERE id = '{session_row_ids['ai_session_a']}'"
|
||||
)
|
||||
assert len(rows) == 1, "Account A should see its own ai_sessions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_branches_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_branches WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_branches"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_supporting_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_supporting_data_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_supporting_data WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_supporting_data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_resolution_outputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_resolution_outputs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_resolution_outputs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_resolution_outputs"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_handoffs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_handoffs_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_handoffs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_handoffs"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_templates_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_templates WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_templates"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_generations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_generations_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_generations WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_generations"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# maintenance_schedules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_maintenance_schedules_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM maintenance_schedules WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B maintenance_schedules"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# psa_post_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_psa_post_log_account_a_cannot_see_account_b(conn_a, session_row_ids):
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM psa_post_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B psa_post_log"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_library — visibility-aware policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_library_account_a_cannot_see_account_b_private_steps(admin_conn, conn_a):
|
||||
"""Private/non-public steps owned by Account B must not be visible to Account A."""
|
||||
private_step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{private_step_id}', '{ACCOUNT_B_ID}', 'RLS Private Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_library "
|
||||
f"WHERE id = '{private_step_id}' AND visibility != 'public'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B's private step_library rows"
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM step_library WHERE id = '{private_step_id}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_step_library_account_a_can_see_account_b_public_steps(admin_conn, conn_a):
|
||||
"""Public steps owned by Account B MUST be visible to Account A (cross-tenant visibility)."""
|
||||
public_step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{public_step_id}', '{ACCOUNT_B_ID}', 'RLS Public Step', 'action',
|
||||
'{{}}'::jsonb, 'public', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_library WHERE id = '{public_step_id}'"
|
||||
)
|
||||
assert len(rows) == 1, (
|
||||
"Account A should see public steps owned by Account B "
|
||||
"(cross-tenant public visibility policy)"
|
||||
)
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM step_library WHERE id = '{public_step_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 3 RLS isolation tests
|
||||
# Tables: step_ratings, step_usage_log, target_lists,
|
||||
# session_shares, audit_logs, tree_shares
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared by Phase 3 fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _get_user_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
"SELECT id FROM users WHERE email = 'rls-user-b@example.com'"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
async def _get_tree_b_id(admin_conn) -> str:
|
||||
row = await admin_conn.fetchrow(
|
||||
f"SELECT id FROM trees WHERE account_id = '{ACCOUNT_B_ID}' LIMIT 1"
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_ratings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_ratings_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step ratings belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
# Need a step_library row as FK target
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 RLS Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
rating_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_ratings (
|
||||
id, step_id, user_id, account_id, is_verified_use, is_visible,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{rating_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
FALSE, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_ratings WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_ratings"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_ratings WHERE id = '{rating_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# step_usage_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_step_usage_log_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see step usage logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_library (
|
||||
id, account_id, title, step_type, content,
|
||||
visibility, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ACCOUNT_B_ID}', 'Phase3 Usage Step', 'action',
|
||||
'{{}}'::jsonb, 'private', TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
# Need a sessions row as FK for usage log
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO step_usage_log (
|
||||
id, step_id, user_id, account_id, session_id, used_at
|
||||
) VALUES (
|
||||
'{log_id}', '{step_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{session_id}', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM step_usage_log WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B step_usage_log"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM step_usage_log WHERE id = '{log_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM step_library WHERE id = '{step_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# target_lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_target_lists_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see target lists belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
tl_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO target_lists (
|
||||
id, account_id, created_by, name, targets, created_at, updated_at
|
||||
) VALUES (
|
||||
'{tl_id}', '{ACCOUNT_B_ID}', '{user_b_id}',
|
||||
'Phase3 RLS Target List', '[]'::jsonb, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM target_lists WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B target_lists"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM target_lists WHERE id = '{tl_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_session_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see session shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need a sessions row as FK
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO sessions (
|
||||
id, tree_id, user_id, account_id, tree_snapshot,
|
||||
path_taken, decisions, custom_steps, started_at
|
||||
) VALUES (
|
||||
'{session_id}', '{tree_b_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'[]'::jsonb, '[]'::jsonb, '[]'::jsonb, '[]'::jsonb, NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-rls-test-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO session_shares (
|
||||
id, session_id, account_id, share_token, visibility,
|
||||
created_by, view_count, is_active, created_at, updated_at
|
||||
) VALUES (
|
||||
'{share_id}', '{session_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', 'account', '{user_b_id}',
|
||||
0, TRUE, NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM session_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B session_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM session_shares WHERE id = '{share_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM sessions WHERE id = '{session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audit_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_audit_logs_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see audit logs belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
log_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO audit_logs (
|
||||
id, user_id, account_id, action, resource_type, created_at
|
||||
) VALUES (
|
||||
'{log_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'test.action', 'test_resource', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM audit_logs WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B audit_logs"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM audit_logs WHERE id = '{log_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tree_shares
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_tree_shares_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see tree shares belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
share_id = str(uuid.uuid4())
|
||||
share_token = f"phase3-tree-rls-{share_id[:8]}"
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO tree_shares (
|
||||
id, tree_id, account_id, share_token, created_by,
|
||||
allow_forking, created_at
|
||||
) VALUES (
|
||||
'{share_id}', '{tree_b_id}', '{ACCOUNT_B_ID}',
|
||||
'{share_token}', '{user_b_id}', TRUE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM tree_shares WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B tree_shares"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM tree_shares WHERE id = '{share_id}'")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Phase 4 RLS isolation tests
|
||||
# Tables: users, script_builder_sessions, ai_session_steps, notifications
|
||||
#
|
||||
# Note: platform_steps and template_trees have no account_id column and no RLS —
|
||||
# they are globally readable by all authenticated users.
|
||||
# ===========================================================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_users_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see users belonging to Account B."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B users"
|
||||
|
||||
|
||||
async def test_users_account_a_can_see_own(admin_conn, conn_a):
|
||||
"""Account A must be able to see its own users."""
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM users WHERE account_id = '{ACCOUNT_A_ID}'"
|
||||
)
|
||||
assert len(rows) > 0, "Account A should see its own users"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# script_builder_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_script_builder_sessions_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see script builder sessions belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO script_builder_sessions (
|
||||
id, user_id, account_id, language, created_at, updated_at
|
||||
) VALUES (
|
||||
'{session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'powershell', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM script_builder_sessions WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B script_builder_sessions"
|
||||
finally:
|
||||
await admin_conn.execute(
|
||||
f"DELETE FROM script_builder_sessions WHERE id = '{session_id}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ai_session_steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_ai_session_steps_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see ai_session_steps belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
tree_b_id = await _get_tree_b_id(admin_conn)
|
||||
|
||||
# Need an ai_sessions row as FK
|
||||
ai_session_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_sessions (
|
||||
id, user_id, account_id, flow_type, status, confidence_tier,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
'{ai_session_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'troubleshooting', 'active', 'guided', NOW(), NOW()
|
||||
)
|
||||
""")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO ai_session_steps (
|
||||
id, session_id, account_id, step_type, content,
|
||||
created_at
|
||||
) VALUES (
|
||||
'{step_id}', '{ai_session_id}', '{ACCOUNT_B_ID}',
|
||||
'question', 'Phase4 RLS test step', NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM ai_session_steps WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B ai_session_steps"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM ai_session_steps WHERE id = '{step_id}'")
|
||||
await admin_conn.execute(f"DELETE FROM ai_sessions WHERE id = '{ai_session_id}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# notifications
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def test_notifications_account_a_cannot_see_account_b(admin_conn, conn_a):
|
||||
"""Account A must not see notifications belonging to Account B."""
|
||||
user_b_id = await _get_user_b_id(admin_conn)
|
||||
|
||||
notif_id = str(uuid.uuid4())
|
||||
await admin_conn.execute(f"""
|
||||
INSERT INTO notifications (
|
||||
id, user_id, account_id, type, title, message,
|
||||
is_read, created_at
|
||||
) VALUES (
|
||||
'{notif_id}', '{user_b_id}', '{ACCOUNT_B_ID}',
|
||||
'info', 'Phase4 RLS Test', 'RLS isolation test notification',
|
||||
FALSE, NOW()
|
||||
)
|
||||
""")
|
||||
try:
|
||||
rows = await conn_a.fetch(
|
||||
f"SELECT id FROM notifications WHERE account_id = '{ACCOUNT_B_ID}'"
|
||||
)
|
||||
assert len(rows) == 0, "Account A should not see Account B notifications"
|
||||
finally:
|
||||
await admin_conn.execute(f"DELETE FROM notifications WHERE id = '{notif_id}'")
|
||||
|
||||
|
||||
@@ -155,6 +155,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[{"node_id": "root", "timestamp": datetime.now(timezone.utc).isoformat()}],
|
||||
@@ -199,6 +200,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -239,6 +241,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -279,6 +282,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=UUID(test_user["user_data"]["id"]),
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
@@ -352,6 +356,7 @@ class TestSaveSessionAsTreeAPI:
|
||||
session = Session(
|
||||
tree_id=tree.id,
|
||||
user_id=other_user.id,
|
||||
account_id=UUID(test_user["user_data"]["account_id"]),
|
||||
tree_snapshot=tree.tree_structure,
|
||||
path_taken=["root"],
|
||||
decisions=[],
|
||||
|
||||
89
backend/tests/test_service_account.py
Normal file
89
backend/tests/test_service_account.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core import service_account as service_account_module
|
||||
from app.core.service_account import (
|
||||
SERVICE_ACCOUNT_EMAIL,
|
||||
SYSTEM_ACCOUNT_DISPLAY_CODE,
|
||||
ensure_service_account,
|
||||
)
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class _SessionFactoryOverride:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_creates_and_reuses_seeded_user(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
service_account_id = await ensure_service_account(test_db)
|
||||
|
||||
created_user = (
|
||||
await test_db.execute(select(User).where(User.id == service_account_id))
|
||||
).scalar_one()
|
||||
assert created_user.email == SERVICE_ACCOUNT_EMAIL
|
||||
assert created_user.is_service_account is True
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
assert created_user.account_id == system_account.id
|
||||
|
||||
second_id = await ensure_service_account(test_db)
|
||||
assert second_id == service_account_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_service_account_marks_existing_user_as_service_account(test_db, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
service_account_module,
|
||||
"_admin_session_factory",
|
||||
_SessionFactoryOverride(test_db),
|
||||
)
|
||||
|
||||
system_account = (
|
||||
await test_db.execute(
|
||||
select(Account).where(Account.display_code == SYSTEM_ACCOUNT_DISPLAY_CODE)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
existing_user = User(
|
||||
email=SERVICE_ACCOUNT_EMAIL,
|
||||
name="ResolutionFlow",
|
||||
password_hash="!service-account-no-login",
|
||||
role="engineer",
|
||||
is_super_admin=False,
|
||||
is_team_admin=False,
|
||||
is_active=True,
|
||||
is_service_account=False,
|
||||
must_change_password=False,
|
||||
account_id=system_account.id,
|
||||
account_role="engineer",
|
||||
)
|
||||
test_db.add(existing_user)
|
||||
await test_db.commit()
|
||||
|
||||
resolved_id = await ensure_service_account(test_db)
|
||||
await test_db.refresh(existing_user)
|
||||
|
||||
assert resolved_id == existing_user.id
|
||||
assert existing_user.is_service_account is True
|
||||
@@ -3,37 +3,10 @@ import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.team import Team
|
||||
from app.models.user import User
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient, test_db: AsyncSession, test_user: dict):
|
||||
"""Override auth_headers to ensure the test user has a team_id assigned."""
|
||||
# Fetch the user from DB and assign a team
|
||||
result = await test_db.execute(select(User).where(User.email == test_user["email"]))
|
||||
user = result.scalar_one()
|
||||
|
||||
# Create a team and assign the user to it
|
||||
team = Team(name="Test Team")
|
||||
test_db.add(team)
|
||||
await test_db.flush()
|
||||
|
||||
user.team_id = team.id
|
||||
await test_db.commit()
|
||||
|
||||
# Re-login to get a fresh token
|
||||
login_data = {
|
||||
"email": test_user["email"],
|
||||
"password": test_user["password"],
|
||||
}
|
||||
resp = await client.post("/api/v1/auth/login/json", json=login_data)
|
||||
assert resp.status_code == 200
|
||||
token_data = resp.json()
|
||||
return {"Authorization": f"Bearer {token_data['access_token']}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_target_list(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
@@ -107,25 +80,28 @@ async def test_delete_target_list(client: AsyncClient, auth_headers: dict):
|
||||
assert get.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from team B cannot access team A's list."""
|
||||
async def test_cannot_access_other_accounts_list(client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""User from account B cannot access account A's target list."""
|
||||
import uuid
|
||||
from app.models.team import Team
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
# Create team A list using existing auth_headers
|
||||
# Create account A list using existing auth_headers
|
||||
create = await client.post(
|
||||
"/api/v1/target-lists/",
|
||||
json={"name": "Team A List", "targets": [{"label": "SRV-A"}]},
|
||||
json={"name": "Account A List", "targets": [{"label": "SRV-A"}]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create.status_code == 201
|
||||
list_id = create.json()["id"]
|
||||
|
||||
# Create a separate team B with its own user
|
||||
team_b = Team(name=f"Team B {uuid.uuid4()}")
|
||||
test_db.add(team_b)
|
||||
# Create a separate account B with its own user
|
||||
account_b = Account(
|
||||
name=f"Account B {uuid.uuid4()}",
|
||||
display_code=f"AB{str(uuid.uuid4())[:6].upper()}",
|
||||
)
|
||||
test_db.add(account_b)
|
||||
await test_db.flush()
|
||||
|
||||
user_b = User(
|
||||
@@ -133,11 +109,13 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
password_hash=get_password_hash("password123"),
|
||||
name="User B",
|
||||
is_active=True,
|
||||
team_id=team_b.id,
|
||||
account_id=account_b.id,
|
||||
account_role="engineer",
|
||||
role="engineer",
|
||||
)
|
||||
test_db.add(user_b)
|
||||
await test_db.flush()
|
||||
await test_db.commit()
|
||||
|
||||
# Get auth token for user B
|
||||
login = await client.post(
|
||||
@@ -148,6 +126,6 @@ async def test_cannot_access_other_teams_list(client: AsyncClient, auth_headers:
|
||||
token_b = login.json()["access_token"]
|
||||
headers_b = {"Authorization": f"Bearer {token_b}"}
|
||||
|
||||
# Team B cannot access Team A's list
|
||||
# Account B cannot access Account A's list
|
||||
resp = await client.get(f"/api/v1/target-lists/{list_id}", headers=headers_b)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -117,6 +117,7 @@ class TestTreeSharing:
|
||||
for i in range(3):
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token=f"token_{i}_" + "x" * 56,
|
||||
created_by=sample_tree.author_id,
|
||||
allow_forking=i % 2 == 0
|
||||
@@ -162,6 +163,7 @@ class TestTreeSharing:
|
||||
# Create a share
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="public_test_token" + "x" * 47,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True
|
||||
@@ -192,6 +194,7 @@ class TestTreeSharing:
|
||||
# Create expired share
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="expired_token" + "x" * 50,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True,
|
||||
@@ -209,6 +212,7 @@ class TestTreeSharing:
|
||||
from uuid import UUID
|
||||
share = TreeShare(
|
||||
tree_id=sample_tree.id,
|
||||
account_id=sample_tree.account_id,
|
||||
share_token="inactive_tree_token" + "x" * 44,
|
||||
created_by=UUID(test_user["user_data"]["id"]),
|
||||
allow_forking=True
|
||||
@@ -248,6 +252,37 @@ class TestTreeSharing:
|
||||
tokens.add(token)
|
||||
assert len(tokens) == 5
|
||||
|
||||
async def test_share_account_id_matches_tree_not_actor(
|
||||
self, client: AsyncClient, sample_tree, auth_headers, test_db
|
||||
):
|
||||
"""Share account_id must equal tree.account_id, not the actor's account_id.
|
||||
|
||||
A super admin in a different account can share any tree. The resulting
|
||||
TreeShare row must live in the tree-owner's account so that the tree
|
||||
owner's RLS context covers it. If account_id were derived from the
|
||||
actor instead, the share would vanish from the tree owner's view once
|
||||
RLS is enabled.
|
||||
"""
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{sample_tree.id}/share",
|
||||
json={"allow_forking": True},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
share_token = response.json()["share_token"]
|
||||
|
||||
result = await test_db.execute(
|
||||
select(TreeShare).where(TreeShare.share_token == share_token)
|
||||
)
|
||||
share = result.scalar_one()
|
||||
assert share.account_id == sample_tree.account_id, (
|
||||
"TreeShare.account_id must equal tree.account_id, not the actor's account. "
|
||||
"Shares must live in the tree owner's tenant for RLS to cover them."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_defaults_visibility_to_team(test_db):
|
||||
|
||||
7
frontend/package-lock.json
generated
7
frontend/package-lock.json
generated
@@ -23,6 +23,7 @@
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^4.1.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"immer": "^11.1.3",
|
||||
"lucide-react": "^0.563.0",
|
||||
"monaco-editor": "^0.55.1",
|
||||
@@ -5331,6 +5332,12 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/html-to-image": {
|
||||
"version": "1.11.13",
|
||||
"resolved": "https://registry.npmjs.org/html-to-image/-/html-to-image-1.11.13.tgz",
|
||||
"integrity": "sha512-cuOPoI7WApyhBElTTb9oqsawRvZ0rHhaHwghRLlTuffoD1B2aDemlCruLeZrUIIdvG7gs9xeELEPm6PhuASqrg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/html-url-attributes": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz",
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^4.1.0",
|
||||
"html-to-image": "^1.11.13",
|
||||
"immer": "^11.1.3",
|
||||
"lucide-react": "^0.563.0",
|
||||
"monaco-editor": "^0.55.1",
|
||||
|
||||
BIN
frontend/public/images/hero_001.jpg
Normal file
BIN
frontend/public/images/hero_001.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 686 KiB |
23
frontend/src/api/deviceTypes.ts
Normal file
23
frontend/src/api/deviceTypes.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import apiClient from './client'
|
||||
import type { DeviceTypeResponse, DeviceTypeCreate } from '@/types'
|
||||
|
||||
export const deviceTypesApi = {
|
||||
async list(): Promise<DeviceTypeResponse[]> {
|
||||
const response = await apiClient.get<DeviceTypeResponse[]>('/device-types/')
|
||||
return response.data
|
||||
},
|
||||
|
||||
async create(data: DeviceTypeCreate): Promise<DeviceTypeResponse> {
|
||||
const response = await apiClient.post<DeviceTypeResponse>('/device-types/', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async update(id: string, data: Partial<DeviceTypeCreate>): Promise<DeviceTypeResponse> {
|
||||
const response = await apiClient.put<DeviceTypeResponse>(`/device-types/${id}`, data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async remove(id: string): Promise<void> {
|
||||
await apiClient.delete(`/device-types/${id}`)
|
||||
},
|
||||
}
|
||||
@@ -35,3 +35,5 @@ export { betaFeedbackApi } from './betaFeedback'
|
||||
export { branchesApi } from './branches'
|
||||
export { handoffsApi } from './handoffs'
|
||||
export { resolutionsApi } from './resolutions'
|
||||
export { deviceTypesApi } from './deviceTypes'
|
||||
export { networkDiagramsApi } from './networkDiagrams'
|
||||
|
||||
63
frontend/src/api/networkDiagrams.ts
Normal file
63
frontend/src/api/networkDiagrams.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import apiClient from './client'
|
||||
import type {
|
||||
NetworkDiagramResponse,
|
||||
NetworkDiagramListItem,
|
||||
NetworkDiagramCreate,
|
||||
NetworkDiagramUpdate,
|
||||
AIGenerateRequest,
|
||||
AIGenerateResponse,
|
||||
DiagramImportData,
|
||||
DiagramImportResponse,
|
||||
DiagramExportResponse,
|
||||
} from '@/types'
|
||||
|
||||
export const networkDiagramsApi = {
|
||||
async list(params?: { client_name?: string; search?: string }): Promise<NetworkDiagramListItem[]> {
|
||||
const response = await apiClient.get<NetworkDiagramListItem[]>('/network-diagrams/', { params })
|
||||
return response.data
|
||||
},
|
||||
|
||||
async get(id: string): Promise<NetworkDiagramResponse> {
|
||||
const response = await apiClient.get<NetworkDiagramResponse>(`/network-diagrams/${id}`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async create(data: NetworkDiagramCreate): Promise<NetworkDiagramResponse> {
|
||||
const response = await apiClient.post<NetworkDiagramResponse>('/network-diagrams/', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async update(id: string, data: NetworkDiagramUpdate): Promise<NetworkDiagramResponse> {
|
||||
const response = await apiClient.put<NetworkDiagramResponse>(`/network-diagrams/${id}`, data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async archive(id: string): Promise<void> {
|
||||
await apiClient.delete(`/network-diagrams/${id}`)
|
||||
},
|
||||
|
||||
async duplicate(id: string): Promise<NetworkDiagramResponse> {
|
||||
const response = await apiClient.post<NetworkDiagramResponse>(`/network-diagrams/${id}/duplicate`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async exportJson(id: string): Promise<DiagramExportResponse> {
|
||||
const response = await apiClient.get<DiagramExportResponse>(`/network-diagrams/${id}/export`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async importJson(data: DiagramImportData): Promise<DiagramImportResponse> {
|
||||
const response = await apiClient.post<DiagramImportResponse>('/network-diagrams/import', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async aiGenerate(data: AIGenerateRequest): Promise<AIGenerateResponse> {
|
||||
const response = await apiClient.post<AIGenerateResponse>('/network-diagrams/ai-generate', data)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async listClients(): Promise<string[]> {
|
||||
const response = await apiClient.get<string[]>('/network-diagrams/clients')
|
||||
return response.data
|
||||
},
|
||||
}
|
||||
@@ -57,6 +57,7 @@ function loadTaskState(sessionId: string): TaskResponse[] | null {
|
||||
} catch { return null }
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react-refresh/only-export-components
|
||||
export function clearTaskState(sessionId: string) {
|
||||
try { sessionStorage.removeItem(`${TASK_LANE_STORAGE_KEY}:${sessionId}`) } catch { /* ignore */ }
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ export function TeamSummary() {
|
||||
const { isAccountOwner } = usePermissions()
|
||||
const navigate = useNavigate()
|
||||
const [escalationCount, setEscalationCount] = useState(0)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [loading, setLoading] = useState(!!isAccountOwner)
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAccountOwner) { setLoading(false); return }
|
||||
if (!isAccountOwner) return
|
||||
aiSessionsApi.getEscalationQueue()
|
||||
.then((esc) => setEscalationCount(esc.length))
|
||||
.catch(() => {})
|
||||
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
LayoutGrid, Clock, AlertTriangle, GitBranch, Code2, Wand2,
|
||||
ListChecks, Download, BarChart3,
|
||||
Settings, Pin, PinOff,
|
||||
History, FileText,
|
||||
History, FileText, Network,
|
||||
} from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { useUserPreferencesStore } from '@/store/userPreferencesStore'
|
||||
@@ -86,10 +86,11 @@ export function Sidebar() {
|
||||
{
|
||||
href: '/trees', icon: GitBranch, label: 'Flows', shortLabel: 'Flows',
|
||||
badge: stats?.tree_counts.total || undefined,
|
||||
matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue'],
|
||||
matchPaths: ['/trees', '/flows', '/my-trees', '/step-library', '/review-queue', '/network-diagrams'],
|
||||
children: [
|
||||
{ href: '/trees', label: 'Flow Library', count: stats?.tree_counts.total || undefined },
|
||||
{ href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined },
|
||||
{ href: '/network-diagrams', label: 'Network Maps' },
|
||||
{ href: '/step-library', label: 'Solutions Library' },
|
||||
{ href: '/review-queue', label: 'Review Queue' },
|
||||
],
|
||||
@@ -134,6 +135,7 @@ export function Sidebar() {
|
||||
{ href: '/trees?type=procedural', label: 'Projects', count: stats?.tree_counts.procedural || undefined },
|
||||
],
|
||||
},
|
||||
{ href: '/network-diagrams', icon: Network, label: 'Network Maps', shortLabel: 'NetMap', matchPaths: ['/network-diagrams'] },
|
||||
{ href: '/scripts', icon: Code2, label: 'Scripts', shortLabel: 'Scripts' },
|
||||
{ href: '/script-builder', icon: Wand2, label: 'Script Builder', shortLabel: 'Builder' },
|
||||
{ href: '/review-queue', icon: ListChecks, label: 'Review Queue', shortLabel: 'Review' },
|
||||
|
||||
232
frontend/src/components/network/CanvasEmptyPrompt.tsx
Normal file
232
frontend/src/components/network/CanvasEmptyPrompt.tsx
Normal file
@@ -0,0 +1,232 @@
|
||||
import { useState, useCallback, useEffect } from 'react'
|
||||
import { Sparkles, ArrowRight, PencilRuler, Wand2, X } from 'lucide-react'
|
||||
import { networkDiagramsApi } from '@/api'
|
||||
import type { AIGenerateResponse } from '@/types'
|
||||
|
||||
const EXAMPLE_PROMPTS = [
|
||||
'Small office with firewall and core switch',
|
||||
'Azure hybrid cloud with VPN gateway',
|
||||
'Branch office connected to HQ via MPLS',
|
||||
'Data center with redundant core switches',
|
||||
'Remote workforce with Meraki and cloud apps',
|
||||
]
|
||||
|
||||
interface CanvasEmptyPromptProps {
|
||||
onGenerate: (result: AIGenerateResponse, mode: 'replace' | 'merge') => void
|
||||
}
|
||||
|
||||
export function CanvasEmptyPrompt({ onGenerate }: CanvasEmptyPromptProps) {
|
||||
const [mode, setMode] = useState<'choice' | 'ai' | 'manual'>('choice')
|
||||
const [description, setDescription] = useState('')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
const switchToManual = useCallback(() => {
|
||||
if (loading) return
|
||||
setMode('manual')
|
||||
setError(null)
|
||||
}, [loading])
|
||||
|
||||
const handleGenerate = useCallback(async (text?: string) => {
|
||||
const desc = (text ?? description).trim()
|
||||
if (!desc) return
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
try {
|
||||
const result = await networkDiagramsApi.aiGenerate({
|
||||
description: desc,
|
||||
mode: 'replace',
|
||||
existingBounds: null,
|
||||
})
|
||||
onGenerate(result, 'replace')
|
||||
} catch (err: unknown) {
|
||||
setError(err instanceof Error ? err.message : 'Generation failed. Please try again.')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [description, onGenerate])
|
||||
|
||||
useEffect(() => {
|
||||
if (mode === 'manual') return
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault()
|
||||
switchToManual()
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown)
|
||||
return () => window.removeEventListener('keydown', handleKeyDown)
|
||||
}, [mode, switchToManual])
|
||||
|
||||
if (mode === 'manual') {
|
||||
return (
|
||||
<div className="pointer-events-none absolute inset-x-0 bottom-6 z-10 flex justify-center px-6">
|
||||
<div className="pointer-events-auto flex max-w-xl items-center gap-3 rounded-lg border border-default bg-card px-4 py-3 shadow-xl">
|
||||
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-accent/10 text-accent">
|
||||
<PencilRuler size={14} />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<p className="text-sm font-medium text-heading">Manual mode is on</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Drag devices from the left panel onto the canvas, or reopen AI whenever you want.
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => setMode('ai')}
|
||||
className="inline-flex shrink-0 items-center gap-1 rounded-full border border-default px-3 py-1 text-xs font-medium text-primary hover:border-accent hover:text-accent"
|
||||
>
|
||||
<Sparkles size={12} />
|
||||
Open AI Generator
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="pointer-events-none absolute inset-0 z-10 flex items-center justify-center bg-[rgba(10,14,20,0.42)] px-6"
|
||||
onClick={event => {
|
||||
if (event.target === event.currentTarget) {
|
||||
switchToManual()
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="pointer-events-auto relative w-full max-w-lg rounded-lg border border-default bg-card p-8 shadow-2xl">
|
||||
<button
|
||||
onClick={switchToManual}
|
||||
disabled={loading}
|
||||
aria-label="Close AI prompt and build manually"
|
||||
className="absolute right-4 top-4 inline-flex h-8 w-8 items-center justify-center rounded-full border border-default text-muted-foreground hover:border-hover hover:text-primary disabled:opacity-40"
|
||||
>
|
||||
<X size={14} />
|
||||
</button>
|
||||
{mode === 'choice' ? (
|
||||
<>
|
||||
<div className="mb-6 text-center">
|
||||
<div className="mb-2 flex items-center justify-center gap-2">
|
||||
<Wand2 size={16} className="text-accent" />
|
||||
<h2 className="font-heading text-base font-semibold text-heading">
|
||||
Start a network map
|
||||
</h2>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Generate a topology with AI or start with a blank canvas and build it manually.
|
||||
</p>
|
||||
<p className="mt-2 text-[11px] text-muted-foreground/80">
|
||||
Press <span className="font-medium text-primary">Esc</span> or click outside to skip AI and start dragging devices.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
<button
|
||||
onClick={() => setMode('ai')}
|
||||
className="rounded-lg border border-accent/40 bg-accent/10 p-4 text-left transition-colors hover:border-accent hover:bg-accent/15"
|
||||
>
|
||||
<div className="mb-3 inline-flex rounded-lg bg-accent/15 p-2 text-accent">
|
||||
<Sparkles size={16} />
|
||||
</div>
|
||||
<div className="mb-1 text-sm font-semibold text-heading">Generate with AI</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Describe the environment and let AI lay out the first version for you.
|
||||
</p>
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={switchToManual}
|
||||
className="rounded-lg border border-default bg-elevated/40 p-4 text-left transition-colors hover:border-accent hover:bg-elevated/60"
|
||||
>
|
||||
<div className="mb-3 inline-flex rounded-lg bg-primary/10 p-2 text-primary">
|
||||
<PencilRuler size={16} />
|
||||
</div>
|
||||
<div className="mb-1 text-sm font-semibold text-heading">Build manually</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Close this prompt and use click-and-drag from the left toolbar to place devices on the canvas.
|
||||
</p>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-5 text-center">
|
||||
<div className="mb-2 flex items-center justify-center gap-2">
|
||||
<Sparkles size={16} className="text-accent" />
|
||||
<h2 className="font-heading text-base font-semibold text-heading">
|
||||
Describe your network
|
||||
</h2>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
AI will generate the topology in seconds, or you can go back and switch to manual creation.
|
||||
</p>
|
||||
<p className="mt-2 text-[11px] text-muted-foreground/80">
|
||||
Press <span className="font-medium text-primary">Esc</span>, click outside, or use the close button to build manually instead.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="relative mb-3">
|
||||
<textarea
|
||||
value={description}
|
||||
onChange={e => setDescription(e.target.value)}
|
||||
onKeyDown={e => {
|
||||
if (e.key === 'Enter' && (e.metaKey || e.ctrlKey)) handleGenerate()
|
||||
}}
|
||||
placeholder="e.g. Small office with a firewall, core switch, 3 access points, a file server, and 20 workstations"
|
||||
rows={3}
|
||||
disabled={loading}
|
||||
autoFocus
|
||||
className="w-full resize-none rounded-lg border border-default bg-input px-4 py-3 pb-7 text-sm text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none disabled:opacity-50"
|
||||
/>
|
||||
<span className="pointer-events-none absolute bottom-2 right-3 text-[10px] text-muted-foreground">
|
||||
⌘↵ to generate
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mb-4 flex flex-wrap gap-1.5">
|
||||
{EXAMPLE_PROMPTS.map(p => (
|
||||
<button
|
||||
key={p}
|
||||
onClick={() => handleGenerate(p)}
|
||||
disabled={loading}
|
||||
className="rounded-full border border-default px-3 py-1 text-xs text-muted-foreground transition-colors hover:border-accent hover:text-accent disabled:opacity-40"
|
||||
>
|
||||
{p}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{error && <p className="mb-3 text-xs text-red-400">{error}</p>}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={switchToManual}
|
||||
disabled={loading}
|
||||
className="flex-1 rounded-lg border border-default px-4 py-2.5 text-sm font-medium text-primary hover:border-accent hover:text-accent disabled:opacity-40"
|
||||
>
|
||||
Build Manually
|
||||
</button>
|
||||
{loading ? (
|
||||
<div className="flex flex-1 items-center justify-center gap-2 py-2.5">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-accent border-t-transparent" />
|
||||
<span className="text-sm text-muted-foreground">Mapping your network…</span>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleGenerate()}
|
||||
disabled={!description.trim()}
|
||||
className="flex flex-1 items-center justify-center gap-2 rounded-lg bg-accent px-4 py-2.5 text-sm font-medium text-white transition-opacity hover:bg-accent/90 disabled:opacity-40"
|
||||
>
|
||||
<Sparkles size={14} />
|
||||
Generate Diagram
|
||||
<ArrowRight size={14} />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
119
frontend/src/components/network/ContextMenu.tsx
Normal file
119
frontend/src/components/network/ContextMenu.tsx
Normal file
@@ -0,0 +1,119 @@
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { Copy, CopyPlus, Trash2, ClipboardPaste, BoxSelect, Maximize2, BringToFront, SendToBack } from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
interface MenuAction {
|
||||
label: string
|
||||
icon: React.ElementType
|
||||
shortcut: string
|
||||
onClick: () => void
|
||||
disabled?: boolean
|
||||
dividerBefore?: boolean
|
||||
}
|
||||
|
||||
interface ContextMenuProps {
|
||||
position: { x: number; y: number }
|
||||
actions: MenuAction[]
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export function ContextMenu({ position, actions, onClose }: ContextMenuProps) {
|
||||
const menuRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const clampedPosition = { ...position }
|
||||
if (typeof window !== 'undefined') {
|
||||
const itemCount = actions.length
|
||||
const dividerCount = actions.filter(a => a.dividerBefore).length
|
||||
const menuWidth = 192
|
||||
const menuHeight = itemCount * 36 + dividerCount * 9 + 8
|
||||
if (clampedPosition.x + menuWidth > window.innerWidth) {
|
||||
clampedPosition.x = window.innerWidth - menuWidth - 8
|
||||
}
|
||||
if (clampedPosition.y + menuHeight > window.innerHeight) {
|
||||
clampedPosition.y = window.innerHeight - menuHeight - 8
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (e: MouseEvent) => {
|
||||
if (menuRef.current && !menuRef.current.contains(e.target as HTMLElement)) {
|
||||
onClose()
|
||||
}
|
||||
}
|
||||
const handleEscape = (e: KeyboardEvent) => {
|
||||
if (e.key === 'Escape') onClose()
|
||||
}
|
||||
const handleScroll = () => onClose()
|
||||
|
||||
document.addEventListener('mousedown', handleClickOutside)
|
||||
document.addEventListener('keydown', handleEscape)
|
||||
document.addEventListener('scroll', handleScroll, true)
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside)
|
||||
document.removeEventListener('keydown', handleEscape)
|
||||
document.removeEventListener('scroll', handleScroll, true)
|
||||
}
|
||||
}, [onClose])
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={menuRef}
|
||||
className="fixed z-50 w-48 rounded-lg border border-default bg-card py-1 shadow-lg"
|
||||
style={{ left: clampedPosition.x, top: clampedPosition.y }}
|
||||
>
|
||||
{actions.map((action) => (
|
||||
<div key={action.label}>
|
||||
{action.dividerBefore && (
|
||||
<div className="my-1 border-t border-default" />
|
||||
)}
|
||||
<button
|
||||
onClick={() => {
|
||||
action.onClick()
|
||||
onClose()
|
||||
}}
|
||||
disabled={action.disabled}
|
||||
className={cn(
|
||||
'flex w-full items-center gap-2 px-3 py-2 text-xs text-primary hover:bg-elevated',
|
||||
action.disabled && 'opacity-40 pointer-events-none',
|
||||
)}
|
||||
>
|
||||
<action.icon size={14} />
|
||||
<span>{action.label}</span>
|
||||
<span className="ml-auto text-[10px] text-muted-foreground">{action.shortcut}</span>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react-refresh/only-export-components
|
||||
export function getNodeMenuActions(handlers: {
|
||||
onCopy: () => void
|
||||
onDuplicate: () => void
|
||||
onBringToFront: () => void
|
||||
onSendToBack: () => void
|
||||
onDelete: () => void
|
||||
}): MenuAction[] {
|
||||
return [
|
||||
{ label: 'Copy', icon: Copy, shortcut: 'Ctrl+C', onClick: handlers.onCopy },
|
||||
{ label: 'Duplicate', icon: CopyPlus, shortcut: 'Ctrl+D', onClick: handlers.onDuplicate },
|
||||
{ label: 'Bring to Front', icon: BringToFront, shortcut: ']', onClick: handlers.onBringToFront, dividerBefore: true },
|
||||
{ label: 'Send to Back', icon: SendToBack, shortcut: '[', onClick: handlers.onSendToBack },
|
||||
{ label: 'Delete', icon: Trash2, shortcut: 'Del', onClick: handlers.onDelete, dividerBefore: true },
|
||||
]
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react-refresh/only-export-components
|
||||
export function getCanvasMenuActions(handlers: {
|
||||
onPaste: () => void
|
||||
onSelectAll: () => void
|
||||
onFitView: () => void
|
||||
hasClipboard: boolean
|
||||
}): MenuAction[] {
|
||||
return [
|
||||
{ label: 'Paste', icon: ClipboardPaste, shortcut: 'Ctrl+V', onClick: handlers.onPaste, disabled: !handlers.hasClipboard },
|
||||
{ label: 'Select All', icon: BoxSelect, shortcut: 'Ctrl+A', onClick: handlers.onSelectAll },
|
||||
{ label: 'Fit View', icon: Maximize2, shortcut: '⌘⇧F', onClick: handlers.onFitView },
|
||||
]
|
||||
}
|
||||
167
frontend/src/components/network/DiagramHeader.tsx
Normal file
167
frontend/src/components/network/DiagramHeader.tsx
Normal file
@@ -0,0 +1,167 @@
|
||||
import { useState, useCallback, useRef, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { ChevronLeft, Save, Download, FileJson, Image, FileText } from 'lucide-react'
|
||||
|
||||
interface DiagramHeaderProps {
|
||||
name: string
|
||||
clientName: string | null
|
||||
isDirty: boolean
|
||||
isSaving: boolean
|
||||
lastSavedAt: Date | null
|
||||
diagramId: string | null
|
||||
onNameChange: (name: string) => void
|
||||
onSave: () => void
|
||||
onExportPng: () => void
|
||||
onExportPdf: () => void
|
||||
onExportJson: () => void
|
||||
}
|
||||
|
||||
export function DiagramHeader({
|
||||
name,
|
||||
clientName,
|
||||
isDirty,
|
||||
isSaving,
|
||||
lastSavedAt,
|
||||
diagramId,
|
||||
onNameChange,
|
||||
onSave,
|
||||
onExportPng,
|
||||
onExportPdf,
|
||||
onExportJson,
|
||||
}: DiagramHeaderProps) {
|
||||
const navigate = useNavigate()
|
||||
const [editing, setEditing] = useState(false)
|
||||
const [editValue, setEditValue] = useState(name)
|
||||
const [showExportMenu, setShowExportMenu] = useState(false)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const exportMenuRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (editing && inputRef.current) {
|
||||
inputRef.current.focus()
|
||||
inputRef.current.select()
|
||||
}
|
||||
}, [editing])
|
||||
|
||||
useEffect(() => {
|
||||
setEditValue(name)
|
||||
}, [name])
|
||||
|
||||
useEffect(() => {
|
||||
if (!showExportMenu) return
|
||||
const handleClick = (e: MouseEvent) => {
|
||||
if (exportMenuRef.current && !exportMenuRef.current.contains(e.target as HTMLElement)) {
|
||||
setShowExportMenu(false)
|
||||
}
|
||||
}
|
||||
document.addEventListener('mousedown', handleClick)
|
||||
return () => document.removeEventListener('mousedown', handleClick)
|
||||
}, [showExportMenu])
|
||||
|
||||
const handleConfirmName = useCallback(() => {
|
||||
setEditing(false)
|
||||
if (editValue.trim() && editValue !== name) {
|
||||
onNameChange(editValue.trim())
|
||||
} else {
|
||||
setEditValue(name)
|
||||
}
|
||||
}, [editValue, name, onNameChange])
|
||||
|
||||
const formatLastSaved = () => {
|
||||
if (!lastSavedAt) return null
|
||||
// eslint-disable-next-line react-hooks/purity
|
||||
const diff = Date.now() - lastSavedAt.getTime()
|
||||
if (diff < 60_000) return 'Saved just now'
|
||||
const mins = Math.floor(diff / 60_000)
|
||||
return `Saved ${mins}m ago`
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-14 items-center gap-3 border-b border-default bg-card px-4">
|
||||
<button
|
||||
onClick={() => navigate('/network-diagrams')}
|
||||
className="flex items-center gap-1 text-xs text-muted-foreground hover:text-primary"
|
||||
>
|
||||
<ChevronLeft size={16} />
|
||||
Network Maps
|
||||
</button>
|
||||
|
||||
<div className="mx-2 h-5 w-px bg-border-default" />
|
||||
|
||||
{editing ? (
|
||||
<input
|
||||
ref={inputRef}
|
||||
value={editValue}
|
||||
onChange={e => setEditValue(e.target.value)}
|
||||
onBlur={handleConfirmName}
|
||||
onKeyDown={e => { if (e.key === 'Enter') handleConfirmName(); if (e.key === 'Escape') { setEditing(false); setEditValue(name) } }}
|
||||
className="rounded border border-accent bg-input px-2 py-1 text-sm font-heading font-semibold text-heading focus:outline-none"
|
||||
/>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => setEditing(true)}
|
||||
className="text-sm font-heading font-semibold text-heading hover:text-accent"
|
||||
>
|
||||
{name || 'Untitled Diagram'}
|
||||
</button>
|
||||
)}
|
||||
|
||||
{clientName && (
|
||||
<span className="rounded-full bg-elevated px-2 py-0.5 text-[10px] text-muted-foreground">
|
||||
{clientName}
|
||||
</span>
|
||||
)}
|
||||
|
||||
<div className="flex-1" />
|
||||
|
||||
{isDirty && !isSaving ? (
|
||||
<span className="text-[10px] text-amber-400">Unsaved changes</span>
|
||||
) : lastSavedAt ? (
|
||||
<span className="text-[10px] text-muted-foreground">{formatLastSaved()}</span>
|
||||
) : null}
|
||||
|
||||
<button
|
||||
onClick={onSave}
|
||||
disabled={isSaving}
|
||||
className="flex items-center gap-1.5 rounded bg-accent px-3 py-1.5 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
|
||||
>
|
||||
<Save size={14} />
|
||||
{isSaving ? 'Saving...' : 'Save'}
|
||||
</button>
|
||||
|
||||
<div className="relative" ref={exportMenuRef}>
|
||||
<button
|
||||
onClick={() => setShowExportMenu(prev => !prev)}
|
||||
className="flex items-center gap-1.5 rounded border border-default px-3 py-1.5 text-xs text-primary hover:border-hover"
|
||||
>
|
||||
<Download size={14} />
|
||||
Export
|
||||
</button>
|
||||
{showExportMenu && (
|
||||
<div className="absolute right-0 top-full z-50 mt-1 w-40 rounded border border-default bg-card py-1 shadow-lg">
|
||||
<button
|
||||
onClick={() => { onExportPng(); setShowExportMenu(false) }}
|
||||
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
<Image size={12} /> Export PNG
|
||||
</button>
|
||||
<button
|
||||
onClick={() => { onExportPdf(); setShowExportMenu(false) }}
|
||||
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
<FileText size={12} /> Export PDF
|
||||
</button>
|
||||
{diagramId && (
|
||||
<button
|
||||
onClick={() => { onExportJson(); setShowExportMenu(false) }}
|
||||
className="flex w-full items-center gap-2 px-3 py-1.5 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
<FileJson size={12} /> Export JSON
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
119
frontend/src/components/network/NetworkCanvas.tsx
Normal file
119
frontend/src/components/network/NetworkCanvas.tsx
Normal file
@@ -0,0 +1,119 @@
|
||||
import { useCallback } from 'react'
|
||||
import {
|
||||
ReactFlow,
|
||||
Background,
|
||||
Controls,
|
||||
MiniMap,
|
||||
BackgroundVariant,
|
||||
type OnConnect,
|
||||
type OnNodesChange,
|
||||
type OnEdgesChange,
|
||||
type Node,
|
||||
type Edge,
|
||||
} from '@xyflow/react'
|
||||
import { nodeTypes } from './nodes/nodeTypes'
|
||||
import { edgeTypes } from './edges/edgeTypes'
|
||||
import { getDeviceRenderConfig } from './nodes/deviceRegistry'
|
||||
import type { DeviceNodeData } from './nodes/DeviceNode'
|
||||
|
||||
interface NetworkCanvasProps {
|
||||
nodes: Node[]
|
||||
edges: Edge[]
|
||||
onNodesChange: OnNodesChange
|
||||
onEdgesChange: OnEdgesChange
|
||||
onConnect: OnConnect
|
||||
onNodeSelect: (nodeId: string | null) => void
|
||||
onEdgeSelect: (edgeId: string | null) => void
|
||||
onDrop: (event: React.DragEvent) => void
|
||||
onDragOver: (event: React.DragEvent) => void
|
||||
onDragLeave?: (event: React.DragEvent) => void
|
||||
isDragOver?: boolean
|
||||
onNodeContextMenu?: (event: React.MouseEvent, node: Node) => void
|
||||
onPaneContextMenu?: (event: MouseEvent | React.MouseEvent) => void
|
||||
onPaneClick?: () => void
|
||||
}
|
||||
|
||||
export function NetworkCanvas({
|
||||
nodes,
|
||||
edges,
|
||||
onNodesChange,
|
||||
onEdgesChange,
|
||||
onConnect,
|
||||
onNodeSelect,
|
||||
onEdgeSelect,
|
||||
onDrop,
|
||||
onDragOver,
|
||||
onDragLeave,
|
||||
isDragOver,
|
||||
onNodeContextMenu,
|
||||
onPaneContextMenu,
|
||||
onPaneClick: onPaneClickProp,
|
||||
}: NetworkCanvasProps) {
|
||||
const handleSelectionChange = useCallback(({ nodes: selectedNodes, edges: selectedEdges }: { nodes: Node[]; edges: Edge[] }) => {
|
||||
if (selectedNodes.length === 1) {
|
||||
onNodeSelect(selectedNodes[0].id)
|
||||
onEdgeSelect(null)
|
||||
} else if (selectedEdges.length === 1) {
|
||||
onEdgeSelect(selectedEdges[0].id)
|
||||
onNodeSelect(null)
|
||||
} else {
|
||||
onNodeSelect(null)
|
||||
onEdgeSelect(null)
|
||||
}
|
||||
}, [onNodeSelect, onEdgeSelect])
|
||||
|
||||
const handlePaneClick = useCallback(() => {
|
||||
onNodeSelect(null)
|
||||
onEdgeSelect(null)
|
||||
onPaneClickProp?.()
|
||||
}, [onNodeSelect, onEdgeSelect, onPaneClickProp])
|
||||
|
||||
const getNodeColor = useCallback((node: Node) => {
|
||||
if (node.type === 'group') return 'var(--color-bg-elevated)'
|
||||
const data = node.data as unknown as DeviceNodeData
|
||||
return getDeviceRenderConfig(data?.deviceType || '', data?.category).color
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className="relative h-full w-full" onDragLeave={onDragLeave}>
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onConnect={onConnect}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
onPaneClick={handlePaneClick}
|
||||
onDrop={onDrop}
|
||||
onDragOver={onDragOver}
|
||||
onNodeContextMenu={onNodeContextMenu}
|
||||
onPaneContextMenu={onPaneContextMenu}
|
||||
nodeTypes={nodeTypes}
|
||||
edgeTypes={edgeTypes}
|
||||
defaultEdgeOptions={{ type: 'connection' }}
|
||||
deleteKeyCode={['Backspace', 'Delete']}
|
||||
multiSelectionKeyCode="Shift"
|
||||
snapToGrid={true}
|
||||
snapGrid={[20, 20]}
|
||||
fitView
|
||||
className="bg-page"
|
||||
>
|
||||
<Background variant={BackgroundVariant.Dots} color="var(--color-border-default)" gap={20} size={1} />
|
||||
<Controls className="!border-default !bg-card [&>button]:!border-default [&>button]:!bg-card [&>button]:!fill-text-primary" />
|
||||
<MiniMap
|
||||
nodeColor={getNodeColor}
|
||||
maskColor="rgba(0,0,0,0.5)"
|
||||
className="!border-default !bg-card"
|
||||
position="bottom-right"
|
||||
/>
|
||||
</ReactFlow>
|
||||
{isDragOver && (
|
||||
<div className="pointer-events-none absolute inset-2 z-10 flex items-center justify-center rounded-lg border-2 border-dashed border-accent/30">
|
||||
<span className="rounded-md bg-card/80 px-3 py-1.5 text-sm text-muted-foreground">
|
||||
Drop to add
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
71
frontend/src/components/network/edges/ConnectionEdge.tsx
Normal file
71
frontend/src/components/network/edges/ConnectionEdge.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
import { memo } from 'react'
|
||||
import { BaseEdge, EdgeLabelRenderer, getStraightPath, getBezierPath, getSmoothStepPath, type EdgeProps } from '@xyflow/react'
|
||||
|
||||
interface ConnectionEdgeData {
|
||||
connectionType?: string
|
||||
routing?: string | null
|
||||
speed?: string | null
|
||||
notes?: string | null
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
const CONNECTION_STYLES: Record<string, { stroke: string; strokeDasharray?: string; strokeWidth: number }> = {
|
||||
ethernet: { stroke: '#60a5fa', strokeWidth: 2 },
|
||||
fiber: { stroke: '#34d399', strokeWidth: 3 },
|
||||
wifi: { stroke: '#a78bfa', strokeDasharray: '3,3', strokeWidth: 2 },
|
||||
vpn: { stroke: '#eab308', strokeDasharray: '8,4', strokeWidth: 2 },
|
||||
vlan: { stroke: '#848b9b', strokeWidth: 2 },
|
||||
wan: { stroke: '#f87171', strokeDasharray: '12,4', strokeWidth: 2 },
|
||||
}
|
||||
|
||||
const DEFAULT_STYLE = { stroke: '#848b9b', strokeWidth: 2 }
|
||||
|
||||
function getEdgePath(routing: string | null | undefined, props: EdgeProps) {
|
||||
const base = {
|
||||
sourceX: props.sourceX,
|
||||
sourceY: props.sourceY,
|
||||
sourcePosition: props.sourcePosition,
|
||||
targetX: props.targetX,
|
||||
targetY: props.targetY,
|
||||
targetPosition: props.targetPosition,
|
||||
}
|
||||
if (routing === 'curved') return getBezierPath(base)
|
||||
if (routing === 'step') return getSmoothStepPath(base)
|
||||
return getStraightPath(base)
|
||||
}
|
||||
|
||||
function ConnectionEdgeComponent(props: EdgeProps) {
|
||||
const edgeData = props.data as ConnectionEdgeData | undefined
|
||||
const connectionType = edgeData?.connectionType || 'ethernet'
|
||||
const style = CONNECTION_STYLES[connectionType] || DEFAULT_STYLE
|
||||
|
||||
const [edgePath, labelX, labelY] = getEdgePath(edgeData?.routing, props)
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
style={{
|
||||
...style,
|
||||
...(props.selected ? { stroke: '#60a5fa', strokeWidth: style.strokeWidth + 1 } : {}),
|
||||
}}
|
||||
/>
|
||||
{props.label && (
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
className="nodrag nopan rounded bg-page px-1.5 py-0.5 text-[10px] text-muted-foreground"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: 'all',
|
||||
}}
|
||||
>
|
||||
{props.label}
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export const ConnectionEdge = memo(ConnectionEdgeComponent)
|
||||
7
frontend/src/components/network/edges/edgeTypes.ts
Normal file
7
frontend/src/components/network/edges/edgeTypes.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { ConnectionEdge } from './ConnectionEdge'
|
||||
import { AnimatedSvgEdge } from '../ui/animated-svg-edge'
|
||||
|
||||
export const edgeTypes = {
|
||||
connection: ConnectionEdge,
|
||||
animated: AnimatedSvgEdge,
|
||||
}
|
||||
252
frontend/src/components/network/hooks/useCanvasShortcuts.ts
Normal file
252
frontend/src/components/network/hooks/useCanvasShortcuts.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { useReactFlow, type Node, type Edge } from '@xyflow/react'
|
||||
|
||||
interface ClipboardData {
|
||||
nodes: Array<{
|
||||
type: string
|
||||
data: Record<string, unknown>
|
||||
style?: React.CSSProperties
|
||||
relativePosition: { x: number; y: number }
|
||||
}>
|
||||
edges: Array<{
|
||||
sourceIndex: number
|
||||
targetIndex: number
|
||||
type?: string
|
||||
data?: Record<string, unknown>
|
||||
label?: string
|
||||
}>
|
||||
}
|
||||
|
||||
function generateId(prefix: string): string {
|
||||
return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`
|
||||
}
|
||||
|
||||
function isInputFocused(): boolean {
|
||||
const tag = document.activeElement?.tagName
|
||||
return tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT'
|
||||
}
|
||||
|
||||
export function useCanvasShortcuts({
|
||||
nodes: _nodes, // eslint-disable-line @typescript-eslint/no-unused-vars
|
||||
edges,
|
||||
setNodes,
|
||||
setEdges,
|
||||
setIsDirty,
|
||||
canvasRef,
|
||||
}: {
|
||||
nodes: Node[]
|
||||
edges: Edge[]
|
||||
setNodes: React.Dispatch<React.SetStateAction<Node[]>>
|
||||
setEdges: React.Dispatch<React.SetStateAction<Edge[]>>
|
||||
setIsDirty: (dirty: boolean) => void
|
||||
canvasRef: React.RefObject<HTMLDivElement | null>
|
||||
}) {
|
||||
const { getNodes, fitView, screenToFlowPosition, setNodes: rfSetNodes } = useReactFlow()
|
||||
const clipboardRef = useRef<ClipboardData | null>(null)
|
||||
|
||||
const getSelectedNodes = useCallback((): Node[] => {
|
||||
return getNodes().filter(n => n.selected)
|
||||
}, [getNodes])
|
||||
|
||||
const copyNodes = useCallback(() => {
|
||||
const selected = getSelectedNodes()
|
||||
if (selected.length === 0) return
|
||||
|
||||
const centroid = {
|
||||
x: selected.reduce((sum, n) => sum + n.position.x, 0) / selected.length,
|
||||
y: selected.reduce((sum, n) => sum + n.position.y, 0) / selected.length,
|
||||
}
|
||||
|
||||
const selectedIds = new Set(selected.map(n => n.id))
|
||||
|
||||
const clipNodes = selected.map(n => ({
|
||||
type: n.type || 'device',
|
||||
data: structuredClone(n.data),
|
||||
style: n.style ? { ...n.style } : undefined,
|
||||
relativePosition: {
|
||||
x: n.position.x - centroid.x,
|
||||
y: n.position.y - centroid.y,
|
||||
},
|
||||
}))
|
||||
|
||||
const selectedList = selected.map(n => n.id)
|
||||
const clipEdges = edges
|
||||
.filter(e => selectedIds.has(e.source) && selectedIds.has(e.target))
|
||||
.map(e => ({
|
||||
sourceIndex: selectedList.indexOf(e.source),
|
||||
targetIndex: selectedList.indexOf(e.target),
|
||||
type: e.type,
|
||||
data: e.data ? structuredClone(e.data) as Record<string, unknown> : undefined,
|
||||
label: typeof e.label === 'string' ? e.label : undefined,
|
||||
}))
|
||||
|
||||
clipboardRef.current = { nodes: clipNodes, edges: clipEdges }
|
||||
}, [getSelectedNodes, edges])
|
||||
|
||||
const pasteNodes = useCallback(() => {
|
||||
const clipboard = clipboardRef.current
|
||||
if (!clipboard || clipboard.nodes.length === 0) return
|
||||
|
||||
const canvasEl = canvasRef.current
|
||||
if (!canvasEl) return
|
||||
const rect = canvasEl.getBoundingClientRect()
|
||||
const center = screenToFlowPosition({
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top + rect.height / 2,
|
||||
})
|
||||
|
||||
const newNodeIds: string[] = []
|
||||
const newNodes: Node[] = clipboard.nodes.map(cn => {
|
||||
const prefix = cn.type === 'group' ? 'group' : 'device'
|
||||
const id = generateId(prefix)
|
||||
newNodeIds.push(id)
|
||||
return {
|
||||
id,
|
||||
type: cn.type,
|
||||
position: {
|
||||
x: center.x + cn.relativePosition.x,
|
||||
y: center.y + cn.relativePosition.y,
|
||||
},
|
||||
data: structuredClone(cn.data) as Record<string, unknown>,
|
||||
style: cn.style ? { ...cn.style } : undefined,
|
||||
selected: true,
|
||||
}
|
||||
})
|
||||
|
||||
const newEdges: Edge[] = clipboard.edges.map(ce => ({
|
||||
id: generateId('edge'),
|
||||
source: newNodeIds[ce.sourceIndex],
|
||||
target: newNodeIds[ce.targetIndex],
|
||||
type: ce.type,
|
||||
data: ce.data ? structuredClone(ce.data) as Record<string, unknown> : undefined,
|
||||
label: ce.label,
|
||||
}))
|
||||
|
||||
setNodes(nds => [
|
||||
...nds.map(n => ({ ...n, selected: false })),
|
||||
...newNodes,
|
||||
])
|
||||
setEdges(eds => [...eds, ...newEdges])
|
||||
setIsDirty(true)
|
||||
}, [canvasRef, screenToFlowPosition, setNodes, setEdges, setIsDirty])
|
||||
|
||||
const duplicateNodes = useCallback(() => {
|
||||
const selected = getSelectedNodes()
|
||||
if (selected.length === 0) return
|
||||
|
||||
const selectedIds = new Set(selected.map(n => n.id))
|
||||
const idMap = new Map<string, string>()
|
||||
|
||||
const newNodes: Node[] = selected.map(n => {
|
||||
const prefix = n.type === 'group' ? 'group' : 'device'
|
||||
const newId = generateId(prefix)
|
||||
idMap.set(n.id, newId)
|
||||
return {
|
||||
id: newId,
|
||||
type: n.type,
|
||||
position: { x: n.position.x + 30, y: n.position.y + 30 },
|
||||
data: structuredClone(n.data) as Record<string, unknown>,
|
||||
style: n.style ? { ...n.style } : undefined,
|
||||
selected: true,
|
||||
}
|
||||
})
|
||||
|
||||
const newEdges: Edge[] = edges
|
||||
.filter(e => selectedIds.has(e.source) && selectedIds.has(e.target))
|
||||
.map(e => ({
|
||||
id: generateId('edge'),
|
||||
source: idMap.get(e.source)!,
|
||||
target: idMap.get(e.target)!,
|
||||
type: e.type,
|
||||
data: e.data ? structuredClone(e.data) as Record<string, unknown> : undefined,
|
||||
label: e.label,
|
||||
}))
|
||||
|
||||
setNodes(nds => [
|
||||
...nds.map(n => ({ ...n, selected: false })),
|
||||
...newNodes,
|
||||
])
|
||||
setEdges(eds => [...eds, ...newEdges])
|
||||
setIsDirty(true)
|
||||
}, [getSelectedNodes, edges, setNodes, setEdges, setIsDirty])
|
||||
|
||||
const selectAll = useCallback(() => {
|
||||
rfSetNodes(nds => nds.map(n => ({ ...n, selected: true })))
|
||||
}, [rfSetNodes])
|
||||
|
||||
const deleteSelected = useCallback(() => {
|
||||
const selected = getSelectedNodes()
|
||||
if (selected.length === 0) return
|
||||
const selectedIds = new Set(selected.map(n => n.id))
|
||||
setNodes(nds => nds.filter(n => !selectedIds.has(n.id)))
|
||||
setEdges(eds => eds.filter(e => !selectedIds.has(e.source) && !selectedIds.has(e.target)))
|
||||
setIsDirty(true)
|
||||
}, [getSelectedNodes, setNodes, setEdges, setIsDirty])
|
||||
|
||||
const bringSelectedToFront = useCallback(() => {
|
||||
const selected = getSelectedNodes()
|
||||
if (!selected.length) return
|
||||
const selectedIds = new Set(selected.map(n => n.id))
|
||||
setNodes(nds => {
|
||||
const maxZ = Math.max(0, ...nds.map(n => n.zIndex ?? 0))
|
||||
return nds.map(n => selectedIds.has(n.id) ? { ...n, zIndex: maxZ + 1 } : n)
|
||||
})
|
||||
setIsDirty(true)
|
||||
}, [getSelectedNodes, setNodes, setIsDirty])
|
||||
|
||||
const sendSelectedToBack = useCallback(() => {
|
||||
const selected = getSelectedNodes()
|
||||
if (!selected.length) return
|
||||
const selectedIds = new Set(selected.map(n => n.id))
|
||||
setNodes(nds => {
|
||||
const minZ = Math.min(0, ...nds.map(n => n.zIndex ?? 0))
|
||||
return nds.map(n => selectedIds.has(n.id) ? { ...n, zIndex: minZ - 1 } : n)
|
||||
})
|
||||
setIsDirty(true)
|
||||
}, [getSelectedNodes, setNodes, setIsDirty])
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if (isInputFocused()) return
|
||||
|
||||
const ctrl = e.ctrlKey || e.metaKey
|
||||
|
||||
if (ctrl && e.key === 'c') {
|
||||
e.preventDefault()
|
||||
copyNodes()
|
||||
} else if (ctrl && e.key === 'v') {
|
||||
e.preventDefault()
|
||||
pasteNodes()
|
||||
} else if (ctrl && e.key === 'd') {
|
||||
e.preventDefault()
|
||||
duplicateNodes()
|
||||
} else if (ctrl && e.key === 'a') {
|
||||
e.preventDefault()
|
||||
selectAll()
|
||||
} else if (ctrl && e.shiftKey && (e.key === 'f' || e.key === 'F')) {
|
||||
e.preventDefault()
|
||||
fitView({ padding: 0.2 })
|
||||
} else if (e.key === ']' && !ctrl) {
|
||||
e.preventDefault()
|
||||
bringSelectedToFront()
|
||||
} else if (e.key === '[' && !ctrl) {
|
||||
e.preventDefault()
|
||||
sendSelectedToBack()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown)
|
||||
return () => document.removeEventListener('keydown', handleKeyDown)
|
||||
}, [copyNodes, pasteNodes, duplicateNodes, selectAll, fitView, bringSelectedToFront, sendSelectedToBack])
|
||||
|
||||
return {
|
||||
copyNodes,
|
||||
pasteNodes,
|
||||
duplicateNodes,
|
||||
selectAll,
|
||||
deleteSelected,
|
||||
bringSelectedToFront,
|
||||
sendSelectedToBack,
|
||||
hasClipboard: () => clipboardRef.current !== null && clipboardRef.current.nodes.length > 0,
|
||||
}
|
||||
}
|
||||
106
frontend/src/components/network/nodes/DeviceNode.tsx
Normal file
106
frontend/src/components/network/nodes/DeviceNode.tsx
Normal file
@@ -0,0 +1,106 @@
|
||||
import { memo } from 'react'
|
||||
import { Position, NodeResizer, type NodeProps } from '@xyflow/react'
|
||||
import { BaseNode, BaseNodeHeader, BaseNodeHeaderTitle, BaseNodeContent } from '../ui/base-node'
|
||||
import { BaseHandle } from '../ui/base-handle'
|
||||
import { NodeStatusIndicator, type NodeStatus } from '../ui/node-status-indicator'
|
||||
import { NodeTooltip, NodeTooltipTrigger, NodeTooltipContent } from '../ui/node-tooltip'
|
||||
import { getDeviceRenderConfig } from './deviceRegistry'
|
||||
import type { DeviceProperties } from '@/types'
|
||||
|
||||
export interface DeviceNodeData {
|
||||
label: string
|
||||
deviceType: string
|
||||
category?: string
|
||||
properties: DeviceProperties
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
function TooltipRow({ label, value }: { label: string; value: string | null | undefined }) {
|
||||
if (!value) return null
|
||||
return (
|
||||
<div className="flex gap-2">
|
||||
<span className="text-[10px] uppercase tracking-wider text-muted-foreground">{label}</span>
|
||||
<span className="text-xs font-mono text-primary">{value}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const NODE_DEFAULT = 120 // default square side in px
|
||||
const NODE_MIN = 80 // minimum square side in px
|
||||
const NODE_MAX = 280 // maximum square side in px
|
||||
|
||||
function DeviceNodeComponent({ data, selected, width, height }: NodeProps) {
|
||||
const nodeData = data as unknown as DeviceNodeData
|
||||
const { icon: Icon, color } = getDeviceRenderConfig(nodeData.deviceType, nodeData.category)
|
||||
const status = (nodeData.properties?.status || 'unknown') as NodeStatus
|
||||
const ip = nodeData.properties?.ip
|
||||
const props = nodeData.properties || {}
|
||||
|
||||
// Use the shorter dimension so content never overflows a non-square node
|
||||
const size = Math.min(width ?? NODE_DEFAULT, height ?? NODE_DEFAULT)
|
||||
const scale = size / NODE_DEFAULT
|
||||
|
||||
// Icon: 28px at default, clamped to [14, 72]
|
||||
const iconPx = Math.round(Math.max(14, Math.min(72, scale * 28)))
|
||||
// Label font: 11px at default, clamped to [9, 20]
|
||||
const labelPx = Math.max(9, Math.min(20, Math.round(scale * 11)))
|
||||
// IP font: 9px at default, clamped to [8, 16]
|
||||
const ipPx = Math.max(8, Math.min(16, Math.round(scale * 9)))
|
||||
|
||||
const hasTooltipContent = props.hostname || props.ip || props.vendor || props.model || props.role || props.notes
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeResizer
|
||||
isVisible={selected}
|
||||
minWidth={NODE_MIN}
|
||||
minHeight={NODE_MIN}
|
||||
maxWidth={NODE_MAX}
|
||||
maxHeight={NODE_MAX}
|
||||
keepAspectRatio
|
||||
lineStyle={{ borderColor: 'var(--color-accent)', borderWidth: 1 }}
|
||||
handleStyle={{ width: 8, height: 8, borderColor: 'var(--color-accent)', background: 'var(--color-card)' }}
|
||||
/>
|
||||
<NodeStatusIndicator status={status}>
|
||||
<NodeTooltip>
|
||||
<NodeTooltipTrigger>
|
||||
<BaseNode className="w-full h-full group flex flex-col items-center justify-center">
|
||||
<BaseNodeHeader className="flex-col gap-1 items-center py-2 px-2">
|
||||
<Icon size={iconPx} style={{ color }} />
|
||||
<BaseNodeHeaderTitle className="text-center leading-tight" style={{ fontSize: labelPx }}>
|
||||
{nodeData.label}
|
||||
</BaseNodeHeaderTitle>
|
||||
</BaseNodeHeader>
|
||||
{ip && (
|
||||
<BaseNodeContent className="items-center pt-0 pb-1">
|
||||
<span className="font-mono text-muted-foreground" style={{ fontSize: ipPx }}>{ip}</span>
|
||||
</BaseNodeContent>
|
||||
)}
|
||||
<BaseHandle type="target" position={Position.Top} />
|
||||
<BaseHandle type="source" position={Position.Bottom} />
|
||||
<BaseHandle type="target" position={Position.Left} id="left" />
|
||||
<BaseHandle type="source" position={Position.Right} id="right" />
|
||||
</BaseNode>
|
||||
</NodeTooltipTrigger>
|
||||
{hasTooltipContent && (
|
||||
<NodeTooltipContent position={Position.Top}>
|
||||
<div className="flex flex-col gap-1 min-w-[140px]">
|
||||
<TooltipRow label="Host" value={props.hostname} />
|
||||
<TooltipRow label="IP" value={props.ip} />
|
||||
{(props.vendor || props.model) && (
|
||||
<TooltipRow label="HW" value={[props.vendor, props.model].filter(Boolean).join(' ')} />
|
||||
)}
|
||||
<TooltipRow label="Role" value={props.role} />
|
||||
{props.notes && (
|
||||
<TooltipRow label="Notes" value={props.notes.length > 100 ? props.notes.slice(0, 100) + '...' : props.notes} />
|
||||
)}
|
||||
</div>
|
||||
</NodeTooltipContent>
|
||||
)}
|
||||
</NodeTooltip>
|
||||
</NodeStatusIndicator>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export const DeviceNode = memo(DeviceNodeComponent)
|
||||
113
frontend/src/components/network/nodes/deviceRegistry.ts
Normal file
113
frontend/src/components/network/nodes/deviceRegistry.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import type { LucideIcon } from 'lucide-react'
|
||||
import {
|
||||
Router, Network, BrickWallFire, Wifi, Server, Monitor, Boxes, Package, Cloud,
|
||||
Printer, Smartphone, HardDrive, Gauge, Database, CloudCog,
|
||||
Cpu, Tablet, Laptop, BatteryCharging, RectangleVertical,
|
||||
Cable, Camera, KeyRound, Globe, Video, PlugZap, Radio,
|
||||
} from 'lucide-react'
|
||||
|
||||
export interface DeviceRenderConfig {
|
||||
icon: LucideIcon
|
||||
color: string
|
||||
}
|
||||
|
||||
// Category-semantic color palette — each color carries meaning:
|
||||
// Network (blue) — backbone connectivity layer
|
||||
// Security (orange) — critical/protective elements
|
||||
// Compute (emerald)— running workloads and VMs
|
||||
// Endpoint (amber) — user-facing devices
|
||||
// Storage (violet) — data at rest
|
||||
// Cloud (cyan) — external/internet-connected
|
||||
// Infra (steel) — physical/passive hardware
|
||||
export const NETWORK_COLOR = '#60a5fa'
|
||||
export const SECURITY_COLOR = '#f87171'
|
||||
export const COMPUTE_COLOR = '#34d399'
|
||||
export const ENDPOINT_COLOR = '#fbbf24'
|
||||
export const STORAGE_COLOR = '#a78bfa'
|
||||
export const CLOUD_COLOR = '#67e8f9'
|
||||
export const INFRA_COLOR = '#94a3b8'
|
||||
|
||||
const SYSTEM_DEVICE_ICONS: Record<string, DeviceRenderConfig> = {
|
||||
// Network layer
|
||||
'router': { icon: Router, color: NETWORK_COLOR },
|
||||
'switch': { icon: Network, color: NETWORK_COLOR },
|
||||
'access-point': { icon: Wifi, color: NETWORK_COLOR },
|
||||
'load-balancer': { icon: Gauge, color: NETWORK_COLOR },
|
||||
|
||||
// Security
|
||||
'firewall': { icon: BrickWallFire, color: SECURITY_COLOR },
|
||||
'badge-reader': { icon: KeyRound, color: SECURITY_COLOR },
|
||||
|
||||
// Compute
|
||||
'server': { icon: Server, color: COMPUTE_COLOR },
|
||||
'vm': { icon: Boxes, color: COMPUTE_COLOR },
|
||||
'container': { icon: Package, color: COMPUTE_COLOR },
|
||||
|
||||
// Storage
|
||||
'nas': { icon: Database, color: STORAGE_COLOR },
|
||||
'san': { icon: HardDrive, color: STORAGE_COLOR },
|
||||
'cloud-storage': { icon: CloudCog, color: STORAGE_COLOR },
|
||||
|
||||
// Cloud / Internet
|
||||
'cloud': { icon: Cloud, color: CLOUD_COLOR },
|
||||
'aws': { icon: Cloud, color: CLOUD_COLOR },
|
||||
'azure': { icon: Cloud, color: CLOUD_COLOR },
|
||||
'gcp': { icon: Cloud, color: CLOUD_COLOR },
|
||||
'isp': { icon: Globe, color: CLOUD_COLOR },
|
||||
|
||||
// Endpoints
|
||||
'workstation': { icon: Monitor, color: ENDPOINT_COLOR },
|
||||
'laptop': { icon: Laptop, color: ENDPOINT_COLOR },
|
||||
'tablet': { icon: Tablet, color: ENDPOINT_COLOR },
|
||||
'phone': { icon: Smartphone, color: ENDPOINT_COLOR },
|
||||
'printer': { icon: Printer, color: ENDPOINT_COLOR },
|
||||
|
||||
// Infrastructure / physical
|
||||
'ups': { icon: BatteryCharging, color: INFRA_COLOR },
|
||||
'pdu': { icon: PlugZap, color: INFRA_COLOR },
|
||||
'rack': { icon: RectangleVertical, color: INFRA_COLOR },
|
||||
'patch-panel': { icon: Cable, color: INFRA_COLOR },
|
||||
'camera': { icon: Camera, color: INFRA_COLOR },
|
||||
'nvr': { icon: Video, color: INFRA_COLOR },
|
||||
'iot': { icon: Radio, color: INFRA_COLOR },
|
||||
}
|
||||
|
||||
const CATEGORY_DEFAULTS: Record<string, DeviceRenderConfig> = {
|
||||
'network': { icon: Router, color: NETWORK_COLOR },
|
||||
'compute': { icon: Server, color: COMPUTE_COLOR },
|
||||
'storage': { icon: Database, color: STORAGE_COLOR },
|
||||
'cloud': { icon: Cloud, color: CLOUD_COLOR },
|
||||
'endpoint': { icon: Monitor, color: ENDPOINT_COLOR },
|
||||
'infrastructure': { icon: PlugZap, color: INFRA_COLOR },
|
||||
'security': { icon: BrickWallFire, color: SECURITY_COLOR },
|
||||
}
|
||||
|
||||
const FALLBACK: DeviceRenderConfig = { icon: Cpu, color: INFRA_COLOR }
|
||||
|
||||
export function getDeviceRenderConfig(slug: string, category?: string): DeviceRenderConfig {
|
||||
if (SYSTEM_DEVICE_ICONS[slug]) return SYSTEM_DEVICE_ICONS[slug]
|
||||
if (category && CATEGORY_DEFAULTS[category]) return CATEGORY_DEFAULTS[category]
|
||||
return FALLBACK
|
||||
}
|
||||
|
||||
export const CATEGORY_LABELS: Record<string, string> = {
|
||||
'network': 'Network',
|
||||
'compute': 'Compute',
|
||||
'storage': 'Storage',
|
||||
'cloud': 'Cloud',
|
||||
'endpoint': 'Endpoints',
|
||||
'infrastructure': 'Infrastructure',
|
||||
'security': 'Security',
|
||||
}
|
||||
|
||||
export const CATEGORY_COLORS: Record<string, string> = {
|
||||
'network': NETWORK_COLOR,
|
||||
'compute': COMPUTE_COLOR,
|
||||
'storage': STORAGE_COLOR,
|
||||
'cloud': CLOUD_COLOR,
|
||||
'endpoint': ENDPOINT_COLOR,
|
||||
'infrastructure': INFRA_COLOR,
|
||||
'security': SECURITY_COLOR,
|
||||
}
|
||||
|
||||
export const CATEGORY_ORDER = ['network', 'compute', 'storage', 'cloud', 'endpoint', 'infrastructure', 'security']
|
||||
7
frontend/src/components/network/nodes/nodeTypes.ts
Normal file
7
frontend/src/components/network/nodes/nodeTypes.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { DeviceNode } from './DeviceNode'
|
||||
import { GroupNode } from '../ui/labeled-group-node'
|
||||
|
||||
export const nodeTypes = {
|
||||
device: DeviceNode,
|
||||
group: GroupNode,
|
||||
}
|
||||
168
frontend/src/components/network/panels/AIAssistPanel.tsx
Normal file
168
frontend/src/components/network/panels/AIAssistPanel.tsx
Normal file
@@ -0,0 +1,168 @@
|
||||
import { useState, useCallback } from 'react'
|
||||
import { Sparkles, ChevronUp, ChevronDown, AlertTriangle } from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { networkDiagramsApi } from '@/api'
|
||||
import type { AIGenerateResponse } from '@/types'
|
||||
|
||||
interface AIAssistPanelProps {
|
||||
onGenerate: (result: AIGenerateResponse, mode: 'replace' | 'merge') => void
|
||||
getExistingBounds: () => { minX: number; maxX: number; minY: number; maxY: number } | null
|
||||
hasNodes: boolean
|
||||
}
|
||||
|
||||
export function AIAssistPanel({ onGenerate, getExistingBounds, hasNodes }: AIAssistPanelProps) {
|
||||
const [expanded, setExpanded] = useState(false)
|
||||
const [description, setDescription] = useState('')
|
||||
const [mode, setMode] = useState<'replace' | 'merge'>('replace')
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [replaceConfirm, setReplaceConfirm] = useState(false)
|
||||
|
||||
const handleGenerate = useCallback(async () => {
|
||||
if (!description.trim()) return
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
setReplaceConfirm(false)
|
||||
try {
|
||||
const result = await networkDiagramsApi.aiGenerate({
|
||||
description: description.trim(),
|
||||
mode,
|
||||
existingBounds: mode === 'merge' ? getExistingBounds() : null,
|
||||
})
|
||||
onGenerate(result, mode)
|
||||
setDescription('')
|
||||
setExpanded(false)
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : 'Generation failed. Please try again.'
|
||||
setError(msg)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [description, mode, onGenerate, getExistingBounds])
|
||||
|
||||
// Reset confirm state when mode changes or panel collapses
|
||||
const handleModeChange = (newMode: 'replace' | 'merge') => {
|
||||
setMode(newMode)
|
||||
setReplaceConfirm(false)
|
||||
}
|
||||
|
||||
const needsReplaceConfirm = mode === 'replace' && hasNodes
|
||||
|
||||
if (!expanded) {
|
||||
return (
|
||||
<div className="border-t border-default bg-card">
|
||||
<button
|
||||
onClick={() => setExpanded(true)}
|
||||
className="flex w-full items-center justify-center gap-2 px-4 py-2 text-xs text-muted-foreground hover:text-primary"
|
||||
>
|
||||
<Sparkles size={14} />
|
||||
AI Generate
|
||||
<ChevronUp size={14} />
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border-t border-default bg-card">
|
||||
<div className="flex items-center justify-between border-b border-default px-4 py-2">
|
||||
<div className="flex items-center gap-2 text-xs font-medium text-heading">
|
||||
<Sparkles size={14} />
|
||||
AI Generate
|
||||
</div>
|
||||
<button
|
||||
onClick={() => { setExpanded(false); setReplaceConfirm(false) }}
|
||||
className="text-muted-foreground hover:text-primary"
|
||||
>
|
||||
<ChevronDown size={14} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3 p-4">
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={() => handleModeChange('replace')}
|
||||
className={cn(
|
||||
'rounded px-3 py-1 text-xs font-medium transition-colors',
|
||||
mode === 'replace'
|
||||
? 'bg-accent text-white'
|
||||
: 'border border-default text-muted-foreground hover:text-primary',
|
||||
)}
|
||||
>
|
||||
Generate New
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleModeChange('merge')}
|
||||
className={cn(
|
||||
'rounded px-3 py-1 text-xs font-medium transition-colors',
|
||||
mode === 'merge'
|
||||
? 'bg-accent text-white'
|
||||
: 'border border-default text-muted-foreground hover:text-primary',
|
||||
)}
|
||||
>
|
||||
Add to Diagram
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{needsReplaceConfirm && (
|
||||
<div className="flex items-start gap-2 rounded border border-yellow-500/30 bg-yellow-500/5 px-3 py-2">
|
||||
<AlertTriangle size={14} className="mt-0.5 shrink-0 text-yellow-400" />
|
||||
<p className="text-[11px] text-yellow-400">
|
||||
This will replace your current diagram. Save first if needed.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<textarea
|
||||
value={description}
|
||||
onChange={e => setDescription(e.target.value)}
|
||||
placeholder="Describe the network you want to create... e.g. 'Small office with a firewall, core switch, 3 access points, and a file server'"
|
||||
rows={3}
|
||||
disabled={loading}
|
||||
className="w-full resize-none rounded border border-default bg-input px-3 py-2 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none disabled:opacity-50"
|
||||
/>
|
||||
|
||||
{error && <p className="text-[11px] text-red-400">{error}</p>}
|
||||
|
||||
{loading ? (
|
||||
<div className="flex items-center justify-center gap-2 py-2">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-accent border-t-transparent" />
|
||||
<span className="text-xs text-muted-foreground">Generating your network diagram…</span>
|
||||
</div>
|
||||
) : needsReplaceConfirm && !replaceConfirm ? (
|
||||
<button
|
||||
onClick={() => setReplaceConfirm(true)}
|
||||
disabled={!description.trim()}
|
||||
className="rounded border border-yellow-500/40 bg-yellow-500/10 px-4 py-2 text-xs font-medium text-yellow-400 hover:bg-yellow-500/20 disabled:opacity-50"
|
||||
>
|
||||
Replace Diagram…
|
||||
</button>
|
||||
) : needsReplaceConfirm && replaceConfirm ? (
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={() => setReplaceConfirm(false)}
|
||||
className="flex-1 rounded border border-default px-3 py-2 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onClick={handleGenerate}
|
||||
disabled={!description.trim()}
|
||||
className="flex-1 rounded bg-red-500/20 px-3 py-2 text-xs font-medium text-red-400 hover:bg-red-500/30 disabled:opacity-50"
|
||||
>
|
||||
Yes, Replace
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={handleGenerate}
|
||||
disabled={!description.trim()}
|
||||
className="rounded bg-accent px-4 py-2 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
|
||||
>
|
||||
Generate
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
227
frontend/src/components/network/panels/DeviceToolbar.tsx
Normal file
227
frontend/src/components/network/panels/DeviceToolbar.tsx
Normal file
@@ -0,0 +1,227 @@
|
||||
import { useState, useMemo, useCallback } from 'react'
|
||||
import { Search, Plus, ChevronDown, ChevronRight, X, LayoutGrid, GripVertical, Globe } from 'lucide-react'
|
||||
import { getDeviceRenderConfig, CATEGORY_LABELS, CATEGORY_ORDER } from '../nodes/deviceRegistry'
|
||||
import type { DeviceTypeResponse, DeviceTypeCreate } from '@/types'
|
||||
import { deviceTypesApi } from '@/api'
|
||||
|
||||
interface DeviceToolbarProps {
|
||||
deviceTypes: DeviceTypeResponse[]
|
||||
onDeviceTypesChange: () => void
|
||||
}
|
||||
|
||||
export function DeviceToolbar({ deviceTypes, onDeviceTypesChange }: DeviceToolbarProps) {
|
||||
const [search, setSearch] = useState('')
|
||||
const [collapsedCategories, setCollapsedCategories] = useState<Set<string>>(new Set())
|
||||
const [showAddForm, setShowAddForm] = useState(false)
|
||||
const [newType, setNewType] = useState<DeviceTypeCreate>({ slug: '', label: '', category: 'network' })
|
||||
const [addError, setAddError] = useState<string | null>(null)
|
||||
const [addLoading, setAddLoading] = useState(false)
|
||||
|
||||
const filteredByCategory = useMemo(() => {
|
||||
const lower = search.toLowerCase()
|
||||
const filtered = search
|
||||
? deviceTypes.filter(dt => dt.label.toLowerCase().includes(lower) || dt.slug.toLowerCase().includes(lower))
|
||||
: deviceTypes
|
||||
|
||||
const grouped: Record<string, DeviceTypeResponse[]> = {}
|
||||
for (const dt of filtered) {
|
||||
if (!grouped[dt.category]) grouped[dt.category] = []
|
||||
grouped[dt.category].push(dt)
|
||||
}
|
||||
return grouped
|
||||
}, [deviceTypes, search])
|
||||
|
||||
const toggleCategory = useCallback((cat: string) => {
|
||||
setCollapsedCategories(prev => {
|
||||
const next = new Set(prev)
|
||||
if (next.has(cat)) next.delete(cat)
|
||||
else next.add(cat)
|
||||
return next
|
||||
})
|
||||
}, [])
|
||||
|
||||
const handleDragStart = useCallback((e: React.DragEvent, deviceType: DeviceTypeResponse) => {
|
||||
e.dataTransfer.setData('application/reactflow-device', JSON.stringify({
|
||||
slug: deviceType.slug,
|
||||
label: deviceType.label,
|
||||
category: deviceType.category,
|
||||
}))
|
||||
e.dataTransfer.effectAllowed = 'move'
|
||||
}, [])
|
||||
|
||||
const handleAddType = useCallback(async () => {
|
||||
if (!newType.slug || !newType.label) {
|
||||
setAddError('Slug and label are required')
|
||||
return
|
||||
}
|
||||
setAddLoading(true)
|
||||
setAddError(null)
|
||||
try {
|
||||
await deviceTypesApi.create(newType)
|
||||
setNewType({ slug: '', label: '', category: 'network' })
|
||||
setShowAddForm(false)
|
||||
onDeviceTypesChange()
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : 'Failed to create device type'
|
||||
setAddError(msg)
|
||||
} finally {
|
||||
setAddLoading(false)
|
||||
}
|
||||
}, [newType, onDeviceTypesChange])
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-[200px] flex-col border-r border-default bg-sidebar">
|
||||
<div className="relative p-2">
|
||||
<Search size={14} className="absolute left-4 top-1/2 -translate-y-1/2 text-muted-foreground" />
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search devices..."
|
||||
value={search}
|
||||
onChange={e => setSearch(e.target.value)}
|
||||
className="w-full rounded-md border border-default bg-input pl-8 pr-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-y-auto px-2 pb-2">
|
||||
{CATEGORY_ORDER.map(cat => {
|
||||
const items = filteredByCategory[cat] || []
|
||||
const isCloud = cat === 'cloud'
|
||||
const ispMatchesSearch = !search || 'isp'.includes(search.toLowerCase()) || 'internet service provider'.includes(search.toLowerCase())
|
||||
const showIsp = isCloud && ispMatchesSearch
|
||||
if (!items.length && !showIsp) return null
|
||||
const collapsed = collapsedCategories.has(cat)
|
||||
const totalCount = items.length + (showIsp ? 1 : 0)
|
||||
|
||||
return (
|
||||
<div key={cat} className="mb-1">
|
||||
<button
|
||||
onClick={() => toggleCategory(cat)}
|
||||
className="flex w-full items-center gap-1 rounded px-1 py-1 text-[10px] font-semibold uppercase tracking-wider text-muted-foreground hover:text-primary"
|
||||
>
|
||||
{collapsed ? <ChevronRight size={12} /> : <ChevronDown size={12} />}
|
||||
{CATEGORY_LABELS[cat] || cat}
|
||||
<span className="ml-auto text-[10px] font-normal">{totalCount}</span>
|
||||
</button>
|
||||
{!collapsed && (
|
||||
<div className="flex flex-col gap-0.5">
|
||||
{items.map(dt => {
|
||||
const { icon: Icon, color } = getDeviceRenderConfig(dt.slug, dt.category)
|
||||
return (
|
||||
<div
|
||||
key={dt.id}
|
||||
draggable
|
||||
onDragStart={e => handleDragStart(e, dt)}
|
||||
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
|
||||
>
|
||||
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
|
||||
<Icon size={14} style={{ color }} />
|
||||
<span>{dt.label}</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
{showIsp && (
|
||||
<div
|
||||
draggable
|
||||
onDragStart={e => {
|
||||
e.dataTransfer.setData('application/reactflow-device', JSON.stringify({
|
||||
slug: 'isp',
|
||||
label: 'ISP',
|
||||
category: 'cloud',
|
||||
}))
|
||||
e.dataTransfer.effectAllowed = 'move'
|
||||
}}
|
||||
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
|
||||
>
|
||||
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
|
||||
<Globe size={14} style={{ color: 'var(--color-accent)' }} />
|
||||
<span>ISP</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Grouping section */}
|
||||
<div className="mb-1 mt-2 border-t border-default pt-2">
|
||||
<div className="flex items-center gap-1 px-1 py-1 text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
|
||||
Grouping
|
||||
</div>
|
||||
<div className="flex flex-col gap-0.5">
|
||||
{[
|
||||
{ slug: 'subnet', label: 'Subnet' },
|
||||
{ slug: 'vlan', label: 'VLAN' },
|
||||
{ slug: 'site', label: 'Site' },
|
||||
{ slug: 'dmz', label: 'DMZ' },
|
||||
].map(item => (
|
||||
<div
|
||||
key={item.slug}
|
||||
draggable
|
||||
onDragStart={e => {
|
||||
e.dataTransfer.setData('application/reactflow-group', JSON.stringify(item))
|
||||
e.dataTransfer.effectAllowed = 'move'
|
||||
}}
|
||||
className="flex cursor-grab items-center gap-2 rounded px-2 py-1.5 text-xs text-primary hover:bg-elevated active:cursor-grabbing active:scale-[0.98] transition-transform"
|
||||
>
|
||||
<GripVertical size={12} className="shrink-0 text-muted-foreground/50" />
|
||||
<LayoutGrid size={14} className="text-muted-foreground" />
|
||||
<span>{item.label}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="border-t border-default p-2">
|
||||
{!showAddForm ? (
|
||||
<button
|
||||
onClick={() => setShowAddForm(true)}
|
||||
className="flex w-full items-center justify-center gap-1 rounded border border-default px-2 py-1.5 text-xs text-muted-foreground hover:border-hover hover:text-primary"
|
||||
>
|
||||
<Plus size={12} />
|
||||
Custom Type
|
||||
</button>
|
||||
) : (
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">New Type</span>
|
||||
<button onClick={() => { setShowAddForm(false); setAddError(null) }} className="text-muted-foreground hover:text-primary">
|
||||
<X size={12} />
|
||||
</button>
|
||||
</div>
|
||||
<input
|
||||
placeholder="slug (e.g. pacs-server)"
|
||||
value={newType.slug}
|
||||
onChange={e => setNewType(prev => ({ ...prev, slug: e.target.value.toLowerCase().replace(/[^a-z0-9-]/g, '') }))}
|
||||
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
|
||||
/>
|
||||
<input
|
||||
placeholder="Label (e.g. PACS Server)"
|
||||
value={newType.label}
|
||||
onChange={e => setNewType(prev => ({ ...prev, label: e.target.value }))}
|
||||
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
|
||||
/>
|
||||
<select
|
||||
value={newType.category}
|
||||
onChange={e => setNewType(prev => ({ ...prev, category: e.target.value }))}
|
||||
className="rounded border border-default bg-input px-2 py-1 text-xs text-primary focus:border-accent focus:outline-none"
|
||||
>
|
||||
{CATEGORY_ORDER.map(c => (
|
||||
<option key={c} value={c}>{CATEGORY_LABELS[c]}</option>
|
||||
))}
|
||||
</select>
|
||||
{addError && <p className="text-[10px] text-red-400">{addError}</p>}
|
||||
<button
|
||||
onClick={handleAddType}
|
||||
disabled={addLoading}
|
||||
className="rounded bg-accent px-2 py-1 text-xs font-medium text-white hover:bg-accent/90 disabled:opacity-50"
|
||||
>
|
||||
{addLoading ? 'Adding...' : 'Add Type'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
412
frontend/src/components/network/panels/PropertiesPanel.tsx
Normal file
412
frontend/src/components/network/panels/PropertiesPanel.tsx
Normal file
@@ -0,0 +1,412 @@
|
||||
import { useCallback, useState, useEffect } from 'react'
|
||||
import { Trash2, Minus, Spline, GitBranch, BringToFront, SendToBack } from 'lucide-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import type { DeviceProperties, DiagramEdge } from '@/types'
|
||||
import type { Node, Edge } from '@xyflow/react'
|
||||
import type { DeviceNodeData } from '../nodes/DeviceNode'
|
||||
|
||||
interface PropertiesPanelProps {
|
||||
selectedNode: Node | null
|
||||
selectedEdge: Edge | null
|
||||
onNodeUpdate: (nodeId: string, data: Partial<DeviceNodeData>) => void
|
||||
onEdgeUpdate: (edgeId: string, data: Partial<DiagramEdge>) => void
|
||||
onEdgeTypeChange: (edgeId: string, edgeType: string) => void
|
||||
onBringToFront: (nodeId: string) => void
|
||||
onSendToBack: (nodeId: string) => void
|
||||
onDeleteNode: (nodeId: string) => void
|
||||
onDeleteEdge: (edgeId: string) => void
|
||||
}
|
||||
|
||||
type NodeStatus = 'online' | 'offline' | 'degraded' | 'unknown'
|
||||
|
||||
const STATUS_CONFIG: Record<NodeStatus, { color: string; label: string }> = {
|
||||
online: { color: '#34d399', label: 'Online' },
|
||||
offline: { color: '#f87171', label: 'Offline' },
|
||||
degraded: { color: '#fbbf24', label: 'Degraded' },
|
||||
unknown: { color: '#94a3b8', label: 'Unknown' },
|
||||
}
|
||||
|
||||
const STATUS_OPTIONS = Object.keys(STATUS_CONFIG) as NodeStatus[]
|
||||
const CONNECTION_TYPE_OPTIONS = ['ethernet', 'fiber', 'wifi', 'vpn', 'vlan', 'wan'] as const
|
||||
|
||||
function FieldLabel({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<label className="text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
|
||||
{children}
|
||||
</label>
|
||||
)
|
||||
}
|
||||
|
||||
function FieldInput({ value, onChange, placeholder, mono }: {
|
||||
value: string
|
||||
onChange: (val: string) => void
|
||||
placeholder?: string
|
||||
mono?: boolean
|
||||
}) {
|
||||
return (
|
||||
<input
|
||||
type="text"
|
||||
value={value}
|
||||
onChange={e => onChange(e.target.value)}
|
||||
placeholder={placeholder}
|
||||
className={cn(
|
||||
'w-full rounded border border-default bg-input px-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none',
|
||||
mono && 'font-mono',
|
||||
)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function SectionDivider({ label }: { label: string }) {
|
||||
return (
|
||||
<div className="flex items-center gap-2 pt-1">
|
||||
<span className="whitespace-nowrap text-[10px] font-semibold uppercase tracking-wider text-muted-foreground">
|
||||
{label}
|
||||
</span>
|
||||
<div className="flex-1 border-t border-default" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export function PropertiesPanel({
|
||||
selectedNode,
|
||||
selectedEdge,
|
||||
onNodeUpdate,
|
||||
onEdgeUpdate,
|
||||
onEdgeTypeChange,
|
||||
onBringToFront,
|
||||
onSendToBack,
|
||||
onDeleteNode,
|
||||
onDeleteEdge,
|
||||
}: PropertiesPanelProps) {
|
||||
const [deleteConfirm, setDeleteConfirm] = useState(false)
|
||||
|
||||
// Reset confirm state whenever the selection changes
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
useEffect(() => { setDeleteConfirm(false) }, [selectedNode?.id, selectedEdge?.id])
|
||||
|
||||
const handlePropertyChange = useCallback((field: keyof DeviceProperties, value: string) => {
|
||||
if (!selectedNode) return
|
||||
const nodeData = selectedNode.data as unknown as DeviceNodeData
|
||||
onNodeUpdate(selectedNode.id, {
|
||||
properties: { ...nodeData.properties, [field]: value },
|
||||
} as Partial<DeviceNodeData>)
|
||||
}, [selectedNode, onNodeUpdate])
|
||||
|
||||
const handleLabelChange = useCallback((value: string) => {
|
||||
if (!selectedNode) return
|
||||
onNodeUpdate(selectedNode.id, { label: value } as Partial<DeviceNodeData>)
|
||||
}, [selectedNode, onNodeUpdate])
|
||||
|
||||
if (!selectedNode && !selectedEdge) {
|
||||
return (
|
||||
<div className="flex h-full w-[260px] flex-col items-center justify-center border-l border-default bg-sidebar px-4">
|
||||
<p className="text-center text-xs text-muted-foreground">
|
||||
Select a device or connection to edit its properties
|
||||
</p>
|
||||
<p className="mt-1 text-center text-[10px] text-muted-foreground/60">
|
||||
Hover a device to preview its info
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (selectedEdge) {
|
||||
const edgeData = (selectedEdge.data || {}) as Record<string, unknown>
|
||||
const connectionType = (edgeData.connectionType as string) || 'ethernet'
|
||||
const isCustomType = !CONNECTION_TYPE_OPTIONS.includes(connectionType as typeof CONNECTION_TYPE_OPTIONS[number])
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-[260px] flex-col border-l border-default bg-sidebar">
|
||||
<div className="border-b border-default px-3 py-2">
|
||||
<h3 className="text-xs font-semibold text-heading">Connection</h3>
|
||||
</div>
|
||||
<div className="flex flex-1 flex-col gap-3 overflow-y-auto p-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Label</FieldLabel>
|
||||
<FieldInput
|
||||
value={(selectedEdge.label as string) || ''}
|
||||
onChange={val => onEdgeUpdate(selectedEdge.id, { label: val || null })}
|
||||
placeholder="Connection label"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Type</FieldLabel>
|
||||
<select
|
||||
value={isCustomType ? '__custom__' : connectionType}
|
||||
onChange={e => {
|
||||
const val = e.target.value
|
||||
if (val !== '__custom__') {
|
||||
onEdgeUpdate(selectedEdge.id, { connectionType: val })
|
||||
}
|
||||
}}
|
||||
className="w-full rounded border border-default bg-input px-2 py-1.5 text-xs text-primary focus:border-accent focus:outline-none"
|
||||
>
|
||||
{CONNECTION_TYPE_OPTIONS.map(opt => (
|
||||
<option key={opt} value={opt}>{opt.charAt(0).toUpperCase() + opt.slice(1)}</option>
|
||||
))}
|
||||
<option value="__custom__">Custom…</option>
|
||||
</select>
|
||||
{isCustomType && (
|
||||
<FieldInput
|
||||
value={connectionType}
|
||||
onChange={val => onEdgeUpdate(selectedEdge.id, { connectionType: val })}
|
||||
placeholder="Custom type name"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Speed</FieldLabel>
|
||||
<FieldInput
|
||||
value={(edgeData.speed as string) || ''}
|
||||
onChange={val => onEdgeUpdate(selectedEdge.id, { speed: val || null })}
|
||||
placeholder="e.g. 1 Gbps"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Notes</FieldLabel>
|
||||
<FieldInput
|
||||
value={(edgeData.notes as string) || ''}
|
||||
onChange={val => onEdgeUpdate(selectedEdge.id, { notes: val || null })}
|
||||
placeholder="Port info, cable type…"
|
||||
mono
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Line Style</FieldLabel>
|
||||
<div className="flex gap-1">
|
||||
{([
|
||||
{ value: null, icon: Minus, label: 'Straight' },
|
||||
{ value: 'curved', icon: Spline, label: 'Curved' },
|
||||
{ value: 'step', icon: GitBranch, label: 'Step' },
|
||||
] as const).map(({ value, icon: Icon, label }) => {
|
||||
const routing = (edgeData.routing as string | null | undefined) ?? null
|
||||
const active = routing === value
|
||||
return (
|
||||
<button
|
||||
key={label}
|
||||
title={label}
|
||||
onClick={() => onEdgeUpdate(selectedEdge.id, { routing: value })}
|
||||
className={cn(
|
||||
'flex flex-1 items-center justify-center gap-1 rounded border py-1.5 text-[10px] transition-colors',
|
||||
active
|
||||
? 'border-accent bg-accent/10 text-accent'
|
||||
: 'border-default text-muted-foreground hover:border-hover hover:text-primary',
|
||||
)}
|
||||
>
|
||||
<Icon size={12} />
|
||||
{label}
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<FieldLabel>Show Traffic</FieldLabel>
|
||||
<button
|
||||
onClick={() => {
|
||||
const newType = selectedEdge.type === 'animated' ? 'connection' : 'animated'
|
||||
onEdgeTypeChange(selectedEdge.id, newType)
|
||||
}}
|
||||
className={cn(
|
||||
'relative h-5 w-9 rounded-full transition-colors',
|
||||
selectedEdge.type === 'animated' ? 'bg-accent' : 'bg-elevated',
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
'absolute top-0.5 left-0.5 h-4 w-4 rounded-full bg-white transition-transform',
|
||||
selectedEdge.type === 'animated' && 'translate-x-4',
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="border-t border-default p-3">
|
||||
{deleteConfirm ? (
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<p className="text-center text-[10px] text-muted-foreground">Delete this connection?</p>
|
||||
<div className="flex gap-1.5">
|
||||
<button
|
||||
onClick={() => setDeleteConfirm(false)}
|
||||
className="flex-1 rounded border border-default px-2 py-1.5 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onClick={() => onDeleteEdge(selectedEdge.id)}
|
||||
className="flex-1 rounded bg-red-500/20 px-2 py-1.5 text-xs font-medium text-red-400 hover:bg-red-500/30"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => setDeleteConfirm(true)}
|
||||
className="flex w-full items-center justify-center gap-1.5 rounded border border-red-500/30 px-2 py-1.5 text-xs text-red-400 hover:bg-red-500/10"
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
Delete Connection
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const nodeData = selectedNode!.data as unknown as DeviceNodeData
|
||||
const props = nodeData.properties || {} as DeviceProperties
|
||||
const currentStatus = (props.status || 'unknown') as NodeStatus
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-[260px] flex-col border-l border-default bg-sidebar">
|
||||
<div className="border-b border-default px-3 py-2">
|
||||
<h3 className="text-xs font-semibold text-heading">Device Properties</h3>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-1 flex-col gap-3 overflow-y-auto p-3">
|
||||
|
||||
{/* Identity */}
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Name</FieldLabel>
|
||||
<FieldInput value={nodeData.label} onChange={handleLabelChange} placeholder="Device name" />
|
||||
</div>
|
||||
|
||||
{/* Layering */}
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Layer</FieldLabel>
|
||||
<div className="flex gap-1.5">
|
||||
<button
|
||||
onClick={() => onBringToFront(selectedNode!.id)}
|
||||
title="Bring to Front ]"
|
||||
className="flex flex-1 items-center justify-center gap-1.5 rounded border border-default px-2 py-1.5 text-[10px] text-muted-foreground hover:border-hover hover:text-primary"
|
||||
>
|
||||
<BringToFront size={12} />
|
||||
Bring Front
|
||||
</button>
|
||||
<button
|
||||
onClick={() => onSendToBack(selectedNode!.id)}
|
||||
title="Send to Back ["
|
||||
className="flex flex-1 items-center justify-center gap-1.5 rounded border border-default px-2 py-1.5 text-[10px] text-muted-foreground hover:border-hover hover:text-primary"
|
||||
>
|
||||
<SendToBack size={12} />
|
||||
Send Back
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Status badge grid */}
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<FieldLabel>Status</FieldLabel>
|
||||
<div className="grid grid-cols-2 gap-1">
|
||||
{STATUS_OPTIONS.map(opt => {
|
||||
const { color, label } = STATUS_CONFIG[opt]
|
||||
const active = currentStatus === opt
|
||||
return (
|
||||
<button
|
||||
key={opt}
|
||||
onClick={() => handlePropertyChange('status', opt)}
|
||||
className={cn(
|
||||
'flex items-center justify-center gap-1.5 rounded border py-1.5 text-[10px] font-medium transition-colors',
|
||||
active
|
||||
? 'border-transparent text-white'
|
||||
: 'border-default text-muted-foreground hover:border-hover hover:text-primary',
|
||||
)}
|
||||
style={active ? { backgroundColor: color } : undefined}
|
||||
>
|
||||
<span
|
||||
className="h-1.5 w-1.5 shrink-0 rounded-full"
|
||||
style={{ backgroundColor: active ? 'rgba(255,255,255,0.8)' : color }}
|
||||
/>
|
||||
{label}
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Network section */}
|
||||
<SectionDivider label="Network" />
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>IP Address</FieldLabel>
|
||||
<FieldInput value={props.ip || ''} onChange={v => handlePropertyChange('ip', v)} placeholder="e.g. 10.0.0.1" mono />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Subnet</FieldLabel>
|
||||
<FieldInput value={props.subnet || ''} onChange={v => handlePropertyChange('subnet', v)} placeholder="e.g. 10.0.0.0/24" mono />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>VLAN</FieldLabel>
|
||||
<FieldInput value={props.vlan || ''} onChange={v => handlePropertyChange('vlan', v)} placeholder="e.g. 10" mono />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Hardware section */}
|
||||
<SectionDivider label="Hardware" />
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Hostname</FieldLabel>
|
||||
<FieldInput value={props.hostname || ''} onChange={v => handlePropertyChange('hostname', v)} placeholder="e.g. core-rtr-01" mono />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Vendor</FieldLabel>
|
||||
<FieldInput value={props.vendor || ''} onChange={v => handlePropertyChange('vendor', v)} placeholder="e.g. Cisco" />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Model</FieldLabel>
|
||||
<FieldInput value={props.model || ''} onChange={v => handlePropertyChange('model', v)} placeholder="e.g. ISR 4331" />
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<FieldLabel>Role</FieldLabel>
|
||||
<FieldInput value={props.role || ''} onChange={v => handlePropertyChange('role', v)} placeholder="e.g. Core gateway" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Notes */}
|
||||
<SectionDivider label="Notes" />
|
||||
<div className="flex flex-col gap-1">
|
||||
<textarea
|
||||
value={props.notes || ''}
|
||||
onChange={e => handlePropertyChange('notes', e.target.value)}
|
||||
placeholder="Additional notes…"
|
||||
rows={3}
|
||||
className="w-full resize-none rounded border border-default bg-input px-2 py-1.5 text-xs text-primary placeholder:text-muted-foreground focus:border-accent focus:outline-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<div className="border-t border-default p-3">
|
||||
{deleteConfirm ? (
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<p className="text-center text-[10px] text-muted-foreground">Delete this device?</p>
|
||||
<div className="flex gap-1.5">
|
||||
<button
|
||||
onClick={() => setDeleteConfirm(false)}
|
||||
className="flex-1 rounded border border-default px-2 py-1.5 text-xs text-primary hover:bg-elevated"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onClick={() => onDeleteNode(selectedNode!.id)}
|
||||
className="flex-1 rounded bg-red-500/20 px-2 py-1.5 text-xs font-medium text-red-400 hover:bg-red-500/30"
|
||||
>
|
||||
Delete
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => setDeleteConfirm(true)}
|
||||
className="flex w-full items-center justify-center gap-1.5 rounded border border-red-500/30 px-2 py-1.5 text-xs text-red-400 hover:bg-red-500/10"
|
||||
>
|
||||
<Trash2 size={12} />
|
||||
Delete Device
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
131
frontend/src/components/network/ui/animated-svg-edge.tsx
Normal file
131
frontend/src/components/network/ui/animated-svg-edge.tsx
Normal file
@@ -0,0 +1,131 @@
|
||||
import { memo } from 'react'
|
||||
import {
|
||||
BaseEdge,
|
||||
getSmoothStepPath,
|
||||
getStraightPath,
|
||||
getBezierPath,
|
||||
type EdgeProps,
|
||||
} from '@xyflow/react'
|
||||
|
||||
interface AnimatedEdgeData {
|
||||
connectionType?: string
|
||||
duration?: number
|
||||
direction?: 'forward' | 'reverse' | 'alternate' | 'alternate-reverse'
|
||||
path?: 'bezier' | 'smoothstep' | 'step' | 'straight'
|
||||
repeat?: number | 'indefinite'
|
||||
shape?: 'circle' | 'package'
|
||||
speed?: string | null
|
||||
notes?: string | null
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
const CONNECTION_COLORS: Record<string, string> = {
|
||||
ethernet: '#60a5fa',
|
||||
fiber: '#34d399',
|
||||
wifi: '#a78bfa',
|
||||
vpn: '#eab308',
|
||||
vlan: '#848b9b',
|
||||
wan: '#f87171',
|
||||
}
|
||||
|
||||
const DEFAULT_COLOR = '#848b9b'
|
||||
|
||||
function getPath(
|
||||
props: EdgeProps,
|
||||
pathType: string,
|
||||
): [string, number, number] {
|
||||
const params = {
|
||||
sourceX: props.sourceX,
|
||||
sourceY: props.sourceY,
|
||||
sourcePosition: props.sourcePosition,
|
||||
targetX: props.targetX,
|
||||
targetY: props.targetY,
|
||||
targetPosition: props.targetPosition,
|
||||
}
|
||||
|
||||
switch (pathType) {
|
||||
case 'bezier': {
|
||||
const [path, labelX, labelY] = getBezierPath(params)
|
||||
return [path, labelX, labelY]
|
||||
}
|
||||
case 'straight': {
|
||||
const [path, labelX, labelY] = getStraightPath(params)
|
||||
return [path, labelX, labelY]
|
||||
}
|
||||
default: {
|
||||
const [path, labelX, labelY] = getSmoothStepPath(params)
|
||||
return [path, labelX, labelY]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getAnimateMotionProps(data: AnimatedEdgeData) {
|
||||
const duration = data.duration ?? 2
|
||||
const direction = data.direction ?? 'forward'
|
||||
const repeat = data.repeat ?? 'indefinite'
|
||||
|
||||
const keyPoints: Record<string, string> = {
|
||||
forward: '0;1',
|
||||
reverse: '1;0',
|
||||
alternate: '0;1',
|
||||
'alternate-reverse': '1;0',
|
||||
}
|
||||
|
||||
return {
|
||||
dur: `${duration}s`,
|
||||
repeatCount: String(repeat),
|
||||
keyPoints: keyPoints[direction] || '0;1',
|
||||
keyTimes: '0;1',
|
||||
}
|
||||
}
|
||||
|
||||
function AnimatedSvgEdgeComponent(props: EdgeProps) {
|
||||
const data = (props.data || {}) as AnimatedEdgeData
|
||||
const connectionType = data.connectionType || 'ethernet'
|
||||
const color = CONNECTION_COLORS[connectionType] || DEFAULT_COLOR
|
||||
const pathType = data.path ?? 'smoothstep'
|
||||
const shape = data.shape ?? 'circle'
|
||||
|
||||
const [edgePath] = getPath(props, pathType)
|
||||
const motionProps = getAnimateMotionProps(data)
|
||||
|
||||
return (
|
||||
<>
|
||||
<BaseEdge
|
||||
path={edgePath}
|
||||
style={{
|
||||
stroke: color,
|
||||
strokeWidth: props.selected ? 3 : 2,
|
||||
...(connectionType === 'wifi' || connectionType === 'wan' || connectionType === 'vpn'
|
||||
? { strokeDasharray: connectionType === 'wifi' ? '3,3' : '8,4' }
|
||||
: {}),
|
||||
}}
|
||||
/>
|
||||
<circle r={0} fill={color}>
|
||||
<animateMotion
|
||||
path={edgePath}
|
||||
calcMode="linear"
|
||||
{...motionProps}
|
||||
/>
|
||||
<animate
|
||||
attributeName="r"
|
||||
values="0;3;3;3;0"
|
||||
keyTimes="0;0.05;0.5;0.95;1"
|
||||
dur={motionProps.dur}
|
||||
repeatCount={motionProps.repeatCount}
|
||||
/>
|
||||
</circle>
|
||||
{shape === 'package' && (
|
||||
<rect x={-4} y={-4} width={8} height={8} rx={2} fill={color} opacity={0.8}>
|
||||
<animateMotion
|
||||
path={edgePath}
|
||||
calcMode="linear"
|
||||
{...motionProps}
|
||||
/>
|
||||
</rect>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export const AnimatedSvgEdge = memo(AnimatedSvgEdgeComponent)
|
||||
20
frontend/src/components/network/ui/base-handle.tsx
Normal file
20
frontend/src/components/network/ui/base-handle.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import { Handle, type HandleProps } from '@xyflow/react'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
export type BaseHandleProps = HandleProps
|
||||
|
||||
export function BaseHandle({ className, children, ...props }: ComponentProps<typeof Handle>) {
|
||||
return (
|
||||
<Handle
|
||||
{...props}
|
||||
className={cn(
|
||||
'h-[10px] w-[10px] rounded-full border border-default bg-elevated transition-opacity',
|
||||
'opacity-0 group-hover:opacity-100',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</Handle>
|
||||
)
|
||||
}
|
||||
56
frontend/src/components/network/ui/base-node.tsx
Normal file
56
frontend/src/components/network/ui/base-node.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
export function BaseNode({ className, ...props }: ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'bg-card text-heading relative rounded-lg border border-default',
|
||||
'transition-colors hover:border-hover',
|
||||
'in-[.selected]:border-accent',
|
||||
className,
|
||||
)}
|
||||
tabIndex={0}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function BaseNodeHeader({ className, ...props }: ComponentProps<'header'>) {
|
||||
return (
|
||||
<header
|
||||
{...props}
|
||||
className={cn('flex flex-row items-center gap-2 px-3 py-2', className)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function BaseNodeHeaderTitle({ className, ...props }: ComponentProps<'h3'>) {
|
||||
return (
|
||||
<h3
|
||||
data-slot="base-node-title"
|
||||
className={cn('select-none flex-1 text-xs font-semibold text-heading', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function BaseNodeContent({ className, ...props }: ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="base-node-content"
|
||||
className={cn('flex flex-col gap-y-1 px-3 pb-2', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function BaseNodeFooter({ className, ...props }: ComponentProps<'div'>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="base-node-footer"
|
||||
className={cn('flex flex-col items-center gap-y-1 border-t border-default px-3 pt-1.5 pb-2', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
68
frontend/src/components/network/ui/labeled-group-node.tsx
Normal file
68
frontend/src/components/network/ui/labeled-group-node.tsx
Normal file
@@ -0,0 +1,68 @@
|
||||
import type { ReactNode, ComponentProps } from 'react'
|
||||
import { Panel, NodeResizer, type NodeProps, type PanelPosition } from '@xyflow/react'
|
||||
import { BaseNode } from './base-node'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
export type GroupNodeLabelProps = ComponentProps<'div'>
|
||||
|
||||
export function GroupNodeLabel({ children, className, ...props }: GroupNodeLabelProps) {
|
||||
return (
|
||||
<div className="h-full w-full" {...props}>
|
||||
<div className={cn('bg-card text-muted-foreground w-fit p-2 text-[10px] font-semibold uppercase tracking-wider', className)}>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export interface GroupNodeData {
|
||||
label?: string
|
||||
groupType?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export type GroupNodeProps = Partial<NodeProps> & {
|
||||
label?: ReactNode
|
||||
position?: PanelPosition
|
||||
}
|
||||
|
||||
function getLabelClassName(position?: PanelPosition): string {
|
||||
switch (position) {
|
||||
case 'top-left': return 'rounded-br-sm'
|
||||
case 'top-center': return 'rounded-b-sm'
|
||||
case 'top-right': return 'rounded-bl-sm'
|
||||
case 'bottom-left': return 'rounded-tr-sm'
|
||||
case 'bottom-right': return 'rounded-tl-sm'
|
||||
case 'bottom-center': return 'rounded-t-sm'
|
||||
default: return 'rounded-br-sm'
|
||||
}
|
||||
}
|
||||
|
||||
export function GroupNode({ data, selected }: NodeProps) {
|
||||
const nodeData = data as unknown as GroupNodeData
|
||||
const label = nodeData.label || 'Group'
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeResizer
|
||||
isVisible={selected}
|
||||
minWidth={150}
|
||||
minHeight={100}
|
||||
lineStyle={{ borderColor: 'var(--color-accent)', borderWidth: 1 }}
|
||||
handleStyle={{ width: 8, height: 8, borderColor: 'var(--color-accent)', background: 'var(--color-card)' }}
|
||||
/>
|
||||
<BaseNode
|
||||
className={cn(
|
||||
'h-full w-full min-h-[100px] min-w-[150px] overflow-hidden rounded-lg bg-elevated/30 border-default/50',
|
||||
selected && 'border-accent',
|
||||
)}
|
||||
>
|
||||
<Panel className="m-0 p-0" position="top-left">
|
||||
<GroupNodeLabel className={getLabelClassName('top-left')}>
|
||||
{label}
|
||||
</GroupNodeLabel>
|
||||
</Panel>
|
||||
</BaseNode>
|
||||
</>
|
||||
)
|
||||
}
|
||||
39
frontend/src/components/network/ui/labeled-handle.tsx
Normal file
39
frontend/src/components/network/ui/labeled-handle.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import { type HandleProps, Position } from '@xyflow/react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { BaseHandle } from './base-handle'
|
||||
|
||||
const flexDirections: Record<string, string> = {
|
||||
[Position.Top]: 'flex-col',
|
||||
[Position.Right]: 'flex-row-reverse justify-end',
|
||||
[Position.Bottom]: 'flex-col-reverse justify-end',
|
||||
[Position.Left]: 'flex-row',
|
||||
}
|
||||
|
||||
export function LabeledHandle({
|
||||
className,
|
||||
labelClassName,
|
||||
handleClassName,
|
||||
title,
|
||||
position,
|
||||
...props
|
||||
}: HandleProps &
|
||||
ComponentProps<'div'> & {
|
||||
title: string
|
||||
handleClassName?: string
|
||||
labelClassName?: string
|
||||
}) {
|
||||
const { ref, ...handleProps } = props
|
||||
return (
|
||||
<div
|
||||
title={title}
|
||||
className={cn('relative flex items-center', flexDirections[position], className)}
|
||||
ref={ref}
|
||||
>
|
||||
<BaseHandle position={position} className={handleClassName} {...handleProps} />
|
||||
<label className={cn('text-muted-foreground text-[10px] font-mono px-1.5', labelClassName)}>
|
||||
{title}
|
||||
</label>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
43
frontend/src/components/network/ui/node-status-indicator.tsx
Normal file
43
frontend/src/components/network/ui/node-status-indicator.tsx
Normal file
@@ -0,0 +1,43 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
export type NodeStatus = 'online' | 'offline' | 'degraded' | 'unknown'
|
||||
|
||||
const STATUS_BORDER_COLORS: Record<NodeStatus, string> = {
|
||||
online: 'border-emerald-400',
|
||||
offline: 'border-red-400',
|
||||
degraded: 'border-yellow-400',
|
||||
unknown: '',
|
||||
}
|
||||
|
||||
const STATUS_GLOW: Record<NodeStatus, string> = {
|
||||
online: 'shadow-[0_0_8px_rgba(52,211,153,0.3)]',
|
||||
offline: 'shadow-[0_0_8px_rgba(248,113,113,0.3)]',
|
||||
degraded: 'shadow-[0_0_8px_rgba(250,204,21,0.3)]',
|
||||
unknown: '',
|
||||
}
|
||||
|
||||
interface NodeStatusIndicatorProps {
|
||||
status?: NodeStatus
|
||||
children: ReactNode
|
||||
className?: string
|
||||
}
|
||||
|
||||
export function NodeStatusIndicator({ status = 'unknown', children, className }: NodeStatusIndicatorProps) {
|
||||
if (status === 'unknown') {
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'rounded-lg border-2 transition-colors',
|
||||
STATUS_BORDER_COLORS[status],
|
||||
STATUS_GLOW[status],
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user