diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 9fbd5815..32ada630 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -235,6 +235,7 @@ _SUBSCRIPTION_GUARD_ALLOWLIST = { "/api/v1/billing/portal-session", "/api/v1/users/me", "/api/v1/users/me/onboarding-step", + "/api/v1/users/me/onboarding-dismiss-rest", } @@ -298,6 +299,8 @@ _EMAIL_VERIFICATION_ALLOWLIST = { "/api/v1/auth/email/verify", "/api/v1/auth/password/change", "/api/v1/users/me", + "/api/v1/users/me/onboarding-step", + "/api/v1/users/me/onboarding-dismiss-rest", "/api/v1/billing/state", "/api/v1/billing/checkout-session", "/api/v1/billing/portal-session", diff --git a/backend/app/api/endpoints/onboarding.py b/backend/app/api/endpoints/onboarding.py index 534f58a6..4cecf091 100644 --- a/backend/app/api/endpoints/onboarding.py +++ b/backend/app/api/endpoints/onboarding.py @@ -2,19 +2,24 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_active_user from app.core.database import get_db from app.core.admin_database import get_admin_db +from app.models.account import Account from app.models.assistant_chat import AssistantChat from app.models.psa_connection import PsaConnection from app.models.session import Session from app.models.tree import Tree from app.models.user import User -from app.schemas.onboarding import OnboardingStatus +from app.schemas.onboarding import ( + OnboardingStatus, + OnboardingStepRequest, + OnboardingStepResponse, +) router = APIRouter(prefix="/users", tags=["onboarding"]) @@ -109,3 +114,98 @@ async def dismiss_onboarding( # Return updated status (reuse the GET logic) return await get_onboarding_status(db=db, current_user=current_user) + + +# --------------------------------------------------------------------------- +# Welcome wizard endpoints (Phase 2) +# +# These persist Step 1/2/3 progress for the post-signup welcome wizard. +# Mounted on /users/me/* (the parent router prefix is /users) so the wizard +# can run before email verification and during trial. +# --------------------------------------------------------------------------- + + +@router.patch("/me/onboarding-step", response_model=OnboardingStepResponse) +async def patch_onboarding_step( + body: OnboardingStepRequest, + db: Annotated[AsyncSession, Depends(get_admin_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> OnboardingStepResponse: + """Persist welcome-wizard progress for the current user. + + Contract: + - step=1 + complete writes accounts.name, accounts.team_size_bucket, + users.role_at_signup, then sets users.onboarding_step_completed=1. + - step=2 + complete writes accounts.primary_psa, then sets + users.onboarding_step_completed=2. + - step=3 + complete just sets users.onboarding_step_completed=3 + (invites are POSTed separately). + - action="skip" ignores `data` entirely and only advances the step. + - The new step must be >= current onboarding_step_completed (None=>0); + otherwise 400. Idempotent re-PATCH of the same step succeeds. + """ + current_step = current_user.onboarding_step_completed or 0 + if body.step < current_step: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "step_cannot_decrease", + "current_step": current_step, + "requested_step": body.step, + }, + ) + + if body.action == "complete" and body.data is not None and body.step in (1, 2): + # Load the user's account for field writes. Step 3 has no data writes. + account_result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = account_result.scalar_one_or_none() + if account is None: + # Should never happen — user is required to have an account_id. + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="account_not_found", + ) + + if body.step == 1: + data = body.data + if data.company_name is not None: + account.name = data.company_name + if data.team_size_bucket is not None: + account.team_size_bucket = data.team_size_bucket + if data.role_at_signup is not None: + current_user.role_at_signup = data.role_at_signup + elif body.step == 2: + data = body.data + if data.primary_psa is not None: + account.primary_psa = data.primary_psa + + current_user.onboarding_step_completed = body.step + await db.commit() + await db.refresh(current_user) + + return OnboardingStepResponse( + onboarding_step_completed=current_user.onboarding_step_completed, + onboarding_dismissed=current_user.onboarding_dismissed, + ) + + +@router.post("/me/onboarding-dismiss-rest", response_model=OnboardingStepResponse) +async def dismiss_onboarding_rest( + db: Annotated[AsyncSession, Depends(get_admin_db)], + current_user: Annotated[User, Depends(get_current_active_user)], +) -> OnboardingStepResponse: + """Set users.onboarding_dismissed=TRUE — backs the wizard's "Skip the rest" button. + + Returns the same shape as the step PATCH so the frontend can update its + local store from a single response. + """ + current_user.onboarding_dismissed = True + await db.commit() + await db.refresh(current_user) + + return OnboardingStepResponse( + onboarding_step_completed=current_user.onboarding_step_completed, + onboarding_dismissed=current_user.onboarding_dismissed, + ) diff --git a/backend/app/schemas/onboarding.py b/backend/app/schemas/onboarding.py index d21647b5..303e1ceb 100644 --- a/backend/app/schemas/onboarding.py +++ b/backend/app/schemas/onboarding.py @@ -1,4 +1,6 @@ -from pydantic import BaseModel +from typing import Literal, Optional + +from pydantic import BaseModel, Field class OnboardingStatus(BaseModel): @@ -10,3 +12,40 @@ class OnboardingStatus(BaseModel): connected_psa: bool is_team_user: bool dismissed: bool + + +# --- Welcome wizard (Phase 2) ---------------------------------------------- + + +TeamSizeBucket = Literal["1-2", "3-5", "6-10", "11-25", "26+"] +RoleAtSignup = Literal["owner", "lead_tech", "tech", "other"] +PrimaryPsa = Literal["connectwise", "autotask", "halopsa", "none"] +WizardStep = Literal[1, 2, 3] +WizardAction = Literal["complete", "skip"] + + +class OnboardingStepData(BaseModel): + """Optional payload carried with `action="complete"` for steps 1 and 2. + + Step 1 fields: company_name, team_size_bucket, role_at_signup + Step 2 fields: primary_psa + Step 3 has no data (invitations posted separately). + """ + + # Step 1 + company_name: Optional[str] = Field(default=None, max_length=255) + team_size_bucket: Optional[TeamSizeBucket] = None + role_at_signup: Optional[RoleAtSignup] = None + # Step 2 + primary_psa: Optional[PrimaryPsa] = None + + +class OnboardingStepRequest(BaseModel): + step: WizardStep + action: WizardAction + data: Optional[OnboardingStepData] = None + + +class OnboardingStepResponse(BaseModel): + onboarding_step_completed: Optional[int] + onboarding_dismissed: bool diff --git a/backend/tests/test_onboarding_step.py b/backend/tests/test_onboarding_step.py new file mode 100644 index 00000000..eaf9a2a9 --- /dev/null +++ b/backend/tests/test_onboarding_step.py @@ -0,0 +1,149 @@ +"""Tests for welcome-wizard onboarding-step endpoints (Phase 2).""" + +import pytest +from sqlalchemy import select + +from app.models.account import Account +from app.models.user import User + + +@pytest.mark.asyncio +async def test_onboarding_step1_complete_writes_account_name_and_team_size_and_role( + client, auth_headers, test_db, test_user +): + """Step 1 + complete writes account.name + team_size_bucket + user.role_at_signup + and advances onboarding_step_completed to 1.""" + response = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={ + "step": 1, + "action": "complete", + "data": { + "company_name": "Acme MSP", + "team_size_bucket": "3-5", + "role_at_signup": "owner", + }, + }, + ) + assert response.status_code == 200, response.text + data = response.json() + assert data["onboarding_step_completed"] == 1 + assert data["onboarding_dismissed"] is False + + # Verify persisted writes + account_id = test_user["user_data"]["account_id"] + user_email = test_user["email"] + + acct = ( + await test_db.execute(select(Account).where(Account.id == account_id)) + ).scalar_one() + assert acct.name == "Acme MSP" + assert acct.team_size_bucket == "3-5" + + user = ( + await test_db.execute(select(User).where(User.email == user_email)) + ).scalar_one() + assert user.role_at_signup == "owner" + assert user.onboarding_step_completed == 1 + + +@pytest.mark.asyncio +async def test_onboarding_step2_skip_advances_without_psa( + client, auth_headers, test_db, test_user +): + """Step 2 + skip ignores data entirely and only advances the step counter + (no primary_psa write).""" + # Capture original account.primary_psa so we can assert it's untouched. + account_id = test_user["user_data"]["account_id"] + acct_before = ( + await test_db.execute(select(Account).where(Account.id == account_id)) + ).scalar_one() + psa_before = acct_before.primary_psa # likely None + + # Advance step 1 first so step 2 is allowed. + r1 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={"step": 1, "action": "skip"}, + ) + assert r1.status_code == 200, r1.text + + # Skip step 2 — even if data is present it must be ignored. + r2 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={ + "step": 2, + "action": "skip", + "data": {"primary_psa": "connectwise"}, + }, + ) + assert r2.status_code == 200, r2.text + assert r2.json()["onboarding_step_completed"] == 2 + + # Re-fetch account: primary_psa must NOT have been written. + test_db.expire_all() + acct_after = ( + await test_db.execute(select(Account).where(Account.id == account_id)) + ).scalar_one() + assert acct_after.primary_psa == psa_before + + +@pytest.mark.asyncio +async def test_onboarding_step_cannot_decrease(client, auth_headers): + """A step=2 PATCH followed by step=1 must return 400.""" + # Advance to step 2. + r1 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={"step": 1, "action": "skip"}, + ) + assert r1.status_code == 200, r1.text + r2 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={"step": 2, "action": "skip"}, + ) + assert r2.status_code == 200, r2.text + assert r2.json()["onboarding_step_completed"] == 2 + + # Try to go back to step 1 — must fail. + r3 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={"step": 1, "action": "skip"}, + ) + assert r3.status_code == 400, r3.text + + # Idempotent re-PATCH of same step succeeds. + r4 = await client.patch( + "/api/v1/users/me/onboarding-step", + headers=auth_headers, + json={"step": 2, "action": "skip"}, + ) + assert r4.status_code == 200, r4.text + assert r4.json()["onboarding_step_completed"] == 2 + + +@pytest.mark.asyncio +async def test_onboarding_dismiss_rest_sets_flag( + client, auth_headers, test_db, test_user +): + """POST /users/me/onboarding-dismiss-rest sets users.onboarding_dismissed=TRUE.""" + response = await client.post( + "/api/v1/users/me/onboarding-dismiss-rest", + headers=auth_headers, + ) + assert response.status_code == 200, response.text + data = response.json() + assert data["onboarding_dismissed"] is True + # step counter is whatever it was (None for a fresh user). + assert "onboarding_step_completed" in data + + # Verify persisted. + user_email = test_user["email"] + user = ( + await test_db.execute(select(User).where(User.email == user_email)) + ).scalar_one() + assert user.onboarding_dismissed is True