feat: self-serve signup Phase 2 (frontend cutover) (#162)
Some checks failed
CI / e2e (push) Has been cancelled
CI / frontend (push) Has been cancelled
CI / backend (push) Has been cancelled
Mirror to GitHub / mirror (push) Has been cancelled

Co-authored-by: Michael Chihlas <michael@resolutionflow.com>
Co-committed-by: Michael Chihlas <michael@resolutionflow.com>
This commit was merged in pull request #162.
This commit is contained in:
2026-05-07 18:42:20 +00:00
committed by chihlasm
parent f918b766b0
commit f1be3abcc5
123 changed files with 11563 additions and 559 deletions

View File

@@ -0,0 +1,290 @@
"""Tests for the public GET /accounts/invites/{code}/lookup endpoint
(consumed by the /accept-invite page on the frontend)."""
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from app.models.account_invite import AccountInvite
@pytest.mark.asyncio
async def test_invite_lookup_returns_account_info_for_valid_code(
client, test_db, test_user, auth_headers
):
"""A freshly-created, unused, unexpired invite resolves to the inviter's
account name + the inviter's display name + the invited email + role."""
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "lookup@example.com", "role": "engineer"},
headers=auth_headers,
)
assert create_resp.status_code == 201, create_resp.json()
code = create_resp.json()["code"]
response = await client.get(f"/api/v1/accounts/invites/{code}/lookup")
assert response.status_code == 200, response.json()
body = response.json()
assert body["invited_email"] == "lookup@example.com"
assert body["role"] == "engineer"
assert body["inviter_name"] == test_user["user_data"]["name"]
# account_name is whatever the test_user fixture seeded for the account.
assert isinstance(body["account_name"], str) and body["account_name"]
@pytest.mark.asyncio
async def test_invite_lookup_returns_404_for_invalid_or_expired_code(
client, test_db, test_user
):
"""Three failure modes (unknown code, expired, revoked, used) all collapse
to the same 404 + invite_invalid_or_expired_or_revoked error code."""
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
account_id = uuid.UUID(test_user["user_data"]["account_id"])
# 1) Unknown code
unknown = await client.get("/api/v1/accounts/invites/DOESNOTEXIST/lookup")
assert unknown.status_code == 404
assert unknown.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 2) Expired
expired_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="expired@example.com",
code="EXPIREDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
)
test_db.add(expired_invite)
await test_db.commit()
expired = await client.get("/api/v1/accounts/invites/EXPIREDLOOKUP01/lookup")
assert expired.status_code == 404
assert expired.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 3) Revoked
revoked_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="revoked@example.com",
code="REVOKEDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
revoked_at=datetime.now(timezone.utc),
)
test_db.add(revoked_invite)
await test_db.commit()
revoked = await client.get("/api/v1/accounts/invites/REVOKEDLOOKUP01/lookup")
assert revoked.status_code == 404
assert revoked.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# 4) Already used
used_invite = AccountInvite(
account_id=account_id,
invited_by_id=invited_by_id,
email="used@example.com",
code="USEDLOOKUP01",
role="engineer",
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
accepted_by_id=invited_by_id,
used_at=datetime.now(timezone.utc),
)
test_db.add(used_invite)
await test_db.commit()
used = await client.get("/api/v1/accounts/invites/USEDLOOKUP01/lookup")
assert used.status_code == 404
assert used.json()["detail"]["error"] == "invite_invalid_or_expired_or_revoked"
# Sanity: rows survived (no destructive side effects).
persisted = (
await test_db.execute(
select(AccountInvite).where(
AccountInvite.code.in_(
["EXPIREDLOOKUP01", "REVOKEDLOOKUP01", "USEDLOOKUP01"]
)
)
)
).scalars().all()
assert len(persisted) == 3
@pytest.mark.asyncio
async def test_oauth_callback_links_invite_when_account_invite_code_supplied(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Brand-new OAuth user with account_invite_code joins the invited account
instead of getting a personal one. Invite is marked used."""
from app.core.config import settings
from app.models.user import User
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "oauth-invite@example.com", "role": "engineer"},
headers=auth_headers,
)
code = create_resp.json()["code"]
inviter_account_id = uuid.UUID(test_user["user_data"]["account_id"])
profile = OAuthProfile(
provider_subject="google_invite_subject_1",
email="oauth-invite@example.com",
name="OAuth Invitee",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": "oauth-invite@example.com",
},
)
assert response.status_code == 200, response.json()
assert response.json()["is_new_user"] is True
user = (
await test_db.execute(
select(User).where(User.email == "oauth-invite@example.com")
)
).scalar_one()
assert user.account_id == inviter_account_id
assert user.account_role == "engineer"
invite = (
await test_db.execute(
select(AccountInvite).where(AccountInvite.code == code)
)
).scalar_one()
assert invite.used_at is not None
assert invite.accepted_by_id == user.id
@pytest.mark.asyncio
async def test_oauth_callback_existing_email_with_invite_returns_400(
client, test_db, test_user, auth_headers, monkeypatch
):
"""If a user already exists with the invited email (e.g., previously
registered via password), arriving via /accept-invite OAuth must NOT
silently link the OAuth identity to their existing account and skip the
invite. Surface email_already_registered_use_login so the user signs in
and accepts the invite from the dashboard instead."""
from app.core.config import settings
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
# 1) Pre-existing user with a password (separate from the inviter).
existing_email = "already-here@example.com"
register_resp = await client.post(
"/api/v1/auth/register",
json={
"email": existing_email,
"password": "PreviousPassword123!",
"name": "Already Here",
},
)
assert register_resp.status_code in (200, 201), register_resp.json()
# 2) Inviter creates an invite for that exact email.
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": existing_email, "role": "engineer"},
headers=auth_headers,
)
assert create_resp.status_code == 201, create_resp.json()
code = create_resp.json()["code"]
# 3) The existing user does Google OAuth and the callback receives the
# invite code. Backend must reject — not link silently.
profile = OAuthProfile(
provider_subject="google_existing_subject_1",
email=existing_email,
name="Already Here",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": existing_email,
},
)
assert response.status_code == 400, response.json()
assert (
response.json()["detail"]["error"] == "email_already_registered_use_login"
)
# 4) Sanity: the invite was NOT consumed.
invite = (
await test_db.execute(
select(AccountInvite).where(AccountInvite.code == code)
)
).scalar_one()
assert invite.used_at is None
assert invite.accepted_by_id is None
@pytest.mark.asyncio
async def test_oauth_callback_invite_email_mismatch_returns_400(
client, test_db, test_user, auth_headers, monkeypatch
):
"""If the OAuth profile's email differs from the invite's email, the
backend rejects the link with invite_email_mismatch (mirrors register)."""
from app.core.config import settings
from app.services.oauth_providers import OAuthProfile
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
with patch(
"app.core.email.EmailService.send_account_invite_email",
new_callable=AsyncMock,
return_value=True,
):
create_resp = await client.post(
"/api/v1/accounts/me/invites",
json={"email": "expected@example.com", "role": "engineer"},
headers=auth_headers,
)
code = create_resp.json()["code"]
profile = OAuthProfile(
provider_subject="google_invite_subject_2",
email="different@example.com",
name="Wrong Email",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback",
json={
"code": "auth_code_xyz",
"account_invite_code": code,
"invited_email": "expected@example.com",
},
)
assert response.status_code == 400, response.json()
assert response.json()["detail"]["error"] == "invite_email_mismatch"

View File

@@ -1,7 +1,12 @@
"""Integration tests for admin plan limits and account override endpoints."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from sqlalchemy import select
from app.models.plan_billing import PlanBilling
class TestAdminPlanLimits:
@@ -56,3 +61,204 @@ class TestAdminPlanLimits:
"""Non-admin gets 403."""
response = await client.get("/api/v1/admin/plan-limits", headers=auth_headers)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_admin_plan_limits_get_includes_plan_billing_fields_when_present(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""GET /admin/plan-limits returns plan_billing fields when a row exists,
and None for plans that don't have one yet."""
# Seed a plan_billing row for "pro".
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "pro")
)).scalar_one_or_none()
if existing is None:
test_db.add(PlanBilling(
plan="pro",
display_name="Pro",
description="For working teams",
monthly_price_cents=4900,
annual_price_cents=49000,
stripe_product_id="prod_seed",
stripe_monthly_price_id="price_seed_m",
stripe_annual_price_id="price_seed_a",
is_public=True,
is_archived=False,
sort_order=10,
))
await test_db.commit()
response = await client.get(
"/api/v1/admin/plan-limits", headers=admin_auth_headers
)
assert response.status_code == 200
plans_by_name = {p["plan"]: p for p in response.json()}
assert "pro" in plans_by_name
pro = plans_by_name["pro"]
assert pro["display_name"] == "Pro"
assert pro["monthly_price_cents"] == 4900
assert pro["stripe_monthly_price_id"] == "price_seed_m"
assert pro["is_public"] is True
assert pro["is_archived"] is False
assert pro["sort_order"] == 10
# A plan without a plan_billing row should still return, with None
# billing fields.
if "free" in plans_by_name:
free = plans_by_name["free"]
# free has no plan_billing row in the seed → fields are None.
no_billing_row = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "free")
)).scalar_one_or_none() is None
if no_billing_row:
assert free["display_name"] is None
assert free["monthly_price_cents"] is None
assert free["stripe_product_id"] is None
@pytest.mark.asyncio
async def test_admin_plan_limits_put_creates_plan_billing_row(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits upserts a plan_billing row when billing
fields are included in the body."""
# Ensure no plan_billing row exists for "team" yet.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text", "pdf"],
"display_name": "Team",
"description": "For growing shops",
"monthly_price_cents": 9900,
"annual_price_cents": 99000,
"stripe_product_id": "prod_team_test",
"stripe_monthly_price_id": "price_team_m",
"stripe_annual_price_id": "price_team_a",
"is_public": True,
"is_archived": False,
"sort_order": 20,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["display_name"] == "Team"
assert body["monthly_price_cents"] == 9900
assert body["stripe_product_id"] == "prod_team_test"
assert body["sort_order"] == 20
# Confirm the row was actually persisted.
await test_db.commit() # ensure session sees other-session writes
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team"
assert pb.monthly_price_cents == 9900
assert pb.stripe_monthly_price_id == "price_team_m"
assert pb.is_public is True
@pytest.mark.asyncio
async def test_admin_plan_limits_put_does_not_null_out_required_fields(
self, client: AsyncClient, admin_auth_headers: dict, test_db
):
"""PUT /admin/plan-limits must not NULL out NOT NULL columns on the
plan_billing row when the caller passes explicit nulls. The set of
guarded fields is {display_name, is_public, is_archived, sort_order}.
"""
# Seed a plan_billing row for "team" with non-default values for every
# NOT NULL field so we can detect any clobbering.
existing = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
if existing is not None:
await test_db.delete(existing)
await test_db.commit()
seeded = PlanBilling(
plan="team",
display_name="Team Seeded",
is_public=False,
is_archived=True,
sort_order=5,
)
test_db.add(seeded)
await test_db.commit()
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "team",
"max_trees": None,
"max_sessions_per_month": None,
"max_users": None,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
# Explicit nulls for every NOT NULL plan_billing field.
"display_name": None,
"is_public": None,
"is_archived": None,
"sort_order": None,
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
# Confirm the seeded NOT NULL values were preserved.
await test_db.commit() # ensure session sees writes from the request
pb = (await test_db.execute(
select(PlanBilling).where(PlanBilling.plan == "team")
)).scalar_one_or_none()
assert pb is not None
assert pb.display_name == "Team Seeded"
assert pb.is_public is False
assert pb.is_archived is True
assert pb.sort_order == 5
@pytest.mark.asyncio
async def test_admin_plan_limits_put_invalidates_billing_cache(
self, client: AsyncClient, admin_auth_headers: dict
):
"""PUT /admin/plan-limits calls BillingService.invalidate_billing_cache
with the account_ids on the affected plan."""
# Patch the staticmethod on the class. The endpoint imports
# BillingService at module load, so patch the symbol on the class
# itself — both the import and the dotted reference resolve to it.
with patch(
"app.api.endpoints.admin_plan_limits.BillingService.invalidate_billing_cache",
new_callable=AsyncMock,
) as spy:
response = await client.put(
"/api/v1/admin/plan-limits",
json={
"plan": "pro",
"max_trees": 25,
"max_sessions_per_month": 500,
"max_users": 10,
"custom_branding": True,
"priority_support": True,
"export_formats": ["markdown", "text"],
},
headers=admin_auth_headers,
)
assert response.status_code == 200, response.text
spy.assert_awaited_once()
(account_ids_arg,) = spy.await_args.args
# admin fixture seeds an active Pro Subscription, so we expect at
# least one account_id in the invalidation list.
assert isinstance(account_ids_arg, list)
assert len(account_ids_arg) >= 1

View File

@@ -0,0 +1,43 @@
"""Integration tests for the legacy /beta-signup redirect.
Phase 2 retires the public beta-signup form in favor of the regular
register flow. The endpoint stays mounted but answers with a 307 to
the absolute frontend `/register?from=beta` URL so any external links
keep working. There is no `beta_signup` table to migrate — the old
endpoint only fired an email notification — so this test only covers
the redirect contract.
"""
import pytest
from app.core.config import settings
@pytest.mark.asyncio
async def test_beta_signup_redirects_to_register(client, monkeypatch):
"""POST /beta-signup returns 307 to the absolute frontend register URL."""
monkeypatch.setattr(settings, "FRONTEND_URL", "https://example.com")
response = await client.post(
"/api/v1/beta-signup",
json={"email": "anyone@example.com"},
)
assert response.status_code == 307, response.text
assert (
response.headers["location"]
== "https://example.com/register?from=beta"
)
@pytest.mark.asyncio
async def test_beta_signup_redirect_ignores_body(client, monkeypatch):
"""Redirect fires regardless of payload — no validation on the legacy route."""
monkeypatch.setattr(settings, "FRONTEND_URL", "https://example.com")
response = await client.post("/api/v1/beta-signup", json={})
assert response.status_code == 307
assert (
response.headers["location"]
== "https://example.com/register?from=beta"
)

View File

@@ -0,0 +1,83 @@
import uuid
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import select
from app.models.account import Account
@pytest.mark.asyncio
async def test_billing_portal_returns_url_for_account_with_stripe_customer(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Happy path: account has a stripe_customer_id and Stripe is configured →
GET /billing/portal-session returns the portal URL."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
monkeypatch.setattr(settings, "FRONTEND_URL", "https://app.example.com")
account_id = uuid.UUID(test_user["user_data"]["account_id"])
account = (await test_db.execute(
select(Account).where(Account.id == account_id)
)).scalar_one()
account.stripe_customer_id = "cus_test_456"
await test_db.commit()
fake_session = MagicMock()
fake_session.url = "https://billing.stripe.com/p/session/test_abc"
with patch(
"stripe.billing_portal.Session.create",
return_value=fake_session,
) as portal_mock:
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 200, response.json()
assert response.json() == {"url": "https://billing.stripe.com/p/session/test_abc"}
portal_mock.assert_called_once()
call_kwargs = portal_mock.call_args.kwargs
assert call_kwargs["customer"] == "cus_test_456"
assert call_kwargs["return_url"] == "https://app.example.com/account/billing"
@pytest.mark.asyncio
async def test_billing_portal_returns_503_when_stripe_not_configured(
client, test_db, test_user, auth_headers, monkeypatch
):
"""STRIPE_SECRET_KEY unset → settings.stripe_enabled is False → 503."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", None)
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 503
assert response.json()["detail"]["error"] == "stripe_not_configured"
@pytest.mark.asyncio
async def test_billing_portal_returns_400_when_account_has_no_stripe_customer(
client, test_db, test_user, auth_headers, monkeypatch
):
"""Account with no stripe_customer_id (never completed checkout) → 400
with `no_stripe_customer` error."""
from app.core.config import settings
monkeypatch.setattr(settings, "STRIPE_SECRET_KEY", "sk_test_dummy")
# test_user fixture seeds an account with no stripe_customer_id by default.
account_id = uuid.UUID(test_user["user_data"]["account_id"])
account = (await test_db.execute(
select(Account).where(Account.id == account_id)
)).scalar_one()
assert account.stripe_customer_id is None
response = await client.get(
"/api/v1/billing/portal-session",
headers=auth_headers,
)
assert response.status_code == 400
assert response.json()["detail"]["error"] == "no_stripe_customer"

