feat: ConnectWise PSA integration (#106)

PSA abstraction layer with provider pattern, ConnectWise integration (connection management, ticket linking, note posting, status updates, member mapping), Integrations page UI, Fernet credential encryption, in-memory TTL cache, 6 DB migrations, ConnectWise API reference docs.
This commit was merged in pull request #106.
This commit is contained in:
chihlasm
2026-03-15 01:45:35 -04:00
committed by GitHub
parent 80e094215f
commit 46865882c6
60 changed files with 726716 additions and 11 deletions

View File

@@ -19,6 +19,9 @@ from app.models.survey_invite import SurveyInvite
from app.models.ai_suggestion import AISuggestion # noqa: F401
from app.models.kb_import import KBImport, KBImportNode # noqa: F401
from app.models.script_template import ScriptCategory, ScriptTemplate, ScriptGeneration # noqa: F401
from app.models.psa_connection import PsaConnection # noqa: F401
from app.models.psa_post_log import PsaPostLog # noqa: F401
from app.models.psa_member_mapping import PsaMemberMapping # noqa: F401
from app.core.config import settings
# this is the Alembic Config object

View File

@@ -0,0 +1,39 @@
"""Add psa_connections table.
Revision ID: 058
Revises: 057
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "058"
down_revision = "057"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"psa_connections",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("provider", sa.String(50), nullable=False),
sa.Column("display_name", sa.String(100), nullable=False),
sa.Column("site_url", sa.String(255), nullable=False),
sa.Column("company_id", sa.String(100), nullable=False),
sa.Column("credentials_encrypted", sa.Text(), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("true")),
sa.Column("last_validated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["account_id"], ["accounts.id"], ondelete="CASCADE"),
sa.UniqueConstraint("account_id"),
)
op.create_index("ix_psa_connections_account_id", "psa_connections", ["account_id"])
def downgrade() -> None:
op.drop_index("ix_psa_connections_account_id")
op.drop_table("psa_connections")

View File

@@ -0,0 +1,31 @@
"""Add psa_ticket_id and psa_connection_id to sessions.
Revision ID: 059
Revises: 058
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "059"
down_revision = "058"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("sessions", sa.Column("psa_ticket_id", sa.String(100), nullable=True))
op.add_column(
"sessions",
sa.Column(
"psa_connection_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("psa_connections.id", ondelete="SET NULL"),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("sessions", "psa_connection_id")
op.drop_column("sessions", "psa_ticket_id")

View File

@@ -0,0 +1,57 @@
"""Add psa_post_log table for PSA note posting audit trail.
Revision ID: 060
Revises: 059
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "060"
down_revision = "059"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"psa_post_log",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column(
"psa_connection_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("psa_connections.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("ticket_id", sa.String(100), nullable=False),
sa.Column("note_type", sa.String(50), nullable=False),
sa.Column("content_posted", sa.Text(), nullable=False),
sa.Column("external_note_id", sa.String(100), nullable=True),
sa.Column("status", sa.String(20), nullable=False),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("status_changed_from", sa.String(100), nullable=True),
sa.Column("status_changed_to", sa.String(100), nullable=True),
sa.Column(
"posted_by",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id"),
nullable=False,
),
sa.Column(
"posted_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_table("psa_post_log")

View File

@@ -0,0 +1,60 @@
"""Add psa_member_mappings table for user-to-CW-member mapping.
Revision ID: 061
Revises: 060
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "061"
down_revision = "060"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"psa_member_mappings",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"psa_connection_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("psa_connections.id", ondelete="CASCADE"),
nullable=False,
index=True,
),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("external_member_id", sa.String(100), nullable=False),
sa.Column("external_member_name", sa.String(200), nullable=False),
sa.Column("matched_by", sa.String(50), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.UniqueConstraint(
"psa_connection_id", "user_id",
name="uq_psa_member_mapping_connection_user",
),
sa.UniqueConstraint(
"psa_connection_id", "external_member_id",
name="uq_psa_member_mapping_connection_member",
),
)
def downgrade() -> None:
op.drop_table("psa_member_mappings")

View File

@@ -0,0 +1,565 @@
"""PSA integration endpoints — connection CRUD and test."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import delete
from app.api.deps import get_current_active_user, require_account_owner, require_engineer_or_admin
from app.core.database import get_db
from app.models.psa_connection import PsaConnection
from app.models.psa_member_mapping import PsaMemberMapping
from app.models.user import User
from app.schemas.psa_connection import (
PsaConnectionCreate,
PsaConnectionResponse,
PsaConnectionTestResponse,
PsaConnectionUpdate,
PSATicketSearchResult,
PSATicketStatusItem,
PsaMemberMappingResponse,
PsaMemberMappingSaveRequest,
PsaMemberResponse,
AutoMatchResult,
)
from app.core.config import settings
from app.services.psa.encryption import (
decrypt_credentials,
encrypt_credentials,
mask_credential,
)
router = APIRouter(prefix="/integrations/psa", tags=["integrations"])
# ── helpers ──────────────────────────────────────────────────────────
def _to_response(conn: PsaConnection) -> PsaConnectionResponse:
"""Build a response DTO with masked credential hints."""
creds = decrypt_credentials(conn.credentials_encrypted)
return PsaConnectionResponse(
id=conn.id,
account_id=conn.account_id,
provider=conn.provider,
display_name=conn.display_name,
site_url=conn.site_url,
company_id=conn.company_id,
is_active=conn.is_active,
last_validated_at=conn.last_validated_at,
created_at=conn.created_at,
updated_at=conn.updated_at,
public_key_hint=mask_credential(creds.get("public_key")),
private_key_hint=mask_credential(creds.get("private_key")),
)
async def _get_connection(
account_id: UUID, db: AsyncSession
) -> PsaConnection | None:
result = await db.execute(
select(PsaConnection).where(PsaConnection.account_id == account_id)
)
return result.scalar_one_or_none()
async def _test_credentials(
provider: str,
site_url: str,
company_id: str,
public_key: str,
private_key: str,
client_id: str,
) -> PsaConnectionTestResponse:
"""Instantiate a provider and run test_connection."""
if provider == "connectwise":
from app.services.psa.connectwise.client import ConnectWiseClient
from app.services.psa.connectwise.provider import ConnectWiseProvider
client = ConnectWiseClient(
site_url=site_url,
company_id=company_id,
public_key=public_key,
private_key=private_key,
client_id=client_id,
)
result = await ConnectWiseProvider(client).test_connection()
return PsaConnectionTestResponse(
success=result.success,
message=result.message,
server_version=result.server_version,
)
return PsaConnectionTestResponse(
success=False,
message=f"Unsupported provider: {provider}",
)
# ── endpoints ────────────────────────────────────────────────────────
@router.get("/connections", response_model=PsaConnectionResponse | None)
async def get_connection(
current_user: Annotated[User, Depends(get_current_active_user)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Return the account's PSA connection (redacted credentials) or null."""
if not current_user.account_id:
return None
conn = await _get_connection(current_user.account_id, db)
if not conn:
return None
return _to_response(conn)
@router.post(
"/connections",
response_model=PsaConnectionResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_connection(
data: PsaConnectionCreate,
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Create a new PSA connection. Tests credentials before saving."""
if not current_user.account_id:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No account associated with user")
if not settings.CW_CLIENT_ID:
raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE, "ConnectWise integration is not configured on this server")
# Check for existing connection
existing = await _get_connection(current_user.account_id, db)
if existing:
raise HTTPException(
status.HTTP_409_CONFLICT,
"A PSA connection already exists for this account. Update or delete the existing one.",
)
# Test connection before saving
test_result = await _test_credentials(
provider=data.provider,
site_url=data.site_url,
company_id=data.company_id,
public_key=data.public_key,
private_key=data.private_key,
client_id=settings.CW_CLIENT_ID,
)
if not test_result.success:
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY,
f"Connection test failed: {test_result.message}",
)
credentials = {
"public_key": data.public_key,
"private_key": data.private_key,
}
conn = PsaConnection(
account_id=current_user.account_id,
provider=data.provider,
display_name=data.display_name,
site_url=data.site_url,
company_id=data.company_id,
credentials_encrypted=encrypt_credentials(credentials),
is_active=True,
last_validated_at=datetime.now(timezone.utc),
)
db.add(conn)
await db.commit()
await db.refresh(conn)
return _to_response(conn)
@router.put("/connections/{connection_id}", response_model=PsaConnectionResponse)
async def update_connection(
connection_id: UUID,
data: PsaConnectionUpdate,
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Update an existing PSA connection. Re-tests if credentials change."""
conn = await _get_connection_or_404(connection_id, current_user, db)
# Decrypt existing credentials
creds = decrypt_credentials(conn.credentials_encrypted)
# Track whether credential fields changed
cred_fields = {"public_key", "private_key"}
cred_changed = False
# Apply updates
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field in cred_fields:
if value is not None and value != creds.get(field):
creds[field] = value
cred_changed = True
else:
setattr(conn, field, value)
# Re-test if credentials changed
if cred_changed:
site_url = update_data.get("site_url", conn.site_url)
company_id_val = update_data.get("company_id", conn.company_id)
test_result = await _test_credentials(
provider=conn.provider,
site_url=site_url,
company_id=company_id_val,
public_key=creds["public_key"],
private_key=creds["private_key"],
client_id=settings.CW_CLIENT_ID or "",
)
if not test_result.success:
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY,
f"Connection test failed: {test_result.message}",
)
conn.credentials_encrypted = encrypt_credentials(creds)
conn.last_validated_at = datetime.now(timezone.utc)
conn.updated_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(conn)
return _to_response(conn)
@router.delete(
"/connections/{connection_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_connection(
connection_id: UUID,
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Delete a PSA connection."""
conn = await _get_connection_or_404(connection_id, current_user, db)
await db.delete(conn)
await db.commit()
@router.post(
"/connections/{connection_id}/test",
response_model=PsaConnectionTestResponse,
)
async def test_connection(
connection_id: UUID,
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Test an existing PSA connection."""
conn = await _get_connection_or_404(connection_id, current_user, db)
creds = decrypt_credentials(conn.credentials_encrypted)
result = await _test_credentials(
provider=conn.provider,
site_url=conn.site_url,
company_id=conn.company_id,
public_key=creds["public_key"],
private_key=creds["private_key"],
client_id=settings.CW_CLIENT_ID or "",
)
if result.success:
conn.last_validated_at = datetime.now(timezone.utc)
await db.commit()
# Invalidate cached PSA data when connection is re-validated
from app.services.psa.cache import psa_cache
psa_cache.clear()
return result
# ── ticket / status / company endpoints ──────────────────────────
@router.get("/tickets/search", response_model=list[PSATicketSearchResult])
async def search_tickets(
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
query: str = "",
board_id: int | None = None,
status_id: int | None = None,
include_closed: bool = False,
):
"""Search ConnectWise tickets."""
if not current_user.account_id:
raise HTTPException(status_code=400, detail="User has no account")
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError
try:
provider = await get_provider_for_account(current_user.account_id, db)
tickets = await provider.search_tickets(
query, board_id=board_id, status_id=status_id, include_closed=include_closed
)
return [
PSATicketSearchResult(
id=t.id,
summary=t.summary,
company_name=t.company_name,
board_name=t.board_name,
status_name=t.status_name,
priority_name=t.priority_name,
closed=t.closed,
)
for t in tickets
]
except PSAError as e:
raise HTTPException(status_code=502, detail=str(e))
@router.get("/tickets/{ticket_id}")
async def get_ticket(
ticket_id: str,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get a single CW ticket by ID."""
if not current_user.account_id:
raise HTTPException(status_code=400, detail="User has no account")
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError, PSANotFoundError
try:
provider = await get_provider_for_account(current_user.account_id, db)
ticket = await provider.get_ticket(ticket_id)
return ticket
except PSANotFoundError:
raise HTTPException(status_code=404, detail="Ticket not found")
except PSAError as e:
raise HTTPException(status_code=502, detail=str(e))
@router.get("/tickets/{ticket_id}/statuses", response_model=list[PSATicketStatusItem])
async def get_ticket_statuses(
ticket_id: str,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get available statuses for a ticket's board."""
if not current_user.account_id:
raise HTTPException(status_code=400, detail="User has no account")
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError, PSANotFoundError
try:
provider = await get_provider_for_account(current_user.account_id, db)
ticket = await provider.get_ticket(ticket_id)
if not ticket.board_id:
raise HTTPException(status_code=400, detail="Ticket has no board")
statuses = await provider.get_ticket_statuses(ticket.board_id)
return [PSATicketStatusItem(id=s.id, name=s.name, is_closed=s.is_closed) for s in statuses]
except PSANotFoundError:
raise HTTPException(status_code=404, detail="Ticket not found")
except PSAError as e:
raise HTTPException(status_code=502, detail=str(e))
# ── member mapping endpoints ─────────────────────────────────────────
@router.get("/members", response_model=list[PsaMemberResponse])
async def list_members(
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""List CW members (from CW API)."""
if not current_user.account_id:
raise HTTPException(status_code=400, detail="User has no account")
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError
try:
provider = await get_provider_for_account(current_user.account_id, db)
members = await provider.list_members()
return [
PsaMemberResponse(id=m.id, identifier=m.identifier, name=m.name, email=m.email)
for m in members
]
except PSAError as e:
raise HTTPException(status_code=502, detail=str(e))
@router.get("/member-mappings", response_model=list[PsaMemberMappingResponse])
async def get_member_mappings(
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Get all member mappings for the account."""
conn = await _get_account_connection(current_user.account_id, db)
if not conn:
return []
result = await db.execute(
select(PsaMemberMapping).where(PsaMemberMapping.psa_connection_id == conn.id)
)
mappings = result.scalars().all()
response = []
for m in mappings:
user_result = await db.execute(select(User).where(User.id == m.user_id))
user = user_result.scalar_one_or_none()
if user:
response.append(PsaMemberMappingResponse(
id=str(m.id),
user_id=str(m.user_id),
user_email=user.email,
user_name=user.name,
external_member_id=m.external_member_id,
external_member_name=m.external_member_name,
matched_by=m.matched_by,
))
return response
@router.post("/member-mappings", response_model=list[PsaMemberMappingResponse])
async def save_member_mappings(
mappings: list[PsaMemberMappingSaveRequest],
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Save/update member mappings (batch). Replaces all existing mappings."""
conn = await _get_account_connection(current_user.account_id, db)
if not conn:
raise HTTPException(status_code=400, detail="No PSA connection configured")
# Delete existing mappings
await db.execute(
delete(PsaMemberMapping).where(PsaMemberMapping.psa_connection_id == conn.id)
)
# Insert new mappings
for m in mappings:
mapping = PsaMemberMapping(
psa_connection_id=conn.id,
user_id=UUID(m.user_id),
external_member_id=m.external_member_id,
external_member_name=m.external_member_name,
matched_by="manual_admin",
)
db.add(mapping)
await db.commit()
# Return the saved mappings
return await get_member_mappings(current_user, db)
@router.post("/member-mappings/auto-match", response_model=AutoMatchResult)
async def auto_match_members(
current_user: Annotated[User, Depends(require_account_owner)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Auto-match RF users to CW members by email."""
conn = await _get_account_connection(current_user.account_id, db)
if not conn:
raise HTTPException(status_code=400, detail="No PSA connection configured")
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError
try:
provider = await get_provider_for_account(current_user.account_id, db)
cw_members = await provider.list_members()
except PSAError as e:
raise HTTPException(status_code=502, detail=str(e))
# Build email → member lookup
email_to_member: dict = {}
for m in cw_members:
if m.email:
email_to_member[m.email.lower()] = m
# Get account users
users_result = await db.execute(
select(User).where(User.account_id == current_user.account_id, User.is_active.is_(True))
)
users = users_result.scalars().all()
matched = []
unmatched_count = 0
for user in users:
cw_member = email_to_member.get(user.email.lower())
if cw_member:
# Check if mapping already exists
existing = await db.execute(
select(PsaMemberMapping).where(
PsaMemberMapping.psa_connection_id == conn.id,
PsaMemberMapping.user_id == user.id,
)
)
if not existing.scalar_one_or_none():
mapping = PsaMemberMapping(
psa_connection_id=conn.id,
user_id=user.id,
external_member_id=cw_member.id,
external_member_name=cw_member.name,
matched_by="auto_email",
)
db.add(mapping)
matched.append((mapping, user))
else:
unmatched_count += 1
await db.commit()
# Build response
matched_response = [
PsaMemberMappingResponse(
id=str(m.id),
user_id=str(m.user_id),
user_email=u.email,
user_name=u.name,
external_member_id=m.external_member_id,
external_member_name=m.external_member_name,
matched_by=m.matched_by,
)
for m, u in matched
]
return AutoMatchResult(matched=matched_response, unmatched_users=unmatched_count)
# ── internal helpers ─────────────────────────────────────────────────
async def _get_account_connection(
account_id: UUID | None, db: AsyncSession
) -> PsaConnection | None:
"""Get the PSA connection for an account."""
if not account_id:
return None
result = await db.execute(
select(PsaConnection).where(PsaConnection.account_id == account_id)
)
return result.scalar_one_or_none()
async def _get_connection_or_404(
connection_id: UUID, user: User, db: AsyncSession
) -> PsaConnection:
"""Fetch a connection by ID, ensuring it belongs to the user's account."""
result = await db.execute(
select(PsaConnection).where(PsaConnection.id == connection_id)
)
conn = result.scalar_one_or_none()
if not conn:
raise HTTPException(status.HTTP_404_NOT_FOUND, "PSA connection not found")
if conn.account_id != user.account_id:
raise HTTPException(status.HTTP_404_NOT_FOUND, "PSA connection not found")
return conn

View File

@@ -23,8 +23,12 @@ from app.schemas.session import (
SessionComplete,
SessionVariablesUpdate,
PrepareSessionRequest,
TicketLinkRequest,
TicketLinkResponse,
PSATicketResponse,
)
from app.api.deps import get_current_active_user
from app.schemas.psa_connection import PsaPostRequest
from app.api.deps import get_current_active_user, require_engineer_or_admin
from app.core.permissions import can_access_tree
from app.services.export_service import generate_markdown_export, generate_text_export, generate_html_export, generate_psa_export
@@ -738,3 +742,382 @@ async def batch_launch_sessions(
for s in created_sessions
],
)
# ── PSA Ticket Link ─────────────────────────────────────────────────
@router.patch("/{session_id}/ticket-link", response_model=TicketLinkResponse)
async def link_ticket(
session_id: UUID,
data: TicketLinkRequest,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Link or unlink a PSA ticket to/from a session."""
from app.models.psa_connection import PsaConnection
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSANotFoundError, PSAError
# Look up session
result = await db.execute(select(Session).where(Session.id == session_id))
session = result.scalar_one_or_none()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found",
)
# Verify ownership or admin
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
if not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session",
)
# Unlink
if data.psa_ticket_id is None:
session.psa_ticket_id = None
session.psa_connection_id = None
await db.commit()
return TicketLinkResponse(
session_id=str(session.id),
psa_ticket_id=None,
ticket=None,
)
# Link — validate ticket exists in CW
if not current_user.account_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No account associated with your user",
)
try:
provider = await get_provider_for_account(current_user.account_id, db)
except PSAError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(exc),
)
# Fetch the connection to store its ID
conn_result = await db.execute(
select(PsaConnection).where(
PsaConnection.account_id == current_user.account_id,
PsaConnection.is_active.is_(True),
)
)
psa_connection = conn_result.scalar_one_or_none()
try:
ticket = await provider.get_ticket(data.psa_ticket_id)
except PSANotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Ticket not found in ConnectWise",
)
except PSAError as exc:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"PSA error: {exc}",
)
session.psa_ticket_id = ticket.id
session.psa_connection_id = psa_connection.id if psa_connection else None
await db.commit()
return TicketLinkResponse(
session_id=str(session.id),
psa_ticket_id=ticket.id,
ticket=PSATicketResponse(
id=ticket.id,
summary=ticket.summary,
company_name=ticket.company_name,
board_name=ticket.board_name,
status_name=ticket.status_name,
priority_name=ticket.priority_name,
),
)
# ── PSA Post to Ticket ────────────────────────────────────────────
@router.get("/{session_id}/psa-post/preview")
async def psa_post_preview(
session_id: UUID,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Preview the content that will be posted to the linked PSA ticket.
Generates session documentation in PSA format, fetches current ticket
details and available statuses, and counts previous posts.
"""
from app.models.psa_post_log import PsaPostLog
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError
from app.schemas.psa_connection import (
PsaPreviewResponse,
PSATicketSearchResult,
PSATicketStatusItem,
)
from sqlalchemy import func as sa_func
# Load session
result = await db.execute(select(Session).where(Session.id == session_id))
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="You don't have access to this session")
if not session.psa_ticket_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Session has no linked PSA ticket. Link a ticket first.",
)
if not current_user.account_id:
raise HTTPException(status_code=400, detail="No account associated with your user")
# Generate PSA export content
export_options = SessionExport(
format="psa",
include_timestamps=True,
include_tree_info=True,
include_outcome_notes=True,
include_next_steps=True,
include_summary=True,
)
content = generate_psa_export(session, export_options)
# Resolve session variables in content
session_vars = getattr(session, "session_variables", None) or {}
if session_vars:
from app.services.variable_service import resolve_variables
content = resolve_variables(content, session_vars)
# Fetch ticket details and statuses from CW
try:
provider = await get_provider_for_account(current_user.account_id, db)
ticket = await provider.get_ticket(session.psa_ticket_id)
available_statuses: list[PSATicketStatusItem] = []
if ticket.board_id:
statuses = await provider.get_ticket_statuses(ticket.board_id)
available_statuses = [
PSATicketStatusItem(id=s.id, name=s.name, is_closed=s.is_closed)
for s in statuses
]
except PSAError as e:
raise HTTPException(status_code=502, detail=f"PSA error: {e}")
# Count previous posts
count_result = await db.execute(
select(sa_func.count(PsaPostLog.id)).where(
PsaPostLog.session_id == session_id
)
)
previous_posts = count_result.scalar_one()
return PsaPreviewResponse(
content=content,
ticket=PSATicketSearchResult(
id=ticket.id,
summary=ticket.summary,
company_name=ticket.company_name,
board_name=ticket.board_name,
status_name=ticket.status_name,
priority_name=ticket.priority_name,
closed=ticket.closed,
),
available_statuses=available_statuses,
character_count=len(content),
previous_posts=previous_posts,
)
@router.post("/{session_id}/psa-post")
async def psa_post_to_ticket(
session_id: UUID,
data: PsaPostRequest,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Post session documentation as a note to the linked PSA ticket.
Optionally updates the ticket status if update_status_id is provided.
All actions are logged in psa_post_log for audit trail.
"""
from app.models.psa_connection import PsaConnection
from app.models.psa_post_log import PsaPostLog
from app.services.psa.registry import get_provider_for_account
from app.services.psa.exceptions import PSAError
from app.schemas.psa_connection import PsaPostResponse
# Load session
result = await db.execute(select(Session).where(Session.id == session_id))
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="You don't have access to this session")
if not session.psa_ticket_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Session has no linked PSA ticket. Link a ticket first.",
)
if not current_user.account_id:
raise HTTPException(status_code=400, detail="No account associated with your user")
# Get PSA connection ID for audit
conn_result = await db.execute(
select(PsaConnection).where(
PsaConnection.account_id == current_user.account_id,
PsaConnection.is_active.is_(True),
)
)
psa_connection = conn_result.scalar_one_or_none()
# Look up member mapping for attribution
from app.models.psa_member_mapping import PsaMemberMapping
member_id = None
if psa_connection:
mapping_result = await db.execute(
select(PsaMemberMapping).where(
PsaMemberMapping.psa_connection_id == psa_connection.id,
PsaMemberMapping.user_id == current_user.id,
)
)
mapping = mapping_result.scalar_one_or_none()
if mapping:
member_id = mapping.external_member_id
# Post note
try:
provider = await get_provider_for_account(current_user.account_id, db)
note_result = await provider.post_note(
ticket_id=session.psa_ticket_id,
text=data.content,
note_type=data.note_type,
member_id=member_id,
)
note_status = "success"
external_note_id = note_result.id
error_message = None
except PSAError as e:
note_status = "failed"
external_note_id = None
error_message = str(e)
# Optionally update ticket status
status_changed_from = None
status_changed_to = None
if data.update_status_id and note_status == "success":
try:
# Get current status before update
current_ticket = await provider.get_ticket(session.psa_ticket_id)
status_changed_from = current_ticket.status_name
if current_ticket.status_id != data.update_status_id:
updated_ticket = await provider.update_ticket_status(
session.psa_ticket_id, data.update_status_id
)
status_changed_to = updated_ticket.status_name
except PSAError as e:
# Log the status update failure but don't fail the whole request
# since the note was already posted successfully
if error_message:
error_message += f"; Status update failed: {e}"
else:
error_message = f"Note posted successfully but status update failed: {e}"
# Log to audit trail
log_entry = PsaPostLog(
session_id=session.id,
psa_connection_id=psa_connection.id if psa_connection else None,
ticket_id=session.psa_ticket_id,
note_type=data.note_type,
content_posted=data.content,
external_note_id=external_note_id,
status=note_status,
error_message=error_message,
status_changed_from=status_changed_from,
status_changed_to=status_changed_to,
posted_by=current_user.id,
)
db.add(log_entry)
await db.commit()
await db.refresh(log_entry)
if note_status == "failed":
raise HTTPException(
status_code=502,
detail=error_message or "Failed to post note to PSA",
)
return PsaPostResponse(
id=str(log_entry.id),
session_id=str(session.id),
ticket_id=session.psa_ticket_id,
note_type=data.note_type,
status=note_status,
external_note_id=external_note_id,
error_message=error_message,
status_changed_from=status_changed_from,
status_changed_to=status_changed_to,
posted_at=log_entry.posted_at.isoformat(),
)
@router.get("/{session_id}/psa-posts")
async def list_psa_posts(
session_id: UUID,
current_user: Annotated[User, Depends(require_engineer_or_admin)],
db: Annotated[AsyncSession, Depends(get_db)],
):
"""List all PSA post history for a session, ordered by most recent first."""
from app.models.psa_post_log import PsaPostLog
from app.schemas.psa_connection import PsaPostLogResponse
# Verify session access
result = await db.execute(select(Session).where(Session.id == session_id))
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.user_id != current_user.id and session.assigned_to_id != current_user.id:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="You don't have access to this session")
# Query post log
log_result = await db.execute(
select(PsaPostLog)
.where(PsaPostLog.session_id == session_id)
.order_by(PsaPostLog.posted_at.desc())
)
logs = log_result.scalars().all()
return [
PsaPostLogResponse(
id=str(log.id),
ticket_id=log.ticket_id,
note_type=log.note_type,
status=log.status,
error_message=log.error_message,
status_changed_from=log.status_changed_from,
status_changed_to=log.status_changed_to,
posted_at=log.posted_at.isoformat(),
content_preview=log.content_posted[:200],
)
for log in logs
]

View File

@@ -17,6 +17,7 @@ from app.api.endpoints import ai_suggestions
from app.api.endpoints import kb_accelerator
from app.api.endpoints import beta_signup
from app.api.endpoints import scripts
from app.api.endpoints import integrations
api_router = APIRouter()
@@ -58,3 +59,4 @@ api_router.include_router(ai_suggestions.router)
api_router.include_router(kb_accelerator.router)
api_router.include_router(beta_signup.router)
api_router.include_router(scripts.router)
api_router.include_router(integrations.router)

View File

@@ -119,6 +119,16 @@ class Settings(BaseSettings):
"""Check if any AI provider is configured."""
return self.ANTHROPIC_API_KEY is not None or self.GOOGLE_AI_API_KEY is not None
# ConnectWise PSA Integration
# CW_CLIENT_ID is a product-level GUID registered at developer.connectwise.com
# All MSP customers share this single clientId — it identifies ResolutionFlow as the integration
CW_CLIENT_ID: Optional[str] = None
@property
def cw_enabled(self) -> bool:
"""Check if ConnectWise integration is configured."""
return self.CW_CLIENT_ID is not None
# Monitoring
SENTRY_DSN: Optional[str] = None

View File

@@ -36,6 +36,9 @@ from .survey_response import SurveyResponse
from .survey_invite import SurveyInvite
from .kb_import import KBImport, KBImportNode
from .script_template import ScriptCategory, ScriptTemplate, ScriptGeneration
from .psa_connection import PsaConnection
from .psa_post_log import PsaPostLog
from .psa_member_mapping import PsaMemberMapping
__all__ = [
"User",
@@ -86,4 +89,7 @@ __all__ = [
"ScriptCategory",
"ScriptTemplate",
"ScriptGeneration",
"PsaConnection",
"PsaPostLog",
"PsaMemberMapping",
]

View File

@@ -15,6 +15,7 @@ if TYPE_CHECKING:
from app.models.step_category import StepCategory
from app.models.step_library import StepLibrary
from app.models.account_limit_override import AccountLimitOverride
from app.models.psa_connection import PsaConnection
class Account(Base):
@@ -53,3 +54,4 @@ class Account(Base):
step_categories: Mapped[list["StepCategory"]] = relationship("StepCategory", foreign_keys="[StepCategory.account_id]", back_populates="account")
step_library: Mapped[list["StepLibrary"]] = relationship("StepLibrary", foreign_keys="[StepLibrary.account_id]", back_populates="account")
limit_override: Mapped[Optional["AccountLimitOverride"]] = relationship("AccountLimitOverride", back_populates="account", uselist=False)
psa_connection: Mapped[Optional["PsaConnection"]] = relationship("PsaConnection", back_populates="account", uselist=False)

View File

@@ -0,0 +1,48 @@
"""PSA connection model — one per account."""
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import String, DateTime, Boolean, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class PsaConnection(Base):
__tablename__ = "psa_connections"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
unique=True,
index=True,
)
provider: Mapped[str] = mapped_column(String(50), nullable=False)
display_name: Mapped[str] = mapped_column(String(100), nullable=False)
site_url: Mapped[str] = mapped_column(String(255), nullable=False)
company_id: Mapped[str] = mapped_column(String(100), nullable=False)
credentials_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_validated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
account = relationship("Account", back_populates="psa_connection")

View File

@@ -0,0 +1,47 @@
"""Maps ResolutionFlow users to CW members."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import String, DateTime, ForeignKey, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class PsaMemberMapping(Base):
__tablename__ = "psa_member_mappings"
__table_args__ = (
UniqueConstraint("psa_connection_id", "user_id", name="uq_psa_member_mapping_connection_user"),
UniqueConstraint("psa_connection_id", "external_member_id", name="uq_psa_member_mapping_connection_member"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
psa_connection_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("psa_connections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
external_member_id: Mapped[str] = mapped_column(String(100), nullable=False)
external_member_name: Mapped[str] = mapped_column(String(200), nullable=False)
matched_by: Mapped[str] = mapped_column(String(50), nullable=False) # 'auto_email', 'manual_admin', 'manual_self'
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
# Relationships
psa_connection = relationship("PsaConnection", foreign_keys=[psa_connection_id])
user = relationship("User", foreign_keys=[user_id])

View File

@@ -0,0 +1,58 @@
"""Audit trail for notes posted to PSA systems."""
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import String, DateTime, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
class PsaPostLog(Base):
__tablename__ = "psa_post_log"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
psa_connection_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("psa_connections.id", ondelete="SET NULL"),
nullable=True,
)
ticket_id: Mapped[str] = mapped_column(String(100), nullable=False)
note_type: Mapped[str] = mapped_column(String(50), nullable=False)
content_posted: Mapped[str] = mapped_column(Text, nullable=False)
external_note_id: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True
)
status: Mapped[str] = mapped_column(
String(20), nullable=False
) # 'success' or 'failed'
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status_changed_from: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True
)
status_changed_to: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True
)
posted_by: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id"), nullable=False
)
posted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# Relationships
session = relationship("Session", foreign_keys=[session_id])
psa_connection = relationship("PsaConnection", foreign_keys=[psa_connection_id])
user = relationship("User", foreign_keys=[posted_by])

View File

@@ -83,6 +83,15 @@ class Session(Base):
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")
# PSA ticket link
psa_ticket_id: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
psa_connection_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("psa_connections.id", ondelete="SET NULL"),
nullable=True,
)
psa_connection = relationship("PsaConnection", foreign_keys=[psa_connection_id])
# Batch tracking (maintenance flows)
batch_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True), nullable=True, index=True

View File

@@ -15,6 +15,12 @@ from .script_template import (
ScriptTemplateCreate, ScriptTemplateUpdate, ScriptTemplateListItem, ScriptTemplateDetail,
ScriptGenerateRequest, ScriptGenerateResponse, ScriptGenerationRecord,
)
from .psa_connection import (
PsaConnectionCreate, PsaConnectionUpdate, PsaConnectionResponse, PsaConnectionTestResponse,
PSATicketSearchResult, PSATicketStatusItem,
PsaPostRequest, PsaPostResponse, PsaPreviewResponse, PsaPostLogResponse,
PsaMemberMappingResponse, PsaMemberMappingSaveRequest, PsaMemberResponse, AutoMatchResult,
)
__all__ = [
# User
@@ -39,4 +45,9 @@ __all__ = [
"ScriptCategoryResponse",
"ScriptTemplateCreate", "ScriptTemplateUpdate", "ScriptTemplateListItem", "ScriptTemplateDetail",
"ScriptGenerateRequest", "ScriptGenerateResponse", "ScriptGenerationRecord",
# PSA Connection
"PsaConnectionCreate", "PsaConnectionUpdate", "PsaConnectionResponse", "PsaConnectionTestResponse",
"PSATicketSearchResult", "PSATicketStatusItem",
"PsaPostRequest", "PsaPostResponse", "PsaPreviewResponse", "PsaPostLogResponse",
"PsaMemberMappingResponse", "PsaMemberMappingSaveRequest", "PsaMemberResponse", "AutoMatchResult",
]

View File

@@ -0,0 +1,138 @@
"""Pydantic schemas for PSA connection management."""
from __future__ import annotations
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class PsaConnectionCreate(BaseModel):
provider: str = Field(default="connectwise", pattern="^(connectwise|autotask)$")
display_name: str = Field(min_length=1, max_length=100)
site_url: str = Field(min_length=1, max_length=255)
company_id: str = Field(min_length=1, max_length=100)
public_key: str = Field(min_length=1)
private_key: str = Field(min_length=1)
# Note: client_id is NOT per-MSP — it's a product-level GUID from settings.CW_CLIENT_ID
class PsaConnectionUpdate(BaseModel):
display_name: str | None = None
site_url: str | None = None
company_id: str | None = None
public_key: str | None = None
private_key: str | None = None
class PsaConnectionResponse(BaseModel):
id: UUID
account_id: UUID
provider: str
display_name: str
site_url: str
company_id: str
is_active: bool
last_validated_at: datetime | None
created_at: datetime
updated_at: datetime
public_key_hint: str
private_key_hint: str
model_config = {"from_attributes": True}
class PsaConnectionTestResponse(BaseModel):
success: bool
message: str
server_version: str | None = None
# ── Ticket search & status schemas ────────────────────────────────
class PSATicketSearchResult(BaseModel):
id: str
summary: str
company_name: str | None = None
board_name: str | None = None
status_name: str | None = None
priority_name: str | None = None
closed: bool = False
class PSATicketStatusItem(BaseModel):
id: int
name: str
is_closed: bool = False
# ── PSA post (note posting) schemas ──────────────────────────────
class PsaPostRequest(BaseModel):
note_type: str = Field(pattern="^(internal_analysis|resolution|description)$")
content: str = Field(min_length=1)
update_status_id: int | None = None
class PsaPostResponse(BaseModel):
id: str
session_id: str
ticket_id: str
note_type: str
status: str
external_note_id: str | None = None
error_message: str | None = None
status_changed_from: str | None = None
status_changed_to: str | None = None
posted_at: str
class PsaPreviewResponse(BaseModel):
content: str
ticket: PSATicketSearchResult
available_statuses: list[PSATicketStatusItem]
character_count: int
previous_posts: int
class PsaPostLogResponse(BaseModel):
id: str
ticket_id: str
note_type: str
status: str
error_message: str | None = None
status_changed_from: str | None = None
status_changed_to: str | None = None
posted_at: str
content_preview: str # first 200 chars
# ── Member mapping schemas ───────────────────────────────────────
class PsaMemberMappingResponse(BaseModel):
id: str
user_id: str
user_email: str
user_name: str
external_member_id: str
external_member_name: str
matched_by: str
class PsaMemberMappingSaveRequest(BaseModel):
user_id: str
external_member_id: str
external_member_name: str
class PsaMemberResponse(BaseModel):
id: str
identifier: str
name: str
email: str | None = None
class AutoMatchResult(BaseModel):
matched: list[PsaMemberMappingResponse]
unmatched_users: int

View File

@@ -94,6 +94,10 @@ class SessionResponse(BaseModel):
batch_id: Optional[UUID] = None
target_label: Optional[str] = None
# PSA ticket link
psa_ticket_id: Optional[str] = None
psa_connection_id: Optional[UUID] = None
class Config:
from_attributes = True
@@ -140,3 +144,28 @@ class SaveAsTreeResponse(BaseModel):
tree_id: UUID
tree_name: str
message: str
# ── PSA ticket link ──────────────────────────────────────────────────
class TicketLinkRequest(BaseModel):
"""Link or unlink a PSA ticket to a session."""
psa_ticket_id: Optional[str] = None # null to unlink
class PSATicketResponse(BaseModel):
"""PSA ticket details returned when linking."""
id: str
summary: str
company_name: Optional[str] = None
board_name: Optional[str] = None
status_name: Optional[str] = None
priority_name: Optional[str] = None
class TicketLinkResponse(BaseModel):
"""Response after linking/unlinking a ticket."""
session_id: str
psa_ticket_id: Optional[str] = None
ticket: Optional[PSATicketResponse] = None

View File

@@ -0,0 +1 @@
"""PSA integration abstraction layer."""

View File

@@ -0,0 +1,68 @@
"""Abstract base class for PSA provider implementations."""
from __future__ import annotations
from abc import ABC, abstractmethod
from .types import (
ConnectionTestResult,
PSATicket,
PSANote,
PSAStatus,
PSACompany,
PSAMember,
PSAConfiguration,
)
class PSAProvider(ABC):
"""Abstract base for PSA integrations (ConnectWise, Autotask, etc.)."""
@abstractmethod
async def test_connection(self) -> ConnectionTestResult:
...
@abstractmethod
async def get_ticket(self, ticket_id: str) -> PSATicket:
...
@abstractmethod
async def search_tickets(self, query: str, **filters) -> list[PSATicket]:
...
@abstractmethod
async def post_note(
self,
ticket_id: str,
text: str,
note_type: str,
member_id: str | None = None,
) -> PSANote:
...
@abstractmethod
async def update_ticket_status(
self,
ticket_id: str,
status_id: int,
) -> PSATicket:
...
@abstractmethod
async def get_ticket_statuses(self, board_id: int) -> list[PSAStatus]:
...
@abstractmethod
async def list_companies(self, **filters) -> list[PSACompany]:
...
@abstractmethod
async def get_company(self, company_id: str) -> PSACompany:
...
@abstractmethod
async def list_members(self) -> list[PSAMember]:
...
@abstractmethod
async def get_ticket_configurations(self, ticket_id: str) -> list[PSAConfiguration]:
...

View File

@@ -0,0 +1,38 @@
"""Simple in-memory TTL cache for PSA API responses."""
from __future__ import annotations
import time
from typing import Any
class PSACache:
"""Account-scoped in-memory cache with TTL expiry."""
def __init__(self) -> None:
self._store: dict[str, tuple[Any, float]] = {}
def get(self, key: str) -> Any | None:
entry = self._store.get(key)
if entry is None:
return None
value, expires_at = entry
if time.time() > expires_at:
del self._store[key]
return None
return value
def set(self, key: str, value: Any, ttl_seconds: int) -> None:
self._store[key] = (value, time.time() + ttl_seconds)
def invalidate(self, prefix: str) -> None:
"""Remove all entries matching a key prefix."""
keys_to_remove = [k for k in self._store if k.startswith(prefix)]
for k in keys_to_remove:
del self._store[k]
def clear(self) -> None:
self._store.clear()
# Global singleton — acceptable at current scale (see design doc section 6)
psa_cache = PSACache()

View File

@@ -0,0 +1 @@
"""ConnectWise PSA provider implementation."""

View File

@@ -0,0 +1,288 @@
"""Low-level HTTP client for ConnectWise PSA REST API.
Handles auth headers, base URL resolution (cloud vs on-premise),
pagination, retry with backoff, and error mapping.
"""
from __future__ import annotations
import asyncio
import base64
import ipaddress
import logging
import socket
from typing import Any
from urllib.parse import urlparse
import httpx
from app.services.psa.exceptions import (
PSAAuthError,
PSAConnectionError,
PSANotFoundError,
PSAPermissionError,
PSARateLimitError,
PSAServerError,
PSATimeoutError,
)
logger = logging.getLogger(__name__)
# Pinned CW API version per best-practices/PSA-Versioning.md
CW_API_VERSION = "2025.16"
CW_ACCEPT_HEADER = f"application/vnd.connectwise.com+json; version={CW_API_VERSION}"
# Known CW cloud domains (for SSRF prevention)
CW_ALLOWED_DOMAINS = {
"myconnectwise.net",
"connectwisedev.com",
}
REQUEST_TIMEOUT = 30.0
MAX_RETRIES = 2
MAX_PAGE_SIZE = 1000
def _validate_site_url(site_url: str) -> None:
"""Validate site_url is a known CW domain (SSRF prevention).
Rejects any hostname that is not a recognized ConnectWise domain
and any hostname that resolves to a private/loopback/link-local IP.
"""
# Ensure scheme for parsing
url = site_url if "://" in site_url else f"https://{site_url}"
parsed = urlparse(url)
hostname = parsed.hostname or ""
# Check against allowed domains
if not any(
hostname.endswith(f".{domain}") or hostname == domain
for domain in CW_ALLOWED_DOMAINS
):
raise PSAConnectionError(
f"Invalid ConnectWise site URL: {hostname}. "
"Must be a *.myconnectwise.net or *.connectwisedev.com domain.",
provider="connectwise",
)
# Resolve and check for private IPs
try:
addrs = socket.getaddrinfo(hostname, None)
for _, _, _, _, sockaddr in addrs:
ip = ipaddress.ip_address(sockaddr[0])
if ip.is_private or ip.is_loopback or ip.is_link_local:
raise PSAConnectionError(
f"Site URL resolves to a private/internal address: {sockaddr[0]}",
provider="connectwise",
)
except socket.gaierror:
raise PSAConnectionError(
f"Cannot resolve hostname: {hostname}",
provider="connectwise",
)
class ConnectWiseClient:
"""Async HTTP client for the ConnectWise PSA API.
Auth: Authorization: Basic {base64(companyId+publicKey:privateKey)} + clientId header
Accept: application/vnd.connectwise.com+json; version=2025.16
Base URL: resolved dynamically via /login/companyinfo/{companyId}
Pagination: page/pageSize params, max 1000 per page, while-loop pattern
Retry: respects 429 Retry-After, max 2 retries with exponential backoff for 5xx
Timeout: 30 seconds per request
"""
def __init__(
self,
site_url: str,
company_id: str,
public_key: str,
private_key: str,
client_id: str,
):
self.site_url = site_url.rstrip("/")
self.company_id = company_id
self.client_id = client_id
# Auth: Base64(companyId+publicKey:privateKey)
auth_string = f"{company_id}+{public_key}:{private_key}"
self._auth_b64 = base64.b64encode(auth_string.encode()).decode()
# Base URL resolved lazily on first request
self._base_url: str | None = None
async def _resolve_base_url(self) -> str:
"""Resolve the CW API base URL using /login/companyinfo/{companyId}.
Cloud environments return a versioned codebase (e.g., v2025_3/) requiring
an 'api-' prefix on the hostname. On-premise returns v4_6_release/ with
no prefix needed.
"""
if self._base_url:
return self._base_url
_validate_site_url(self.site_url)
info_url = f"https://{self.site_url}/login/companyinfo/{self.company_id}"
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
try:
resp = await client.get(info_url)
resp.raise_for_status()
except httpx.TimeoutException:
raise PSATimeoutError(
"Timed out resolving CW base URL", provider="connectwise"
)
except httpx.HTTPError as e:
raise PSAConnectionError(
f"Failed to resolve CW base URL: {e}", provider="connectwise"
)
data = resp.json()
codebase = data.get("Codebase", "v4_6_release/")
site_url = data.get("SiteUrl", self.site_url)
# Cloud codebase (e.g., v2025_3/) requires api- prefix
if codebase != "v4_6_release/":
if not site_url.startswith("api-"):
site_url = f"api-{site_url}"
self._base_url = f"https://{site_url}/{codebase}apis/3.0"
logger.info("Resolved CW base URL: %s", self._base_url)
return self._base_url
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Basic {self._auth_b64}",
"clientId": self.client_id,
"Accept": CW_ACCEPT_HEADER,
"Content-Type": "application/json",
}
async def _request(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json_body: Any = None,
retries: int = MAX_RETRIES,
) -> Any:
"""Make an authenticated request to the CW API with retry and error mapping."""
base_url = await self._resolve_base_url()
url = f"{base_url}/{path.lstrip('/')}"
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
for attempt in range(retries + 1):
try:
resp = await client.request(
method,
url,
headers=self._headers(),
params=params,
json=json_body,
)
except httpx.TimeoutException:
if attempt < retries:
await asyncio.sleep(2 ** attempt)
continue
raise PSATimeoutError(
"ConnectWise request timed out", provider="connectwise"
)
except httpx.ConnectError:
raise PSAConnectionError(
"Cannot reach ConnectWise server", provider="connectwise"
)
# Rate limit — retry with Retry-After backoff
if resp.status_code == 429:
if attempt < retries:
retry_after = int(resp.headers.get("Retry-After", "5"))
await asyncio.sleep(retry_after)
continue
raise PSARateLimitError(
"ConnectWise rate limit exceeded",
retry_after_seconds=int(
resp.headers.get("Retry-After", "60")
),
provider="connectwise",
)
# Map error status codes to typed exceptions
if resp.status_code == 401:
raise PSAAuthError(
"Invalid credentials. Check your API keys.",
provider="connectwise",
)
if resp.status_code == 403:
raise PSAPermissionError(
"Insufficient permissions. Check the API member's security role.",
provider="connectwise",
)
if resp.status_code == 404:
raise PSANotFoundError(
"Resource not found.", provider="connectwise"
)
if resp.status_code >= 500:
if attempt < retries:
await asyncio.sleep(2 ** attempt)
continue
raise PSAServerError(
"ConnectWise is experiencing issues. Try again.",
provider="connectwise",
)
resp.raise_for_status()
if resp.status_code == 204:
return None
return resp.json()
# Should not reach here, but satisfy type checker
raise PSAConnectionError(
"Request failed after all retries", provider="connectwise"
)
async def get(self, path: str, params: dict[str, Any] | None = None) -> Any:
"""GET request to CW API."""
return await self._request("GET", path, params=params)
async def post(self, path: str, json_body: Any = None) -> Any:
"""POST request to CW API."""
return await self._request("POST", path, json_body=json_body)
async def patch(self, path: str, json_body: Any = None) -> Any:
"""PATCH request to CW API (JSON Patch array format).
CW uses JSON Patch syntax: [{"op": "replace", "path": "field", "value": ...}]
"""
return await self._request("PATCH", path, json_body=json_body)
async def delete(self, path: str) -> Any:
"""DELETE request to CW API."""
return await self._request("DELETE", path)
async def get_paginated(
self,
path: str,
params: dict[str, Any] | None = None,
max_pages: int = 10,
) -> list[Any]:
"""Fetch all pages of a paginated CW endpoint.
Uses navigable pagination with page/pageSize params.
Stops when a page returns fewer results than pageSize or max_pages is reached.
"""
params = dict(params or {})
params.setdefault("pageSize", MAX_PAGE_SIZE)
all_results: list[Any] = []
for page in range(1, max_pages + 1):
params["page"] = page
results = await self.get(path, params=params)
if not results:
break
all_results.extend(results)
if len(results) < params["pageSize"]:
break
return all_results

View File

@@ -0,0 +1,283 @@
"""ConnectWise implementation of PSAProvider."""
from __future__ import annotations
from app.services.psa.base import PSAProvider
from app.services.psa.cache import psa_cache
from app.services.psa.types import (
ConnectionTestResult,
PSATicket,
PSANote,
PSAStatus,
PSACompany,
PSAMember,
PSAConfiguration,
)
from .client import ConnectWiseClient
class ConnectWiseProvider(PSAProvider):
"""ConnectWise PSA provider implementation."""
def __init__(self, client: ConnectWiseClient):
self.client = client
async def test_connection(self) -> ConnectionTestResult:
"""Test the CW connection by fetching system info."""
try:
info = await self.client.get("/system/info")
return ConnectionTestResult(
success=True,
message="Connected successfully.",
server_version=info.get("version", None),
)
except Exception as e:
return ConnectionTestResult(
success=False,
message=str(e),
server_version=None,
)
# ── Tickets ───────────────────────────────────────────────────────
async def get_ticket(self, ticket_id: str) -> PSATicket:
"""Fetch a single ticket by ID from ConnectWise."""
data = await self.client.get(
f"/service/tickets/{ticket_id}",
params={"fields": "id,summary,company,board,status,priority,closedFlag"},
)
return self._map_ticket(data)
async def search_tickets(self, query: str, **filters) -> list[PSATicket]:
"""Search CW tickets by summary. Supports board_id and status_id filters."""
params: dict = {
"fields": "id,summary,company,board,status,priority,closedFlag",
"orderBy": "id desc",
"pageSize": 25,
}
# Build CW condition query
conditions: list[str] = []
if query:
conditions.append(f"summary contains '{query}'")
if filters.get("board_id"):
conditions.append(f"board/id = {filters['board_id']}")
if filters.get("status_id"):
conditions.append(f"status/id = {filters['status_id']}")
if not filters.get("include_closed", False):
conditions.append("closedFlag = false")
if conditions:
params["conditions"] = " and ".join(conditions)
data = await self.client.get("/service/tickets", params=params)
return [
self._map_ticket(t)
for t in (data if isinstance(data, list) else [])
]
async def get_ticket_configurations(
self, ticket_id: str
) -> list[PSAConfiguration]:
"""Get configurations (assets) attached to a ticket."""
data = await self.client.get(
f"/service/tickets/{ticket_id}/configurations",
params={"fields": "id,deviceIdentifier,type,company"},
)
return [
PSAConfiguration(
id=str(c["id"]),
name=c.get("deviceIdentifier", ""),
type=c.get("type", {}).get("name") if c.get("type") else None,
company_name=c.get("company", {}).get("name") if c.get("company") else None,
)
for c in (data if isinstance(data, list) else [])
]
# ── Board statuses (cached) ───────────────────────────────────────
async def get_ticket_statuses(self, board_id: int) -> list[PSAStatus]:
"""Get available statuses for a CW service board (cached 1 hour)."""
cache_key = f"board_statuses:{board_id}"
cached = psa_cache.get(cache_key)
if cached is not None:
return cached
data = await self.client.get(
f"/service/boards/{board_id}/statuses",
params={"fields": "id,name,closedStatus", "pageSize": 100},
)
result = [
PSAStatus(
id=s["id"],
name=s["name"],
is_closed=s.get("closedStatus", False),
)
for s in (data if isinstance(data, list) else [])
]
psa_cache.set(cache_key, result, ttl_seconds=3600)
return result
# ── Companies ─────────────────────────────────────────────────────
async def list_companies(self, **filters) -> list[PSACompany]:
"""List companies from CW, optionally filtered by status."""
params: dict = {
"fields": "id,name,status",
"pageSize": 100,
"orderBy": "name asc",
}
conditions: list[str] = []
if filters.get("status"):
conditions.append(f"status/name = '{filters['status']}'")
if conditions:
params["conditions"] = " and ".join(conditions)
data = await self.client.get("/company/companies", params=params)
return [
PSACompany(
id=str(c["id"]),
name=c.get("name", ""),
status=c.get("status", {}).get("name") if c.get("status") else None,
)
for c in (data if isinstance(data, list) else [])
]
async def get_company(self, company_id: str) -> PSACompany:
"""Fetch a single company by ID."""
data = await self.client.get(
f"/company/companies/{company_id}",
params={"fields": "id,name,status"},
)
return PSACompany(
id=str(data["id"]),
name=data.get("name", ""),
status=data.get("status", {}).get("name") if data.get("status") else None,
)
# ── Notes & status updates ───────────────────────────────────────
async def post_note(
self,
ticket_id: str,
text: str,
note_type: str,
member_id: str | None = None,
) -> PSANote:
"""Post a note to a CW ticket.
Maps ResolutionFlow note types to CW flag fields:
- internal_analysis → internalAnalysisFlag (internal only)
- resolution → resolutionFlag (internal, triggers notifications)
- description → detailDescriptionFlag (external, triggers notifications)
"""
from app.services.psa.types import NoteType
flags = {
NoteType.INTERNAL_ANALYSIS: {
"internalAnalysisFlag": True,
"resolutionFlag": False,
"detailDescriptionFlag": False,
"internalFlag": True,
"processNotifications": False,
},
NoteType.RESOLUTION: {
"internalAnalysisFlag": False,
"resolutionFlag": True,
"detailDescriptionFlag": False,
"internalFlag": True,
"processNotifications": True,
},
NoteType.DESCRIPTION: {
"internalAnalysisFlag": False,
"resolutionFlag": False,
"detailDescriptionFlag": True,
"internalFlag": False,
"processNotifications": True,
},
}
note_flags = flags.get(note_type, flags[NoteType.INTERNAL_ANALYSIS])
# NOTE: CW Developer Guide states \n is "Not Supported" in JSON bodies
# and may be collapsed to a single space. CW does support markdown in ticket
# notes (see PSA-Markdown.md). This needs sandbox testing — if newlines are
# lost, consider using double-space line breaks or HTML <br> tags instead.
body: dict = {
"text": text,
**note_flags,
}
if member_id:
body["member"] = {"id": int(member_id)}
data = await self.client.post(
f"/service/tickets/{ticket_id}/notes", json_body=body
)
return PSANote(
id=str(data.get("id", "")),
text=data.get("text", ""),
note_type=note_type,
created_at=data.get("dateCreated"),
)
async def update_ticket_status(
self, ticket_id: str, status_id: int
) -> PSATicket:
"""Update a CW ticket's status using JSON Patch format."""
patch_body = [
{"op": "replace", "path": "status", "value": {"id": status_id}}
]
data = await self.client.patch(
f"/service/tickets/{ticket_id}", json_body=patch_body
)
return self._map_ticket(data)
async def list_members(self) -> list[PSAMember]:
"""List CW system members (cached 15 minutes)."""
cache_key = "members:all"
cached = psa_cache.get(cache_key)
if cached is not None:
return cached
data = await self.client.get_paginated(
"/system/members",
params={
"fields": "id,identifier,firstName,lastName,officeEmail",
"conditions": "inactiveFlag = false",
"pageSize": 1000,
},
)
result = [
PSAMember(
id=str(m["id"]),
identifier=m.get("identifier", ""),
name=f"{m.get('firstName', '')} {m.get('lastName', '')}".strip(),
email=m.get("officeEmail"),
)
for m in data
]
psa_cache.set(cache_key, result, ttl_seconds=900)
return result
# ── Private helpers ───────────────────────────────────────────────
@staticmethod
def _map_ticket(data: dict) -> PSATicket:
"""Map a CW ticket JSON dict to a PSATicket."""
return PSATicket(
id=str(data["id"]),
summary=data.get("summary", ""),
company_name=data.get("company", {}).get("name"),
company_id=str(data["company"]["id"]) if data.get("company") else None,
board_name=data.get("board", {}).get("name"),
board_id=data.get("board", {}).get("id"),
status_name=data.get("status", {}).get("name"),
status_id=data.get("status", {}).get("id"),
priority_name=data.get("priority", {}).get("name"),
priority_id=data.get("priority", {}).get("id"),
closed=data.get("closedFlag", False),
)

View File

@@ -0,0 +1,53 @@
"""Fernet-based credential encryption for PSA connections.
Uses the application SECRET_KEY to derive a Fernet encryption key via HKDF.
Credentials are stored as a single encrypted JSON blob.
"""
from __future__ import annotations
import json
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from app.core.config import settings
def _get_fernet() -> Fernet:
"""Derive a Fernet key from the application SECRET_KEY."""
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=b"resolutionflow-psa-credentials",
info=b"psa-credential-encryption",
)
key = hkdf.derive(settings.SECRET_KEY.encode())
fernet_key = base64.urlsafe_b64encode(key)
return Fernet(fernet_key)
def encrypt_credentials(credentials: dict) -> str:
"""Encrypt a credentials dict to a Fernet token string."""
f = _get_fernet()
plaintext = json.dumps(credentials).encode()
return f.encrypt(plaintext).decode()
def decrypt_credentials(encrypted: str) -> dict:
"""Decrypt a Fernet token string back to a credentials dict."""
f = _get_fernet()
plaintext = f.decrypt(encrypted.encode())
return json.loads(plaintext)
def mask_credential(value: str | None, visible_suffix: int = 4) -> str:
"""Return a masked version of a credential for display.
e.g., 'abcdefghij' -> '......ghij'
"""
if not value:
return "\u2022\u2022\u2022\u2022\u2022\u2022"
if len(value) <= visible_suffix:
return "\u2022\u2022\u2022\u2022\u2022\u2022" + value
return "\u2022\u2022\u2022\u2022\u2022\u2022" + value[-visible_suffix:]

View File

@@ -0,0 +1,45 @@
"""Typed exceptions for PSA integration errors."""
class PSAError(Exception):
"""Base exception for all PSA integration errors."""
def __init__(self, message: str, provider: str = "unknown"):
self.provider = provider
super().__init__(message)
class PSAAuthError(PSAError):
"""Invalid or expired credentials."""
pass
class PSAPermissionError(PSAError):
"""Insufficient permissions on the PSA side."""
pass
class PSANotFoundError(PSAError):
"""Requested resource (ticket, company, etc.) not found."""
pass
class PSARateLimitError(PSAError):
"""Rate limit exceeded. retry_after_seconds may be set."""
def __init__(self, message: str, retry_after_seconds: int | None = None, provider: str = "unknown"):
self.retry_after_seconds = retry_after_seconds
super().__init__(message, provider)
class PSAServerError(PSAError):
"""Remote PSA server error (5xx)."""
pass
class PSATimeoutError(PSAError):
"""Request to PSA timed out."""
pass
class PSAConnectionError(PSAError):
"""Cannot reach the PSA server."""
pass

View File

@@ -0,0 +1,51 @@
"""Factory for instantiating PSA providers from stored connection data."""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.psa_connection import PsaConnection
from app.services.psa.base import PSAProvider
from app.core.config import settings
from app.services.psa.encryption import decrypt_credentials
from app.services.psa.exceptions import PSAConnectionError
async def get_provider_for_account(
account_id: UUID, db: AsyncSession
) -> PSAProvider:
"""Look up account's PSA connection, decrypt credentials, instantiate provider."""
result = await db.execute(
select(PsaConnection).where(
PsaConnection.account_id == account_id,
PsaConnection.is_active.is_(True),
)
)
connection = result.scalar_one_or_none()
if not connection:
raise PSAConnectionError(
"No active PSA connection configured for this account.",
provider="unknown",
)
if connection.provider == "connectwise":
from app.services.psa.connectwise.client import ConnectWiseClient
from app.services.psa.connectwise.provider import ConnectWiseProvider
creds = decrypt_credentials(connection.credentials_encrypted)
client = ConnectWiseClient(
site_url=connection.site_url,
company_id=connection.company_id,
public_key=creds["public_key"],
private_key=creds["private_key"],
client_id=settings.CW_CLIENT_ID or "",
)
return ConnectWiseProvider(client)
raise PSAConnectionError(
f"Unsupported PSA provider: {connection.provider}",
provider=connection.provider,
)

View File

@@ -0,0 +1,63 @@
"""Provider-agnostic PSA data types."""
from __future__ import annotations
from pydantic import BaseModel
class ConnectionTestResult(BaseModel):
success: bool
message: str
server_version: str | None = None
class PSATicket(BaseModel):
id: str
summary: str
company_name: str | None = None
company_id: str | None = None
board_name: str | None = None
board_id: int | None = None
status_name: str | None = None
status_id: int | None = None
priority_name: str | None = None
priority_id: int | None = None
closed: bool = False
class PSANote(BaseModel):
id: str
text: str
note_type: str
created_at: str | None = None
class PSAStatus(BaseModel):
id: int
name: str
is_closed: bool = False
class PSACompany(BaseModel):
id: str
name: str
status: str | None = None
class PSAMember(BaseModel):
id: str
identifier: str # CW login username
name: str
email: str | None = None
class PSAConfiguration(BaseModel):
id: str
name: str
type: str | None = None
company_name: str | None = None
class NoteType:
INTERNAL_ANALYSIS = "internal_analysis"
RESOLUTION = "resolution"
DESCRIPTION = "description"

View File

@@ -0,0 +1,59 @@
"""Tests for PSA connection endpoints — routing and RBAC only.
We cannot fully test create/update/test endpoints in CI because they
call the ConnectWise API. These tests verify routing and authorization.
"""
import pytest
from sqlalchemy import select, update
from app.models.user import User
@pytest.mark.asyncio
async def test_get_connection_empty(client, admin_auth_headers):
"""GET returns null when no connection exists."""
response = await client.get(
"/api/v1/integrations/psa/connections",
headers=admin_auth_headers,
)
assert response.status_code == 200
assert response.json() is None
@pytest.mark.asyncio
async def test_create_connection_requires_owner(client, test_user, auth_headers, test_db):
"""Engineer (non-owner) should get 403 on create."""
# Downgrade the test user from owner to engineer so require_account_owner rejects
user_id = test_user["user_data"]["id"]
await test_db.execute(
update(User).where(User.id == user_id).values(account_role="engineer")
)
await test_db.commit()
payload = {
"provider": "connectwise",
"display_name": "Test CW",
"site_url": "https://na.myconnectwise.net",
"company_id": "testmsp",
"public_key": "pub123",
"private_key": "priv456",
"client_id": "client789",
}
response = await client.post(
"/api/v1/integrations/psa/connections",
json=payload,
headers=auth_headers,
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_404(client, admin_auth_headers):
"""DELETE with a nonexistent ID returns 404."""
import uuid
fake_id = uuid.uuid4()
response = await client.delete(
f"/api/v1/integrations/psa/connections/{fake_id}",
headers=admin_auth_headers,
)
assert response.status_code == 404

View File

@@ -0,0 +1,44 @@
"""Tests for PSA credential encryption/decryption."""
import pytest
from app.services.psa.encryption import encrypt_credentials, decrypt_credentials
class TestCredentialEncryption:
def test_round_trip(self):
"""Encrypt then decrypt returns original credentials."""
creds = {
"public_key": "abc123",
"private_key": "secret456",
"client_id": "my-client-id",
}
encrypted = encrypt_credentials(creds)
# Encrypted should be a non-empty string, different from input
assert isinstance(encrypted, str)
assert len(encrypted) > 0
assert "secret456" not in encrypted
decrypted = decrypt_credentials(encrypted)
assert decrypted == creds
def test_different_inputs_produce_different_outputs(self):
creds1 = {"public_key": "key1", "private_key": "priv1", "client_id": "cid1"}
creds2 = {"public_key": "key2", "private_key": "priv2", "client_id": "cid2"}
enc1 = encrypt_credentials(creds1)
enc2 = encrypt_credentials(creds2)
assert enc1 != enc2
def test_tampered_ciphertext_raises(self):
creds = {"public_key": "k", "private_key": "p", "client_id": "c"}
encrypted = encrypt_credentials(creds)
tampered = encrypted[:-5] + "XXXXX"
with pytest.raises(Exception):
decrypt_credentials(tampered)
def test_mask_private_key(self):
from app.services.psa.encryption import mask_credential
assert mask_credential("abcdefghij") == "\u2022\u2022\u2022\u2022\u2022\u2022ghij"
assert mask_credential("abc") == "\u2022\u2022\u2022\u2022\u2022\u2022abc"
assert mask_credential("") == "\u2022\u2022\u2022\u2022\u2022\u2022"
assert mask_credential(None) == "\u2022\u2022\u2022\u2022\u2022\u2022"