Files
resolutionflow/backend/tests/test_branding.py
chihlasm 758cd61621 fix: propagate account_id through all write paths missing NOT NULL coverage
Service layer (production code):
- branch_manager: set account_id on SessionBranch (root + fork) and ForkPoint
  from session.account_id; load session in create_fork for this purpose
- handoff_manager: set account_id on SessionHandoff from session.account_id
- ai_suggestions endpoint: set account_id on AISuggestion from current_user
- steps endpoint (/feedback): set account_id on StepRating from current_user
- ratings endpoint: set account_id on StepRating from current_user

Test infrastructure:
- conftest.py: seed PLATFORM_ACCOUNT_ID (00000000-...-0001) account after
  Base.metadata.create_all so global categories and gallery items have a valid FK
- test_rls_isolation: add _ensure_rls_schema fixture that runs
  'alembic upgrade head' before module tests — previous function-scoped
  test_db fixtures drop the schema, leaving the RLS tests with no tables
- test_branding: create Account before User in helper functions
- test_admin_gallery: set account_id=PLATFORM_ACCOUNT_ID on Tree/ScriptTemplate
- test_public_templates: set account_id=PLATFORM_ACCOUNT_ID on Tree,
  ScriptTemplate, TreeCategory
- test_resolution_outputs: set account_id=session.account_id on
  SessionResolutionOutput
- test_analytics_phase5: set account_id on PsaPostLog
- test_draft_trees: replace account_id=None with PLATFORM_ACCOUNT_ID in
  migration default test (NOT NULL now enforced)
- test_maintenance_schedules: set account_id on other_tree
- test_save_session_as_tree: set account_id on all 5 Session() constructors

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 04:24:36 +00:00

266 lines
8.9 KiB
Python