View File

@@ -0,0 +1,100 @@
"""Integration tests for the public runtime config endpoint.
Covers GET /api/v1/config/public and the SELF_SERVE_ENABLED interaction
with the existing /auth/register invite-code gate.
"""
from __future__ import annotations
import pytest
from httpx import AsyncClient
from app.core.config import settings
class TestConfigPublic:
"""GET /api/v1/config/public — anonymous, no auth."""
@pytest.mark.asyncio
async def test_get_config_public_returns_self_serve_flag(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Endpoint reflects the current SELF_SERVE_ENABLED setting and the
configured OAuth providers, with no auth required."""
# Default-off: SELF_SERVE_ENABLED is False unless explicitly set.
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", None)
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
body = response.json()
assert body == {"self_serve_enabled": False, "oauth_providers": []}
# Flip it on, with both OAuth providers configured.
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "google-test-id")
monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
body = response.json()
assert body["self_serve_enabled"] is True
assert body["oauth_providers"] == ["google", "microsoft"]
# Only Microsoft configured.
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
monkeypatch.setattr(settings, "MS_CLIENT_ID", "ms-test-id")
response = await client.get("/api/v1/config/public")
assert response.status_code == 200
assert response.json()["oauth_providers"] == ["microsoft"]
class TestRegisterInviteCodeGate:
"""Regression + new-behavior tests for /auth/register vs SELF_SERVE_ENABLED."""
@pytest.mark.asyncio
async def test_register_invite_code_required_when_self_serve_disabled(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Pre-self-serve behavior: REQUIRE_INVITE_CODE=True without an
invite code (and no account-invite) must still 400."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", False)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "no-invite@example.com",
"password": "SecurePass123!",
"name": "No Invite",
},
)
assert response.status_code == 400
assert "invite code is required" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_register_invite_code_optional_when_self_serve_enabled(
self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch
):
"""Self-serve on: registration succeeds with no invite code even
when REQUIRE_INVITE_CODE is True. The user, personal account, and
a Pro trial subscription are all created."""
monkeypatch.setattr(settings, "REQUIRE_INVITE_CODE", True)
monkeypatch.setattr(settings, "SELF_SERVE_ENABLED", True)
response = await client.post(
"/api/v1/auth/register",
json={
"email": "self-serve@example.com",
"password": "SecurePass123!",
"name": "Self Serve",
},
)
assert response.status_code == 201, response.text
body = response.json()
assert body["email"] == "self-serve@example.com"
assert body["account_role"] == "owner"
assert "account_id" in body

View File

@@ -2,8 +2,10 @@ import uuid
import pytest
from unittest.mock import patch
from sqlalchemy import select
from app.core.security import decode_token, hash_token
from app.models.user import User
from app.models.oauth_identity import OAuthIdentity
from app.models.refresh_token import RefreshToken
from app.models.subscription import Subscription
from app.services.oauth_providers import OAuthProfile
@@ -118,3 +120,77 @@ async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
)).scalar_one()
assert identity.provider == "microsoft"
@pytest.mark.asyncio
async def test_oauth_google_callback_stores_refresh_token_jti(
client, test_db, monkeypatch
):
"""A successful Google OAuth callback must persist the refresh-token JTI
in the refresh_tokens table — otherwise /auth/refresh rejects it."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_jti_test",
email="jtitest@example.com",
name="JTI Test",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
response = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert response.status_code == 200, response.json()
body = response.json()
refresh_token_str = body["refresh_token"]
payload = decode_token(refresh_token_str)
assert payload is not None
jti = payload["jti"]
token_hash = hash_token(jti)
user = (await test_db.execute(
select(User).where(User.email == "jtitest@example.com")
)).scalar_one()
stored = (await test_db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)).scalar_one_or_none()
assert stored is not None, "OAuth callback did not persist refresh-token JTI"
assert stored.user_id == user.id
assert stored.revoked_at is None
@pytest.mark.asyncio
async def test_oauth_refresh_works_after_oauth_signup(
client, test_db, monkeypatch
):
"""End-to-end: OAuth callback issues a refresh token; calling /auth/refresh
with that token must succeed (not be rejected as revoked)."""
from app.core.config import settings
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
profile = OAuthProfile(
provider_subject="google_subject_refresh_test",
email="refresh-after-oauth@example.com",
name="Refresh After OAuth",
)
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
callback_resp = await client.post(
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
)
assert callback_resp.status_code == 200, callback_resp.json()
refresh_token_str = callback_resp.json()["refresh_token"]
refresh_resp = await client.post(
"/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token_str}"},
)
assert refresh_resp.status_code == 200, refresh_resp.json()
refreshed = refresh_resp.json()
assert refreshed["access_token"]
assert refreshed["refresh_token"]
# Token rotation: new refresh token differs from the original.
assert refreshed["refresh_token"] != refresh_token_str

