"""PSA integration endpoints — connection CRUD and test.""" from __future__ import annotations from datetime import datetime, timezone from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_active_user, require_account_owner from app.core.database import get_db from app.models.psa_connection import PsaConnection from app.models.user import User from app.schemas.psa_connection import ( PsaConnectionCreate, PsaConnectionResponse, PsaConnectionTestResponse, PsaConnectionUpdate, ) from app.services.psa.encryption import ( decrypt_credentials, encrypt_credentials, mask_credential, ) router = APIRouter(prefix="/integrations/psa", tags=["integrations"]) # ── helpers ────────────────────────────────────────────────────────── def _to_response(conn: PsaConnection) -> PsaConnectionResponse: """Build a response DTO with masked credential hints.""" creds = decrypt_credentials(conn.credentials_encrypted) return PsaConnectionResponse( id=conn.id, account_id=conn.account_id, provider=conn.provider, display_name=conn.display_name, site_url=conn.site_url, company_id=conn.company_id, is_active=conn.is_active, last_validated_at=conn.last_validated_at, created_at=conn.created_at, updated_at=conn.updated_at, public_key_hint=mask_credential(creds.get("public_key")), private_key_hint=mask_credential(creds.get("private_key")), ) async def _get_connection( account_id: UUID, db: AsyncSession ) -> PsaConnection | None: result = await db.execute( select(PsaConnection).where(PsaConnection.account_id == account_id) ) return result.scalar_one_or_none() async def _test_credentials( provider: str, site_url: str, company_id: str, public_key: str, private_key: str, client_id: str, ) -> PsaConnectionTestResponse: """Instantiate a provider and run test_connection.""" if provider == "connectwise": from app.services.psa.connectwise.client import ConnectWiseClient from app.services.psa.connectwise.provider import ConnectWiseProvider client = ConnectWiseClient( site_url=site_url, company_id=company_id, public_key=public_key, private_key=private_key, client_id=client_id, ) result = await ConnectWiseProvider(client).test_connection() return PsaConnectionTestResponse( success=result.success, message=result.message, server_version=result.server_version, ) return PsaConnectionTestResponse( success=False, message=f"Unsupported provider: {provider}", ) # ── endpoints ──────────────────────────────────────────────────────── @router.get("/connections", response_model=PsaConnectionResponse | None) async def get_connection( current_user: Annotated[User, Depends(get_current_active_user)], db: Annotated[AsyncSession, Depends(get_db)], ): """Return the account's PSA connection (redacted credentials) or null.""" if not current_user.account_id: return None conn = await _get_connection(current_user.account_id, db) if not conn: return None return _to_response(conn) @router.post( "/connections", response_model=PsaConnectionResponse, status_code=status.HTTP_201_CREATED, ) async def create_connection( data: PsaConnectionCreate, current_user: Annotated[User, Depends(require_account_owner)], db: Annotated[AsyncSession, Depends(get_db)], ): """Create a new PSA connection. Tests credentials before saving.""" if not current_user.account_id: raise HTTPException(status.HTTP_400_BAD_REQUEST, "No account associated with user") # Check for existing connection existing = await _get_connection(current_user.account_id, db) if existing: raise HTTPException( status.HTTP_409_CONFLICT, "A PSA connection already exists for this account. Update or delete the existing one.", ) # Test connection before saving test_result = await _test_credentials( provider=data.provider, site_url=data.site_url, company_id=data.company_id, public_key=data.public_key, private_key=data.private_key, client_id=data.client_id, ) if not test_result.success: raise HTTPException( status.HTTP_422_UNPROCESSABLE_ENTITY, f"Connection test failed: {test_result.message}", ) credentials = { "public_key": data.public_key, "private_key": data.private_key, "client_id": data.client_id, } conn = PsaConnection( account_id=current_user.account_id, provider=data.provider, display_name=data.display_name, site_url=data.site_url, company_id=data.company_id, credentials_encrypted=encrypt_credentials(credentials), is_active=True, last_validated_at=datetime.now(timezone.utc), ) db.add(conn) await db.commit() await db.refresh(conn) return _to_response(conn) @router.put("/connections/{connection_id}", response_model=PsaConnectionResponse) async def update_connection( connection_id: UUID, data: PsaConnectionUpdate, current_user: Annotated[User, Depends(require_account_owner)], db: Annotated[AsyncSession, Depends(get_db)], ): """Update an existing PSA connection. Re-tests if credentials change.""" conn = await _get_connection_or_404(connection_id, current_user, db) # Decrypt existing credentials creds = decrypt_credentials(conn.credentials_encrypted) # Track whether credential fields changed cred_fields = {"public_key", "private_key", "client_id"} cred_changed = False # Apply updates update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): if field in cred_fields: if value is not None and value != creds.get(field): creds[field] = value cred_changed = True else: setattr(conn, field, value) # Re-test if credentials changed if cred_changed: site_url = update_data.get("site_url", conn.site_url) company_id_val = update_data.get("company_id", conn.company_id) test_result = await _test_credentials( provider=conn.provider, site_url=site_url, company_id=company_id_val, public_key=creds["public_key"], private_key=creds["private_key"], client_id=creds["client_id"], ) if not test_result.success: raise HTTPException( status.HTTP_422_UNPROCESSABLE_ENTITY, f"Connection test failed: {test_result.message}", ) conn.credentials_encrypted = encrypt_credentials(creds) conn.last_validated_at = datetime.now(timezone.utc) conn.updated_at = datetime.now(timezone.utc) await db.commit() await db.refresh(conn) return _to_response(conn) @router.delete( "/connections/{connection_id}", status_code=status.HTTP_204_NO_CONTENT, ) async def delete_connection( connection_id: UUID, current_user: Annotated[User, Depends(require_account_owner)], db: Annotated[AsyncSession, Depends(get_db)], ): """Delete a PSA connection.""" conn = await _get_connection_or_404(connection_id, current_user, db) await db.delete(conn) await db.commit() @router.post( "/connections/{connection_id}/test", response_model=PsaConnectionTestResponse, ) async def test_connection( connection_id: UUID, current_user: Annotated[User, Depends(require_account_owner)], db: Annotated[AsyncSession, Depends(get_db)], ): """Test an existing PSA connection.""" conn = await _get_connection_or_404(connection_id, current_user, db) creds = decrypt_credentials(conn.credentials_encrypted) result = await _test_credentials( provider=conn.provider, site_url=conn.site_url, company_id=conn.company_id, public_key=creds["public_key"], private_key=creds["private_key"], client_id=creds["client_id"], ) if result.success: conn.last_validated_at = datetime.now(timezone.utc) await db.commit() return result # ── internal helpers ───────────────────────────────────────────────── async def _get_connection_or_404( connection_id: UUID, user: User, db: AsyncSession ) -> PsaConnection: """Fetch a connection by ID, ensuring it belongs to the user's account.""" result = await db.execute( select(PsaConnection).where(PsaConnection.id == connection_id) ) conn = result.scalar_one_or_none() if not conn: raise HTTPException(status.HTTP_404_NOT_FOUND, "PSA connection not found") if conn.account_id != user.account_id: raise HTTPException(status.HTTP_404_NOT_FOUND, "PSA connection not found") return conn