"""Tests for team branding endpoints (logo upload + company display name)."""
import uuid
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.security import get_password_hash
from app.models.account import Account
from app.models.team import Team
from app.models.user import User
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _create_team_with_admin(
test_db: AsyncSession,
client: AsyncClient,
*,
team_name: str = "Branding Test Team",
) -> tuple[dict, str, Team]:
"""Create a team + team admin user. Returns (auth_headers, team_id_str, team)."""
account = Account(name=team_name, display_code=uuid.uuid4().hex[:8].upper())
test_db.add(account)
team = Team(name=team_name)
test_db.add(team)
await test_db.flush()
email = f"admin_{uuid.uuid4().hex[:8]}@test.com"
user = User(
email=email,
password_hash=get_password_hash("Password123!"),
name="Team Admin",
is_active=True,
team_id=team.id,
is_team_admin=True,
role="engineer",
account_id=account.id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()
resp = await client.post(
"/api/v1/auth/login/json",
json={"email": email, "password": "Password123!"},
)
assert resp.status_code == 200
token = resp.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
return headers, str(team.id), team
async def _create_team_member(
test_db: AsyncSession,
client: AsyncClient,
team: Team,
*,
is_team_admin: bool = False,
) -> dict:
"""Create a regular team member. Returns auth_headers."""
# Look up the account associated with this team via an existing member
from sqlalchemy import select as _select
from app.models.user import User as _User
result = await test_db.execute(
_select(_User).where(_User.team_id == team.id).limit(1)
)
team_member = result.scalar_one_or_none()
member_account_id = team_member.account_id if team_member else None
email = f"member_{uuid.uuid4().hex[:8]}@test.com"
user = User(
email=email,
password_hash=get_password_hash("Password123!"),
name="Team Member",
is_active=True,
team_id=team.id,
is_team_admin=is_team_admin,
role="engineer",
account_id=member_account_id,
account_role="engineer",
)
test_db.add(user)
await test_db.commit()
resp = await client.post(
"/api/v1/auth/login/json",
json={"email": email, "password": "Password123!"},
)
assert resp.status_code == 200
token = resp.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_branding_defaults(client: AsyncClient, test_db: AsyncSession):
"""GET branding with no logo returns defaults (has_logo=False)."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
resp = await client.get(f"/api/v1/teams/{team_id}/branding", headers=headers)
assert resp.status_code == 200
data = resp.json()
assert data["has_logo"] is False
assert data["company_display_name"] is None
assert data["logo_content_type"] is None
@pytest.mark.asyncio
async def test_upload_logo_with_company_name(client: AsyncClient, test_db: AsyncSession):
"""PATCH with valid PNG logo + company name succeeds."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
# 1x1 transparent PNG (67 bytes)
png_bytes = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
b"\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89"
b"\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
b"\r\n\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
)
resp = await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=headers,
files={"logo": ("logo.png", png_bytes, "image/png")},
data={"company_display_name": "Acme MSP"},
)
assert resp.status_code == 200, resp.text
data = resp.json()
assert data["has_logo"] is True
assert data["logo_content_type"] == "image/png"
assert data["company_display_name"] == "Acme MSP"
@pytest.mark.asyncio
async def test_upload_oversized_logo(client: AsyncClient, test_db: AsyncSession):
"""PATCH with >2MB file returns 400."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
big_bytes = b"\x00" * (2 * 1024 * 1024 + 1) # 2MB + 1 byte
resp = await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=headers,
files={"logo": ("big.png", big_bytes, "image/png")},
)
assert resp.status_code == 400
assert "maximum size" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_upload_invalid_content_type(client: AsyncClient, test_db: AsyncSession):
"""PATCH with application/pdf content type returns 400."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
resp = await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=headers,
files={"logo": ("doc.pdf", b"%PDF-fake", "application/pdf")},
)
assert resp.status_code == 400
assert "content type" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_delete_logo(client: AsyncClient, test_db: AsyncSession):
"""DELETE logo clears logo_data while keeping company_display_name."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
# Upload a logo + name first
png_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50
await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=headers,
files={"logo": ("logo.png", png_bytes, "image/png")},
data={"company_display_name": "Keep This Name"},
)
# Delete logo
resp = await client.delete(
f"/api/v1/teams/{team_id}/branding/logo",
headers=headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["has_logo"] is False
assert data["logo_content_type"] is None
assert data["company_display_name"] == "Keep This Name"
@pytest.mark.asyncio
async def test_non_admin_cannot_update(client: AsyncClient, test_db: AsyncSession):
"""Regular team member (non-admin) cannot PATCH branding — returns 403."""
admin_headers, team_id, team = await _create_team_with_admin(test_db, client)
member_headers = await _create_team_member(test_db, client, team)
resp = await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=member_headers,
data={"company_display_name": "Should Fail"},
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_non_admin_cannot_delete_logo(client: AsyncClient, test_db: AsyncSession):
"""Regular team member cannot DELETE logo — returns 403."""
admin_headers, team_id, team = await _create_team_with_admin(test_db, client)
member_headers = await _create_team_member(test_db, client, team)
resp = await client.delete(
f"/api/v1/teams/{team_id}/branding/logo",
headers=member_headers,
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_non_member_cannot_read(client: AsyncClient, test_db: AsyncSession):
"""User from a different team cannot GET branding — returns 403."""
_, team_id, _ = await _create_team_with_admin(test_db, client, team_name="Team A")
other_headers, _, _ = await _create_team_with_admin(test_db, client, team_name="Team B")
resp = await client.get(
f"/api/v1/teams/{team_id}/branding",
headers=other_headers,
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_member_can_read_branding(client: AsyncClient, test_db: AsyncSession):
"""Regular team member CAN read branding."""
admin_headers, team_id, team = await _create_team_with_admin(test_db, client)
member_headers = await _create_team_member(test_db, client, team)
resp = await client.get(
f"/api/v1/teams/{team_id}/branding",
headers=member_headers,
)
assert resp.status_code == 200
assert resp.json()["has_logo"] is False
@pytest.mark.asyncio
async def test_update_display_name_only(client: AsyncClient, test_db: AsyncSession):
"""PATCH with only company_display_name (no logo) succeeds."""
headers, team_id, _ = await _create_team_with_admin(test_db, client)
resp = await client.patch(
f"/api/v1/teams/{team_id}/branding",
headers=headers,
data={"company_display_name": "Just A Name"},
)
assert resp.status_code == 200
data = resp.json()
assert data["company_display_name"] == "Just A Name"
assert data["has_logo"] is False