View File

@@ -1,6 +1,11 @@
"""Tests for onboarding status endpoints."""
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.models.user import User
@pytest.mark.asyncio
@@ -21,6 +26,42 @@ async def test_onboarding_status_fresh_user(client, auth_headers):
assert data["connected_psa"] is False
assert data["is_team_user"] is False
assert data["dismissed"] is False
# Phase 2 fields default to false on a fresh, unverified user with no wizard progress.
assert data["email_verified"] is False
assert data["shop_setup_done"] is False
@pytest.mark.asyncio
async def test_onboarding_status_includes_email_verified_and_shop_setup_done(
client, auth_headers, test_user, test_db
):
"""email_verified flips when email_verified_at is set; shop_setup_done flips at step >= 1."""
# Sanity-check baseline.
response = await client.get(
"/api/v1/users/onboarding-status",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["email_verified"] is False
assert data["shop_setup_done"] is False
# Mutate the underlying user, then re-fetch.
user_email = test_user["email"]
result = await test_db.execute(select(User).where(User.email == user_email))
user = result.scalar_one()
user.email_verified_at = datetime.now(tz=timezone.utc)
user.onboarding_step_completed = 1
await test_db.commit()
response = await client.get(
"/api/v1/users/onboarding-status",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["email_verified"] is True
assert data["shop_setup_done"] is True
@pytest.mark.asyncio

View File

@@ -0,0 +1,149 @@
"""Tests for welcome-wizard onboarding-step endpoints (Phase 2)."""
import pytest
from sqlalchemy import select
from app.models.account import Account
from app.models.user import User
@pytest.mark.asyncio
async def test_onboarding_step1_complete_writes_account_name_and_team_size_and_role(
client, auth_headers, test_db, test_user
):
"""Step 1 + complete writes account.name + team_size_bucket + user.role_at_signup
and advances onboarding_step_completed to 1."""
response = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={
"step": 1,
"action": "complete",
"data": {
"company_name": "Acme MSP",
"team_size_bucket": "3-5",
"role_at_signup": "owner",
},
},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["onboarding_step_completed"] == 1
assert data["onboarding_dismissed"] is False
# Verify persisted writes
account_id = test_user["user_data"]["account_id"]
user_email = test_user["email"]
acct = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
assert acct.name == "Acme MSP"
assert acct.team_size_bucket == "3-5"
user = (
await test_db.execute(select(User).where(User.email == user_email))
).scalar_one()
assert user.role_at_signup == "owner"
assert user.onboarding_step_completed == 1
@pytest.mark.asyncio
async def test_onboarding_step2_skip_advances_without_psa(
client, auth_headers, test_db, test_user
):
"""Step 2 + skip ignores data entirely and only advances the step counter
(no primary_psa write)."""
# Capture original account.primary_psa so we can assert it's untouched.
account_id = test_user["user_data"]["account_id"]
acct_before = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
psa_before = acct_before.primary_psa # likely None
# Advance step 1 first so step 2 is allowed.
r1 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r1.status_code == 200, r1.text
# Skip step 2 — even if data is present it must be ignored.
r2 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={
"step": 2,
"action": "skip",
"data": {"primary_psa": "connectwise"},
},
)
assert r2.status_code == 200, r2.text
assert r2.json()["onboarding_step_completed"] == 2
# Re-fetch account: primary_psa must NOT have been written.
test_db.expire_all()
acct_after = (
await test_db.execute(select(Account).where(Account.id == account_id))
).scalar_one()
assert acct_after.primary_psa == psa_before
@pytest.mark.asyncio
async def test_onboarding_step_cannot_decrease(client, auth_headers):
"""A step=2 PATCH followed by step=1 must return 400."""
# Advance to step 2.
r1 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r1.status_code == 200, r1.text
r2 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 2, "action": "skip"},
)
assert r2.status_code == 200, r2.text
assert r2.json()["onboarding_step_completed"] == 2
# Try to go back to step 1 — must fail.
r3 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 1, "action": "skip"},
)
assert r3.status_code == 400, r3.text
# Idempotent re-PATCH of same step succeeds.
r4 = await client.patch(
"/api/v1/users/me/onboarding-step",
headers=auth_headers,
json={"step": 2, "action": "skip"},
)
assert r4.status_code == 200, r4.text
assert r4.json()["onboarding_step_completed"] == 2
@pytest.mark.asyncio
async def test_onboarding_dismiss_rest_sets_flag(
client, auth_headers, test_db, test_user
):
"""POST /users/me/onboarding-dismiss-rest sets users.onboarding_dismissed=TRUE."""
response = await client.post(
"/api/v1/users/me/onboarding-dismiss-rest",
headers=auth_headers,
)
assert response.status_code == 200, response.text
data = response.json()
assert data["onboarding_dismissed"] is True
# step counter is whatever it was (None for a fresh user).
assert "onboarding_step_completed" in data
# Verify persisted.
user_email = test_user["email"]
user = (
await test_db.execute(select(User).where(User.email == user_email))
).scalar_one()
assert user.onboarding_dismissed is True

