From 086e4c6d597ee4fa68d78da57e26bbab354cc3b3 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Sat, 14 Mar 2026 21:48:09 -0400 Subject: [PATCH] feat(psa): add Fernet credential encryption with HKDF key derivation Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/services/psa/encryption.py | 53 ++++++++++++++++++++++++++ backend/tests/test_psa_encryption.py | 44 +++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 backend/app/services/psa/encryption.py create mode 100644 backend/tests/test_psa_encryption.py diff --git a/backend/app/services/psa/encryption.py b/backend/app/services/psa/encryption.py new file mode 100644 index 00000000..b537e3fa --- /dev/null +++ b/backend/app/services/psa/encryption.py @@ -0,0 +1,53 @@ +"""Fernet-based credential encryption for PSA connections. + +Uses the application SECRET_KEY to derive a Fernet encryption key via HKDF. +Credentials are stored as a single encrypted JSON blob. +""" +from __future__ import annotations + +import json +import base64 + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +from app.core.config import settings + + +def _get_fernet() -> Fernet: + """Derive a Fernet key from the application SECRET_KEY.""" + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=b"resolutionflow-psa-credentials", + info=b"psa-credential-encryption", + ) + key = hkdf.derive(settings.SECRET_KEY.encode()) + fernet_key = base64.urlsafe_b64encode(key) + return Fernet(fernet_key) + + +def encrypt_credentials(credentials: dict) -> str: + """Encrypt a credentials dict to a Fernet token string.""" + f = _get_fernet() + plaintext = json.dumps(credentials).encode() + return f.encrypt(plaintext).decode() + + +def decrypt_credentials(encrypted: str) -> dict: + """Decrypt a Fernet token string back to a credentials dict.""" + f = _get_fernet() + plaintext = f.decrypt(encrypted.encode()) + return json.loads(plaintext) + + +def mask_credential(value: str | None, visible_suffix: int = 4) -> str: + """Return a masked version of a credential for display. + e.g., 'abcdefghij' -> '......ghij' + """ + if not value: + return "\u2022\u2022\u2022\u2022\u2022\u2022" + if len(value) <= visible_suffix: + return "\u2022\u2022\u2022\u2022\u2022\u2022" + value + return "\u2022\u2022\u2022\u2022\u2022\u2022" + value[-visible_suffix:] diff --git a/backend/tests/test_psa_encryption.py b/backend/tests/test_psa_encryption.py new file mode 100644 index 00000000..9860b120 --- /dev/null +++ b/backend/tests/test_psa_encryption.py @@ -0,0 +1,44 @@ +"""Tests for PSA credential encryption/decryption.""" +import pytest +from app.services.psa.encryption import encrypt_credentials, decrypt_credentials + + +class TestCredentialEncryption: + def test_round_trip(self): + """Encrypt then decrypt returns original credentials.""" + creds = { + "public_key": "abc123", + "private_key": "secret456", + "client_id": "my-client-id", + } + encrypted = encrypt_credentials(creds) + + # Encrypted should be a non-empty string, different from input + assert isinstance(encrypted, str) + assert len(encrypted) > 0 + assert "secret456" not in encrypted + + decrypted = decrypt_credentials(encrypted) + assert decrypted == creds + + def test_different_inputs_produce_different_outputs(self): + creds1 = {"public_key": "key1", "private_key": "priv1", "client_id": "cid1"} + creds2 = {"public_key": "key2", "private_key": "priv2", "client_id": "cid2"} + + enc1 = encrypt_credentials(creds1) + enc2 = encrypt_credentials(creds2) + assert enc1 != enc2 + + def test_tampered_ciphertext_raises(self): + creds = {"public_key": "k", "private_key": "p", "client_id": "c"} + encrypted = encrypt_credentials(creds) + tampered = encrypted[:-5] + "XXXXX" + with pytest.raises(Exception): + decrypt_credentials(tampered) + + def test_mask_private_key(self): + from app.services.psa.encryption import mask_credential + assert mask_credential("abcdefghij") == "\u2022\u2022\u2022\u2022\u2022\u2022ghij" + assert mask_credential("abc") == "\u2022\u2022\u2022\u2022\u2022\u2022abc" + assert mask_credential("") == "\u2022\u2022\u2022\u2022\u2022\u2022" + assert mask_credential(None) == "\u2022\u2022\u2022\u2022\u2022\u2022"