View File

@@ -0,0 +1,132 @@
"""Integration tests for the public plans endpoint.
Covers GET /api/v1/plans/public — the marketing /pricing page data source.
"""
from __future__ import annotations
import pytest
from httpx import AsyncClient
from sqlalchemy import delete
from app.models.plan_billing import PlanBilling
from app.models.plan_limits import PlanLimits
async def _seed_plan_limits(test_db, plan: str, max_users: int | None) -> None:
"""Ensure a plan_limits row exists for the given plan name."""
existing = await test_db.get(PlanLimits, plan)
if existing is None:
test_db.add(
PlanLimits(
plan=plan,
max_trees=None,
max_sessions_per_month=None,
max_users=max_users,
custom_branding=False,
priority_support=False,
export_formats=["markdown", "text"],
)
)
await test_db.commit()
class TestGetPlansPublic:
"""GET /api/v1/plans/public — anonymous, no auth."""
@pytest.mark.asyncio
async def test_get_plans_public_returns_only_is_public_rows(
self, client: AsyncClient, test_db
):
"""Rows with is_public=False or is_archived=True must NOT appear."""
# Wipe any existing billing rows so this test owns the fixture state.
await test_db.execute(delete(PlanBilling))
await test_db.commit()
await _seed_plan_limits(test_db, "starter", 3)
await _seed_plan_limits(test_db, "pro", 10)
await _seed_plan_limits(test_db, "internal", None)
await _seed_plan_limits(test_db, "legacy", 5)
test_db.add_all(
[
PlanBilling(
plan="starter",
display_name="Starter",
monthly_price_cents=1900,
is_public=True,
is_archived=False,
sort_order=10,
),
PlanBilling(
plan="pro",
display_name="Pro",
monthly_price_cents=4900,
is_public=True,
is_archived=False,
sort_order=20,
),
PlanBilling(
plan="internal",
display_name="Internal",
is_public=False, # hidden
is_archived=False,
sort_order=30,
),
PlanBilling(
plan="legacy",
display_name="Legacy",
is_public=True,
is_archived=True, # archived
sort_order=40,
),
]
)
await test_db.commit()
response = await client.get("/api/v1/plans/public")
assert response.status_code == 200
plans = response.json()
plan_names = {p["plan"] for p in plans}
assert "starter" in plan_names
assert "pro" in plan_names
assert "internal" not in plan_names
assert "legacy" not in plan_names
# Schema sanity check
starter = next(p for p in plans if p["plan"] == "starter")
assert starter["display_name"] == "Starter"
assert starter["monthly_price_cents"] == 1900
assert starter["max_seats"] == 3
assert starter["is_public"] is True
@pytest.mark.asyncio
async def test_get_plans_public_orders_by_sort_order_then_plan(
self, client: AsyncClient, test_db
):
"""Result must be ordered by sort_order ASC, then plan name ASC."""
await test_db.execute(delete(PlanBilling))
await test_db.commit()
# plan_limits rows for FK satisfaction
for name in ("alpha", "bravo", "charlie", "delta"):
await _seed_plan_limits(test_db, name, None)
# Two with sort_order=10 (charlie should come before delta by plan ASC),
# one with sort_order=5 (alpha first overall),
# one with sort_order=20 (bravo last).
test_db.add_all(
[
PlanBilling(plan="charlie", display_name="C", sort_order=10, is_public=True, is_archived=False),
PlanBilling(plan="delta", display_name="D", sort_order=10, is_public=True, is_archived=False),
PlanBilling(plan="alpha", display_name="A", sort_order=5, is_public=True, is_archived=False),
PlanBilling(plan="bravo", display_name="B", sort_order=20, is_public=True, is_archived=False),
]
)
await test_db.commit()
response = await client.get("/api/v1/plans/public")
assert response.status_code == 200
ordered = [p["plan"] for p in response.json()]
assert ordered == ["alpha", "charlie", "delta", "bravo"]

View File

@@ -0,0 +1,134 @@
"""Integration tests for the public Talk-to-Sales endpoint.
POST /api/v1/sales-leads — no auth, rate-limited 5/hour per IP.
"""
from unittest.mock import AsyncMock, patch
import pytest
import sqlalchemy as sa
@pytest.mark.asyncio
async def test_sales_lead_creates_row_and_sends_notification_email(client, test_db):
"""Happy path: row inserted, notification email fired, 201 returned."""
payload = {
"email": "buyer@acme.example",
"name": "Pat Buyer",
"company": "Acme MSP",
"team_size": "11-50",
"message": "We're evaluating ResolutionFlow for our NOC team.",
"source": "pricing_page",
"posthog_distinct_id": "ph_distinct_123",
}
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(return_value=True),
) as mock_email:
response = await client.post("/api/v1/sales-leads", json=payload)
assert response.status_code == 201, response.text
body = response.json()
assert body["status"] == "received"
assert "id" in body
# Notification email was attempted (asyncio.create_task — give it a tick).
import asyncio
await asyncio.sleep(0)
await asyncio.sleep(0)
assert mock_email.await_count == 1
kwargs = mock_email.await_args.kwargs
assert kwargs["to_email"] # default placeholder until cutover
assert kwargs["lead"].email == "buyer@acme.example"
assert kwargs["lead"].source == "pricing_page"
# Row was inserted with normalized email + all fields preserved.
result = await test_db.execute(
sa.text("SELECT email, name, company, team_size, message, source, posthog_distinct_id, status FROM sales_leads")
)
rows = result.all()
assert len(rows) == 1
row = rows[0]
assert row.email == "buyer@acme.example"
assert row.name == "Pat Buyer"
assert row.company == "Acme MSP"
assert row.team_size == "11-50"
assert row.message == "We're evaluating ResolutionFlow for our NOC team."
assert row.source == "pricing_page"
assert row.posthog_distinct_id == "ph_distinct_123"
assert row.status == "new"
@pytest.mark.asyncio
async def test_sales_lead_email_failure_does_not_fail_request(client, test_db):
"""If the email send raises, the API still returns 201 and the row persists."""
payload = {
"email": "buyer2@acme.example",
"name": "Sam Lead",
"company": "Acme MSP",
"source": "register_footer",
}
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(side_effect=RuntimeError("resend exploded")),
):
response = await client.post("/api/v1/sales-leads", json=payload)
assert response.status_code == 201, response.text
# Row must still be persisted even though email failed.
import asyncio
await asyncio.sleep(0)
result = await test_db.execute(
sa.text("SELECT count(*) FROM sales_leads WHERE email = 'buyer2@acme.example'")
)
assert result.scalar() == 1
@pytest.mark.asyncio
async def test_sales_lead_rate_limited_after_5_per_hour(client):
"""The 6th submission within an hour from the same IP returns 429.
The default `limiter` is disabled in tests (DEBUG=true). We re-enable it
for this test, then reset its state on teardown so other tests aren't
affected.
"""
from app.core.rate_limit import limiter
was_enabled = limiter.enabled
limiter.enabled = True
try:
limiter.reset()
with patch(
"app.api.endpoints.sales_leads.EmailService.send_sales_lead_notification",
new=AsyncMock(return_value=True),
):
for i in range(5):
payload = {
"email": f"lead{i}@acme.example",
"name": f"Lead {i}",
"company": "Acme MSP",
"source": "landing_page",
}
resp = await client.post("/api/v1/sales-leads", json=payload)
assert resp.status_code == 201, f"submission {i}: {resp.text}"
# 6th should be rate-limited.
resp = await client.post(
"/api/v1/sales-leads",
json={
"email": "lead6@acme.example",
"name": "Lead 6",
"company": "Acme MSP",
"source": "landing_page",
},
)
assert resp.status_code == 429, resp.text
finally:
limiter.reset()
limiter.enabled = was_enabled

View File

@@ -142,3 +142,178 @@ async def test_webhook_idempotency(
assert r2.status_code == 200
assert r1.json()["applied"] is True
assert r2.json()["applied"] is False
# ----------------------------------------------------------------------------
# Atomic-idempotency regression tests
# ----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_apply_event_handler_failure_does_not_persist_idempotency_mark(
test_db, test_user,
):
"""If the handler raises, the StripeEvent row must NOT be persisted —
otherwise Stripe's retry will be silently dropped as a duplicate and the
subscription state will desync from Stripe."""
from app.services.billing import BillingService
from app.models.stripe_event import StripeEvent
event_id = "evt_handler_fail_1"
payload = {"data": {"object": {
"id": "sub_doesnotmatter",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 1}]},
"cancel_at_period_end": False,
}}}
boom = RuntimeError("simulated handler failure")
with patch(
"app.services.billing._handle_subscription_updated",
side_effect=boom,
):
with pytest.raises(RuntimeError, match="simulated handler failure"):
await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
# The StripeEvent row must not exist — handler raised, the entire
# transaction (idempotency mark + partial mutations) was rolled back.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one_or_none()
assert row is None, (
"StripeEvent row was persisted despite handler failure — "
"Stripe retry will be silently dropped"
)
@pytest.mark.asyncio
async def test_apply_event_retry_after_failure_succeeds(
test_db, test_user,
):
"""A failed first attempt followed by a successful retry must apply state.
This is the core Stripe webhook retry contract."""
from app.services.billing import BillingService
from app.models.stripe_event import StripeEvent
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(
account_id=account_id, plan="pro", status="trialing",
stripe_subscription_id="sub_retry",
))
await test_db.commit()
event_id = "evt_retry_1"
payload = {"data": {"object": {
"id": "sub_retry",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 3}]},
"cancel_at_period_end": False,
}}}
# First attempt — handler raises mid-flight.
with patch(
"app.services.billing._handle_subscription_updated",
side_effect=RuntimeError("transient blip"),
):
with pytest.raises(RuntimeError):
await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
# No idempotency mark, sub still trialing.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one_or_none()
assert row is None
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)).scalar_one()
assert sub.status == "trialing"
# Second attempt — same event_id, handler succeeds.
applied = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied is True
# Idempotency mark now persisted, sub state reconciled.
row = (await test_db.execute(
select(StripeEvent).where(StripeEvent.id == event_id)
)).scalar_one()
assert row.id == event_id
await test_db.refresh(sub)
assert sub.status == "active"
assert sub.seat_limit == 3
@pytest.mark.asyncio
async def test_apply_event_duplicate_event_id_skips(
test_db, test_user,
):
"""Two successful invocations with the same event_id must not double-apply.
Second call returns False; mutations are not repeated."""
from app.services.billing import BillingService
account_id = uuid.UUID(test_user["user_data"]["account_id"])
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
test_db.add(Subscription(
account_id=account_id, plan="pro", status="trialing",
stripe_subscription_id="sub_dup",
))
await test_db.commit()
event_id = "evt_dedupe_1"
payload = {"data": {"object": {
"id": "sub_dup",
"status": "active",
"current_period_start": 1714521600,
"current_period_end": 1717113600,
"items": {"data": [{"quantity": 7}]},
"cancel_at_period_end": False,
}}}
applied1 = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied1 is True
sub = (await test_db.execute(
select(Subscription).where(Subscription.account_id == account_id)
)).scalar_one()
assert sub.status == "active"
assert sub.seat_limit == 7
# Mutate locally so we can prove the second call doesn't re-run the handler.
sub.seat_limit = 99
await test_db.commit()
applied2 = await BillingService.apply_subscription_event(
test_db,
event_id=event_id,
event_type="customer.subscription.updated",
payload=payload,
)
assert applied2 is False
await test_db.refresh(sub)
# Handler did NOT run again — our local mutation is preserved.
assert sub.seat_limit == 99