Compare commits
10 Commits
9b709488d9
...
fbb41e789c
| Author | SHA1 | Date | |
|---|---|---|---|
| fbb41e789c | |||
| 97d36dd400 | |||
| f26f468878 | |||
| 79942c3fd3 | |||
| 4768ae0648 | |||
| e54d6c586a | |||
| 86893562b9 | |||
| b0708ed650 | |||
| 2ef2350de7 | |||
| f4606f073a |
@@ -2,35 +2,35 @@
|
||||
|
||||
# HANDOFF.md
|
||||
|
||||
**Last updated:** 2026-05-02 (post-PR-159 — guides Diátaxis rewrite merged into `main`)
|
||||
**Last updated:** 2026-05-06 (Phase 1 backend complete on `feat/self-serve-signup-spec`)
|
||||
|
||||
**Active task:** None. Pick next from `.ai/TODO.md` or `03-DEVELOPMENT-ROADMAP.md`.
|
||||
|
||||
**Just-merged:** PR #159 — In-product User Guides rewritten as 43 Diátaxis how-tos under 10 categories. Drops 3 deprecated guides, renames Step Library → Solutions Library, fixes tip-markdown rendering, adds 14 net-new how-tos for FlowPilot-era surfaces.
|
||||
**Active task:** Phase 1 self-serve signup backend foundation — DONE on branch. PR not yet opened.
|
||||
|
||||
## Where this session ended
|
||||
|
||||
PR #159 merged into `main`. CHANGELOG, CURRENT_TASK, SESSION_LOG all updated. See `.ai/CURRENT_TASK.md` "Recently shipped" for the structured rollup.
|
||||
24 commits on top of `main` (`31ca3fb`). All 26 tasks from `docs/superpowers/plans/2026-05-06-self-serve-signup-phase-1-backend.md` complete. Full pytest run is green (1167 passed, 35 deselected). Single alembic head: `c6cbfc534fad`.
|
||||
|
||||
The 43 guides live at `/guides` in the app. Schema is now category-aware (`Guide.category`, optional `relatedSlugs`); `categories` const drives hub ordering. Browser-verified against engineer + owner test users (sidebar labels, account sub-pages, pilot-screen header buttons, Tasks panel, integration form). tsc and Vite build clean.
|
||||
Phase 1 covered: schema additions (oauth_identities, plan_billing, sales_leads, stripe_events, plus 5 new columns across users/accounts/account_invites), Subscription complimentary status + has_pro_entitlement, the two new guards (`require_active_subscription`, `require_verified_email_after_grace`), full BillingService (start_trial / create_checkout_session / apply_subscription_event / get_billing_state), Stripe webhook handler, Google + Microsoft OAuth callbacks with oauth_identities linking, OAuth-only password guard, register-time verification email + invite email-match, bulk + soft-revoke invite routes, GET /billing/state, and the pilot complimentary backfill migration.
|
||||
|
||||
The conftest's `test_user` fixture was modified to seed a Pro/active Subscription post-register (delete-then-insert) so the new subscription guard doesn't 402 every existing test. Two existing tests adapted because they explicitly assumed the old free-plan default: `test_subscription_limits.py` (the two free-plan tests now downgrade inline) and `test_kb_accelerator.py::TestQuota::test_get_quota` (the `kb_setup` fixture downgrades to free).
|
||||
|
||||
## Resume point — DO THIS NEXT
|
||||
|
||||
The issue cleanup plan continues from before this session. Pick up `docs/plans/2026-05-01-issue-cleanup-plan.md` at section 3: **#58 structured "step is wrong" quality signals**. Then section 4 (#60 recurring issue detection) and section 5 (#129 hierarchical guide navigation).
|
||||
|
||||
`$GITEA_TOKEN` is in `.claude/settings.local.json` — confirmed working via the PR-creation API call this session. Issue tracker actions can be done from the code-server LXC via `curl` against `https://gitea.resolutionflow.com/api/v1/...`.
|
||||
1. Open the PR for branch `feat/self-serve-signup-spec`. Use `gh pr create` against `main`. Suggested title: `feat: self-serve signup backend (Phase 1)`. Body should mention dark-launch posture (every new endpoint is gated by env config, not a feature flag — see Task 26 §3 in the plan).
|
||||
2. Phase 2 (frontend + cutover) lives in a sibling plan: `docs/superpowers/plans/2026-05-06-self-serve-signup-phase-2-frontend.md` (assumed; verify path). It's the next logical task once Phase 1 ships.
|
||||
|
||||
## Followups deferred from this session
|
||||
|
||||
Worth picking up if a related touch happens:
|
||||
|
||||
- **`change-teammate-role` how-to was dropped** from PR #159 because the test owner account has no non-owner members to inspect the role-change control. Once a teammate is invited via the Membership form on `/account`, verify whether the list exposes a Role dropdown (or some other control) for non-owners and add the guide back to `frontend/src/data/guides.ts` under the `account-admin` category.
|
||||
- **Resolve / Escalate modal contents are unverified.** Browser couldn't drive Resolve to completion (test session has 6 pending Tasks gating it; clicking Resolve fired a toast). The how-tos point at the right buttons in the right place, but the exact modal copy and the Escalation Mode wedge specifics are based on project context, not live observation. Worth a quick spot-check the next time a clean test session is available.
|
||||
- **OAuth callbacks don't call `_store_refresh_token`.** The Google/Microsoft callbacks issue a refresh JWT but never persist its hash to `refresh_tokens` (the password-login flow does via `auth.py:_store_refresh_token`). Result: refresh-token revocation/rotation lookups won't find OAuth-issued tokens. Decide before Phase 2 dark-launch whether to backfill — likely yes, by extracting `_store_refresh_token` to a shared module and calling it from `_sign_in_or_register`.
|
||||
- **`stripe_enabled` was relaxed** in Task 14 from `bool(STRIPE_SECRET_KEY) and bool(STRIPE_WEBHOOK_SECRET)` to just the secret key. The webhook handler in Task 16 independently checks `STRIPE_WEBHOOK_SECRET` before calling `construct_event`, so signature verification is still safe — but if any other code reads `stripe_enabled` and assumes the webhook secret is set, that's a latent bug. Audit before Phase 2 cutover.
|
||||
- **`backend/app/core/stripe_handlers.py`** is a stub module that's no longer referenced after Task 16. Safe to delete in a follow-up; left in place to keep Phase 1 diff focused.
|
||||
- **Pilot backfill migration `c6cbfc534fad` has not been applied to prod yet.** It runs once at deploy time and is forward-only.
|
||||
|
||||
## Environment notes (carry-forward)
|
||||
|
||||
- Code-server LXC has bun + docker but no native `python`/`node`/`npm`. Use `docker exec resolutionflow_{backend,frontend} …` for build/test commands.
|
||||
- No `gh` CLI on this LXC — use the Gitea API (`$GITEA_TOKEN`) for PR/issue work, or run `gh` from a host that has it.
|
||||
- Code-server LXC has bun + docker but no native python/node/npm. Use `docker exec resolutionflow_{backend,frontend} ...` for build/test commands.
|
||||
- Pytest WORKDIR is `/app` — test paths in pytest commands are `tests/<file>`, NOT `backend/tests/<file>`.
|
||||
- Backend pytest cmd: `docker exec resolutionflow_backend pytest tests/<path> -v --override-ini="addopts="`. The full run takes ~25 min.
|
||||
- Alembic via `docker exec -w /app resolutionflow_backend alembic ...`. Never pass `--rev-id`.
|
||||
- No `gh` CLI on this LXC — use the Gitea API (`$GITEA_TOKEN` in `.claude/settings.local.json`) for PR/issue work, or run `gh` from a host that has it.
|
||||
- Headless Chromium (`/qa`, `/browse`) needs `CONTAINER=1` in the env launching the browse server (LXC namespace constraint).
|
||||
- `/etc/hosts` has `100.64.78.44 docker-01` so the headless browser resolves the bake-in `VITE_API_URL`. The previous handoff claimed this entry was persistent but it was missing on this LXC at the start of this session — re-added via `sudo tee` from a real terminal (the `!` shell prefix can't drive interactive sudo). Confirmed working.
|
||||
- Multi-head alembic state on `main` (heads `070`, `c0f3a4b7e91d`, `024`) is pre-existing. Use `alembic upgrade heads` (plural) if `head` complains.
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""subscriptions pilot complimentary backfill
|
||||
|
||||
This migration converts existing pilot/dev accounts to permanent complimentary
|
||||
Pro per the self-serve signup spec section 5. Forward-only; downgrade is
|
||||
prohibited because original status is not preserved.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "c6cbfc534fad"
|
||||
down_revision: Union[str, None] = "c982a3fc4bf1"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Set status='complimentary' and plan='pro' for all existing accounts that
|
||||
don't have a canceled or past_due subscription. Pilot users transition to
|
||||
permanent complimentary Pro per spec section 5.
|
||||
|
||||
Forward-only — does not preserve original status values."""
|
||||
conn = op.get_bind()
|
||||
# Update existing rows
|
||||
conn.execute(sa.text("""
|
||||
UPDATE subscriptions
|
||||
SET status = 'complimentary', plan = 'pro',
|
||||
current_period_end = NULL, current_period_start = NULL,
|
||||
updated_at = now()
|
||||
WHERE status NOT IN ('canceled', 'past_due')
|
||||
"""))
|
||||
# Backfill: any account without a Subscription row gets one
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO subscriptions (id, account_id, plan, status, cancel_at_period_end, created_at, updated_at)
|
||||
SELECT gen_random_uuid(), a.id, 'pro', 'complimentary', false, now(), now()
|
||||
FROM accounts a
|
||||
WHERE NOT EXISTS (SELECT 1 FROM subscriptions s WHERE s.account_id = a.id)
|
||||
"""))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
raise RuntimeError(
|
||||
"Cannot downgrade: original subscription state is not preserved. "
|
||||
"Restore from backup if needed."
|
||||
)
|
||||
@@ -19,7 +19,7 @@ from app.models.account_invite import AccountInvite
|
||||
from app.models.account_settings import AccountSettings
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, TransferOwnershipRequest
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, AccountInviteBulkCreate, AccountInviteBulkResponse, TransferOwnershipRequest
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.core.security import verify_password
|
||||
@@ -260,7 +260,7 @@ async def create_invite(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create an invite to join this account (owner only)."""
|
||||
"""Create an invite to join this account (owner only). Sends invite email."""
|
||||
code = secrets.token_urlsafe(16)
|
||||
|
||||
expires_at = None
|
||||
@@ -276,11 +276,109 @@ async def create_invite(
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.flush()
|
||||
|
||||
# Lookup account name for email
|
||||
account_result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = account_result.scalar_one()
|
||||
|
||||
# Send invite email — non-blocking on failure (function returns False on error)
|
||||
email_sent = await EmailService.send_account_invite_email(
|
||||
to_email=invite.email,
|
||||
code=code,
|
||||
account_name=account.name,
|
||||
role=invite.role,
|
||||
)
|
||||
if email_sent:
|
||||
invite.email_sent_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
return invite
|
||||
|
||||
|
||||
@router.post("/me/invites/bulk", response_model=AccountInviteBulkResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_invites_bulk(
|
||||
payload: AccountInviteBulkCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Create multiple invites in one call (wizard step 3 supports up to N).
|
||||
Per-row failures are returned in `failed`; successes in `created`."""
|
||||
# Lookup account once for email rendering
|
||||
account_result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = account_result.scalar_one()
|
||||
|
||||
created: list[AccountInvite] = []
|
||||
failed: list[dict] = []
|
||||
for invite_data in payload.invites:
|
||||
try:
|
||||
code = secrets.token_urlsafe(16)
|
||||
expires_at = None
|
||||
if invite_data.expires_in_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=invite_data.expires_in_days)
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=current_user.account_id,
|
||||
invited_by_id=current_user.id,
|
||||
email=invite_data.email,
|
||||
code=code,
|
||||
role=invite_data.role,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.flush()
|
||||
|
||||
email_sent = await EmailService.send_account_invite_email(
|
||||
to_email=invite.email,
|
||||
code=code,
|
||||
account_name=account.name,
|
||||
role=invite.role,
|
||||
)
|
||||
if email_sent:
|
||||
invite.email_sent_at = datetime.now(timezone.utc)
|
||||
|
||||
created.append(invite)
|
||||
except Exception as e:
|
||||
failed.append({"email": invite_data.email, "error": str(e)})
|
||||
|
||||
await db.commit()
|
||||
for inv in created:
|
||||
await db.refresh(inv)
|
||||
|
||||
return AccountInviteBulkResponse(created=created, failed=failed)
|
||||
|
||||
|
||||
@router.delete("/me/invites/{invite_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def revoke_invite(
|
||||
invite_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Soft-revoke an invitation by setting revoked_at. Idempotent on already-
|
||||
revoked invites; rejects already-accepted invites."""
|
||||
result = await db.execute(
|
||||
select(AccountInvite).where(
|
||||
AccountInvite.id == invite_id,
|
||||
AccountInvite.account_id == current_user.account_id,
|
||||
)
|
||||
)
|
||||
invite = result.scalar_one_or_none()
|
||||
if not invite:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invite not found")
|
||||
if invite.is_revoked:
|
||||
return None # idempotent
|
||||
if invite.is_used:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke an accepted invite")
|
||||
invite.revoked_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/me/invites/{invite_id}/resend", response_model=AccountInviteResponse)
|
||||
async def resend_invite(
|
||||
invite_id: UUID,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone, timedelta
|
||||
@@ -41,6 +42,8 @@ from app.core.email import EmailService
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
from app.core.audit import log_audit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@@ -62,6 +65,22 @@ def _generate_display_code() -> str:
|
||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
async def _reject_if_oauth_only(db: AsyncSession, user) -> None:
|
||||
"""If the user has no password_hash, raise 400 with a list of linked
|
||||
providers so the client can redirect them to the right OAuth flow."""
|
||||
if user is None or user.password_hash is not None:
|
||||
return
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
result = await db.execute(
|
||||
select(OAuthIdentity.provider).where(OAuthIdentity.user_id == user.id)
|
||||
)
|
||||
providers = [row for row in result.scalars().all()]
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "use_oauth_provider", "providers": providers},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("3/minute")
|
||||
async def register(
|
||||
@@ -108,6 +127,12 @@ async def register(
|
||||
detail="Account invite code has expired"
|
||||
)
|
||||
|
||||
if account_invite_record.email.lower() != user_data.email.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "invite_email_mismatch"},
|
||||
)
|
||||
|
||||
# Validate platform invite code (skip if account invite was provided)
|
||||
invite_code_record = None
|
||||
if not account_invite_record:
|
||||
@@ -228,6 +253,34 @@ async def register(
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# Auto-send verification email for newly-registered users.
|
||||
# Skip silently if verification already done (shouldn't happen for fresh
|
||||
# users, but defensive).
|
||||
if new_user.email_verified_at is None:
|
||||
verification_enabled = await SettingsManager.get(
|
||||
"email_verification_enabled", db, default=True
|
||||
)
|
||||
if verification_enabled:
|
||||
try:
|
||||
raw_token = create_email_verification_token(str(new_user.id))
|
||||
payload = decode_token(raw_token)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = EmailVerificationToken(
|
||||
token_hash=hash_token(payload["jti"]),
|
||||
user_id=new_user.id,
|
||||
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||
)
|
||||
db.add(token_record)
|
||||
await db.commit()
|
||||
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
|
||||
await EmailService.send_email_verification_email(
|
||||
to_email=new_user.email,
|
||||
verification_url=verification_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("verification email send failed for %s: %s", new_user.email, e)
|
||||
|
||||
return new_user
|
||||
|
||||
|
||||
@@ -243,6 +296,7 @@ async def login(
|
||||
result = await db.execute(select(User).where(User.email == form_data.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
await _reject_if_oauth_only(db, user)
|
||||
if not user or not verify_password(form_data.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -280,6 +334,7 @@ async def login_json(
|
||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
await _reject_if_oauth_only(db, user)
|
||||
if not user or not verify_password(credentials.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -445,6 +500,7 @@ async def change_password(
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
await _reject_if_oauth_only(db, current_user)
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -488,7 +544,7 @@ async def forgot_password(
|
||||
result = await db.execute(select(User).where(User.email == data.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
if user and user.password_hash is not None:
|
||||
# Create reset token JWT
|
||||
raw_token = create_password_reset_token(str(user.id))
|
||||
payload = decode_token(raw_token)
|
||||
|
||||
@@ -9,7 +9,11 @@ from app.core.admin_database import get_admin_db
|
||||
from app.core.config import settings
|
||||
from app.models.account import Account
|
||||
from app.models.user import User
|
||||
from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse
|
||||
from app.schemas.billing import (
|
||||
BillingStateResponse,
|
||||
CheckoutSessionCreate,
|
||||
CheckoutSessionResponse,
|
||||
)
|
||||
from app.services.billing import BillingService
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
@@ -34,3 +38,15 @@ async def create_checkout_session(
|
||||
cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan",
|
||||
)
|
||||
return CheckoutSessionResponse(url=url)
|
||||
|
||||
|
||||
@router.get("/state", response_model=BillingStateResponse)
|
||||
async def get_billing_state(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> BillingStateResponse:
|
||||
account = (await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)).scalar_one()
|
||||
state = await BillingService.get_billing_state(db, account)
|
||||
return BillingStateResponse(**state)
|
||||
|
||||
123
backend/app/api/endpoints/oauth.py
Normal file
123
backend/app/api/endpoints/oauth.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.admin_database import get_admin_db
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.account import Account
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
|
||||
from app.services.billing import BillingService
|
||||
from app.services.oauth_providers import (
|
||||
google_exchange_code,
|
||||
microsoft_exchange_code,
|
||||
OAuthProfile,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth-oauth"])
|
||||
|
||||
|
||||
def _generate_display_code(length: int = 8) -> str:
|
||||
"""Match the helper used by /auth/register — A-Z + 0-9, length 8."""
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
async def _sign_in_or_register(
|
||||
db: AsyncSession, provider: str, profile: OAuthProfile
|
||||
) -> tuple[User, bool]:
|
||||
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
|
||||
identity = (
|
||||
await db.execute(
|
||||
select(OAuthIdentity).where(
|
||||
OAuthIdentity.provider == provider,
|
||||
OAuthIdentity.provider_subject == profile.provider_subject,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if identity:
|
||||
user = (
|
||||
await db.execute(select(User).where(User.id == identity.user_id))
|
||||
).scalar_one()
|
||||
return user, False
|
||||
|
||||
user = (
|
||||
await db.execute(select(User).where(User.email == profile.email))
|
||||
).scalar_one_or_none()
|
||||
is_new_user = user is None
|
||||
if is_new_user:
|
||||
account = Account(
|
||||
name=f"{profile.name}'s Account",
|
||||
display_code=_generate_display_code(),
|
||||
)
|
||||
db.add(account)
|
||||
await db.flush()
|
||||
user = User(
|
||||
email=profile.email,
|
||||
name=profile.name,
|
||||
password_hash=None,
|
||||
account_id=account.id,
|
||||
account_role="owner",
|
||||
role="engineer",
|
||||
email_verified_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
account.owner_id = user.id
|
||||
await db.flush()
|
||||
# start_trial commits internally; flushed account/user above.
|
||||
await BillingService.start_trial(db, account.id)
|
||||
|
||||
db.add(
|
||||
OAuthIdentity(
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
provider_subject=profile.provider_subject,
|
||||
provider_email_at_link=profile.email,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user, is_new_user
|
||||
|
||||
|
||||
@router.post("/google/callback", response_model=OAuthCallbackResponse)
|
||||
async def google_callback(
|
||||
payload: OAuthCallbackPayload,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> OAuthCallbackResponse:
|
||||
if not settings.GOOGLE_CLIENT_ID:
|
||||
raise HTTPException(status_code=503, detail="Google sign-in not configured")
|
||||
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
|
||||
profile = await google_exchange_code(payload.code, redirect_uri)
|
||||
user, is_new = await _sign_in_or_register(db, "google", profile)
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
is_new_user=is_new,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/microsoft/callback", response_model=OAuthCallbackResponse)
|
||||
async def microsoft_callback(
|
||||
payload: OAuthCallbackPayload,
|
||||
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||
) -> OAuthCallbackResponse:
|
||||
if not settings.MS_CLIENT_ID:
|
||||
raise HTTPException(status_code=503, detail="Microsoft sign-in not configured")
|
||||
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback"
|
||||
profile = await microsoft_exchange_code(payload.code, redirect_uri)
|
||||
user, is_new = await _sign_in_or_register(db, "microsoft", profile)
|
||||
return OAuthCallbackResponse(
|
||||
access_token=create_access_token({"sub": str(user.id)}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
is_new_user=is_new,
|
||||
)
|
||||
@@ -41,6 +41,7 @@ from app.api.endpoints import (
|
||||
maintenance_schedules,
|
||||
network_diagrams,
|
||||
notifications,
|
||||
oauth as oauth_endpoints,
|
||||
onboarding,
|
||||
public_templates,
|
||||
ratings,
|
||||
@@ -82,6 +83,7 @@ api_router = APIRouter()
|
||||
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
||||
# ---------------------------------------------------------------------------
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(oauth_endpoints.router)
|
||||
api_router.include_router(billing.router) # Reachable when subscription locked
|
||||
api_router.include_router(shared.router) # Public share links (no auth)
|
||||
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
||||
|
||||
@@ -194,6 +194,13 @@ class Settings(BaseSettings):
|
||||
"""Check if ConnectWise integration is configured."""
|
||||
return self.CW_CLIENT_ID is not None
|
||||
|
||||
# OAuth providers (self-serve signup)
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None
|
||||
MS_CLIENT_ID: Optional[str] = None
|
||||
MS_CLIENT_SECRET: Optional[str] = None
|
||||
OAUTH_REDIRECT_BASE: str = "http://localhost:5173"
|
||||
|
||||
# Monitoring
|
||||
SENTRY_DSN: Optional[str] = None
|
||||
|
||||
|
||||
@@ -42,3 +42,12 @@ class AccountInviteResponse(BaseModel):
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AccountInviteBulkCreate(BaseModel):
|
||||
invites: list[AccountInviteCreate]
|
||||
|
||||
|
||||
class AccountInviteBulkResponse(BaseModel):
|
||||
created: list[AccountInviteResponse]
|
||||
failed: list[dict] # entries shaped {"email": str, "error": str}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -10,3 +11,30 @@ class CheckoutSessionCreate(BaseModel):
|
||||
|
||||
class CheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class SubscriptionState(BaseModel):
|
||||
status: str
|
||||
plan: str
|
||||
current_period_start: Optional[datetime]
|
||||
current_period_end: Optional[datetime]
|
||||
cancel_at_period_end: bool
|
||||
seat_limit: Optional[int]
|
||||
has_pro_entitlement: bool
|
||||
is_paid: bool
|
||||
|
||||
|
||||
class PlanBillingState(BaseModel):
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
monthly_price_cents: Optional[int] = None
|
||||
annual_price_cents: Optional[int] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class BillingStateResponse(BaseModel):
|
||||
subscription: SubscriptionState
|
||||
plan_billing: Optional[PlanBillingState]
|
||||
plan_limits: Dict[str, Any]
|
||||
enabled_features: Dict[str, bool]
|
||||
|
||||
13
backend/app/schemas/oauth.py
Normal file
13
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OAuthCallbackPayload(BaseModel):
|
||||
code: str
|
||||
state: str | None = None
|
||||
|
||||
|
||||
class OAuthCallbackResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
is_new_user: bool
|
||||
@@ -105,6 +105,61 @@ class BillingService:
|
||||
)
|
||||
return session.url
|
||||
|
||||
@staticmethod
|
||||
async def get_billing_state(db: AsyncSession, account):
|
||||
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
|
||||
flags for the account."""
|
||||
from app.models.plan_limits import PlanLimits
|
||||
from app.models.plan_billing import PlanBilling
|
||||
from app.models.feature_flag import (
|
||||
FeatureFlag, PlanFeatureDefault, AccountFeatureOverride,
|
||||
)
|
||||
|
||||
sub = (await db.execute(
|
||||
select(Subscription).where(Subscription.account_id == account.id)
|
||||
)).scalar_one_or_none()
|
||||
if sub is None:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="No subscription for account")
|
||||
|
||||
pl = (await db.execute(
|
||||
select(PlanLimits).where(PlanLimits.plan == sub.plan)
|
||||
)).scalar_one_or_none()
|
||||
pb = (await db.execute(
|
||||
select(PlanBilling).where(PlanBilling.plan == sub.plan)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
# Resolved feature flags: plan defaults overridden by account overrides
|
||||
defaults = (await db.execute(
|
||||
select(PlanFeatureDefault, FeatureFlag)
|
||||
.join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id)
|
||||
.where(PlanFeatureDefault.plan == sub.plan)
|
||||
)).all()
|
||||
resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults}
|
||||
overrides = (await db.execute(
|
||||
select(AccountFeatureOverride, FeatureFlag)
|
||||
.join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id)
|
||||
.where(AccountFeatureOverride.account_id == account.id)
|
||||
)).all()
|
||||
for ovr, flag in overrides:
|
||||
resolved[flag.flag_key] = ovr.enabled
|
||||
|
||||
return {
|
||||
"subscription": {
|
||||
"status": sub.status,
|
||||
"plan": sub.plan,
|
||||
"current_period_start": sub.current_period_start,
|
||||
"current_period_end": sub.current_period_end,
|
||||
"cancel_at_period_end": sub.cancel_at_period_end,
|
||||
"seat_limit": sub.seat_limit,
|
||||
"has_pro_entitlement": sub.has_pro_entitlement,
|
||||
"is_paid": sub.is_paid,
|
||||
},
|
||||
"plan_billing": pb,
|
||||
"plan_limits": _plan_limits_to_dict(pl) if pl else {},
|
||||
"enabled_features": resolved,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def apply_subscription_event(
|
||||
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
||||
@@ -136,6 +191,10 @@ class BillingService:
|
||||
return True
|
||||
|
||||
|
||||
def _plan_limits_to_dict(pl) -> dict:
|
||||
return {c.name: getattr(pl, c.name) for c in pl.__table__.columns}
|
||||
|
||||
|
||||
def _excerpt(payload: dict) -> dict:
|
||||
obj = payload.get("data", {}).get("object", {})
|
||||
return {
|
||||
|
||||
71
backend/app/services/oauth_providers.py
Normal file
71
backend/app/services/oauth_providers.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""OAuth provider helpers. Each provider exposes:
|
||||
- exchange_code(code, redirect_uri) -> OAuthProfile
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthProfile:
|
||||
provider_subject: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
async def google_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||
async with httpx.AsyncClient(timeout=10) as cli:
|
||||
token_response = await cli.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data={
|
||||
"code": code,
|
||||
"client_id": settings.GOOGLE_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_CLIENT_SECRET,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
access_token = token_response.json()["access_token"]
|
||||
|
||||
userinfo = await cli.get(
|
||||
"https://openidconnect.googleapis.com/v1/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
userinfo.raise_for_status()
|
||||
data = userinfo.json()
|
||||
return OAuthProfile(
|
||||
provider_subject=data["sub"],
|
||||
email=data["email"],
|
||||
name=data.get("name") or data["email"].split("@")[0],
|
||||
)
|
||||
|
||||
|
||||
async def microsoft_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||
async with httpx.AsyncClient(timeout=10) as cli:
|
||||
token_response = await cli.post(
|
||||
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
data={
|
||||
"code": code,
|
||||
"client_id": settings.MS_CLIENT_ID,
|
||||
"client_secret": settings.MS_CLIENT_SECRET,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
access_token = token_response.json()["access_token"]
|
||||
|
||||
userinfo = await cli.get(
|
||||
"https://graph.microsoft.com/v1.0/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
userinfo.raise_for_status()
|
||||
data = userinfo.json()
|
||||
return OAuthProfile(
|
||||
provider_subject=data["id"],
|
||||
email=data.get("mail") or data["userPrincipalName"],
|
||||
name=data.get("displayName") or data["userPrincipalName"].split("@")[0],
|
||||
)
|
||||
180
backend/tests/test_account_invite_extensions.py
Normal file
180
backend/tests/test_account_invite_extensions.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from sqlalchemy import select
|
||||
from app.models.account_invite import AccountInvite
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite_sends_email_and_stamps_email_sent_at(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
"""Regression: today's create_invite does NOT send email. After this task, it MUST."""
|
||||
with patch(
|
||||
"app.core.email.EmailService.send_account_invite_email",
|
||||
new_callable=AsyncMock, return_value=True,
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "teammate@example.com", "role": "engineer"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201, response.json()
|
||||
mock_send.assert_called_once()
|
||||
kwargs = mock_send.call_args.kwargs
|
||||
assert kwargs["to_email"] == "teammate@example.com"
|
||||
assert kwargs["role"] == "engineer"
|
||||
assert kwargs["code"]
|
||||
|
||||
invite = (await test_db.execute(
|
||||
select(AccountInvite).where(AccountInvite.email == "teammate@example.com")
|
||||
)).scalar_one()
|
||||
assert invite.email_sent_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite_email_failure_still_creates_row(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
"""When EmailService returns False, the invite row is still created but
|
||||
email_sent_at remains NULL."""
|
||||
with patch(
|
||||
"app.core.email.EmailService.send_account_invite_email",
|
||||
new_callable=AsyncMock, return_value=False,
|
||||
):
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/invites",
|
||||
json={"email": "fail-mail@example.com", "role": "engineer"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
invite = (await test_db.execute(
|
||||
select(AccountInvite).where(AccountInvite.email == "fail-mail@example.com")
|
||||
)).scalar_one()
|
||||
assert invite.email_sent_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_invite_creates_n_rows_and_sends_n_emails(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
with patch(
|
||||
"app.core.email.EmailService.send_account_invite_email",
|
||||
new_callable=AsyncMock, return_value=True,
|
||||
) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/invites/bulk",
|
||||
json={"invites": [
|
||||
{"email": "a@example.com", "role": "engineer"},
|
||||
{"email": "b@example.com", "role": "engineer"},
|
||||
{"email": "c@example.com", "role": "viewer"},
|
||||
]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201, response.json()
|
||||
body = response.json()
|
||||
assert len(body["created"]) == 3
|
||||
assert body["failed"] == []
|
||||
assert mock_send.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_invite_sets_revoked_at(client, test_db, test_user, auth_headers):
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from app.models.account_invite import AccountInvite
|
||||
|
||||
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account_id,
|
||||
invited_by_id=invited_by_id,
|
||||
email="revoked@example.com",
|
||||
code="REVOKEME01",
|
||||
role="engineer",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
test_db.add(invite)
|
||||
await test_db.commit()
|
||||
invite_id = invite.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
await test_db.refresh(invite)
|
||||
assert invite.revoked_at is not None
|
||||
assert invite.is_valid is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_invite_idempotent(client, test_db, test_user, auth_headers):
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from app.models.account_invite import AccountInvite
|
||||
|
||||
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account_id,
|
||||
invited_by_id=invited_by_id,
|
||||
email="revoked2@example.com",
|
||||
code="REVOKEME02",
|
||||
role="engineer",
|
||||
revoked_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
test_db.add(invite)
|
||||
await test_db.commit()
|
||||
invite_id = invite.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_invite_404_when_not_found(client, test_user, auth_headers):
|
||||
import uuid
|
||||
response = await client.delete(
|
||||
f"/api/v1/accounts/me/invites/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_used_invite_returns_400(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from app.models.account_invite import AccountInvite
|
||||
|
||||
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account_id,
|
||||
invited_by_id=invited_by_id,
|
||||
email="used@example.com",
|
||||
code="USEDCODE01",
|
||||
role="engineer",
|
||||
accepted_by_id=invited_by_id, # mark as used
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
test_db.add(invite)
|
||||
await test_db.commit()
|
||||
invite_id = invite.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
64
backend/tests/test_billing_state_endpoint.py
Normal file
64
backend/tests/test_billing_state_endpoint.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_state_returns_subscription_plan_features(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
"""Subscription is already seeded by test_user fixture (pro/active).
|
||||
Add a feature flag default for `pro` and verify it shows up in the response."""
|
||||
flag = FeatureFlag(flag_key="psa_integration", display_name="PSA Integration")
|
||||
test_db.add(flag)
|
||||
await test_db.flush()
|
||||
test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=True))
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||
assert response.status_code == 200, response.json()
|
||||
body = response.json()
|
||||
assert body["subscription"]["status"] == "active"
|
||||
assert body["subscription"]["plan"] == "pro"
|
||||
assert body["subscription"]["has_pro_entitlement"] is True
|
||||
assert body["subscription"]["is_paid"] is True
|
||||
assert body["enabled_features"]["psa_integration"] is True
|
||||
# plan_limits should be a dict with the seeded pro limits from conftest
|
||||
assert body["plan_limits"]["plan"] == "pro"
|
||||
assert body["plan_limits"]["max_trees"] == 25
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_state_account_override_beats_plan_default(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
flag = FeatureFlag(flag_key="escalation_mode", display_name="Escalation Mode")
|
||||
test_db.add(flag)
|
||||
await test_db.flush()
|
||||
test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=False))
|
||||
test_db.add(AccountFeatureOverride(
|
||||
account_id=account_id, flag_id=flag.id, enabled=True,
|
||||
))
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled_features"]["escalation_mode"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_state_404_when_no_subscription(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
"""Wipe the seeded subscription and verify the endpoint surfaces 404."""
|
||||
from sqlalchemy import delete
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
98
backend/tests/test_email_verification_autosend.py
Normal file
98
backend/tests/test_email_verification_autosend.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_auto_sends_verification_email(client, test_db):
|
||||
"""Fresh registration triggers send_email_verification_email."""
|
||||
with patch(
|
||||
"app.core.email.EmailService.send_email_verification_email",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
response = await client.post("/api/v1/auth/register", json={
|
||||
"email": "newshop@example.com",
|
||||
"password": "Verystrong1Pwd",
|
||||
"name": "New Shop",
|
||||
})
|
||||
assert response.status_code in (200, 201), response.json()
|
||||
mock_send.assert_called_once()
|
||||
kwargs = mock_send.call_args.kwargs
|
||||
assert kwargs["to_email"] == "newshop@example.com"
|
||||
assert "/verify-email?token=" in kwargs["verification_url"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_account_invite_code_email_mismatch_rejected(
|
||||
client, test_db, test_user
|
||||
):
|
||||
"""Invite code is for invited@example.com but user registers with a
|
||||
different email -> 400 invite_email_mismatch."""
|
||||
from app.models.account_invite import AccountInvite
|
||||
import uuid
|
||||
|
||||
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account_id,
|
||||
invited_by_id=invited_by_id,
|
||||
email="invited@example.com",
|
||||
code="INVITECODE99",
|
||||
role="engineer",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
test_db.add(invite)
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.post("/api/v1/auth/register", json={
|
||||
"email": "wrong-email@example.com",
|
||||
"password": "Verystrong1Pwd",
|
||||
"name": "Wrong Email",
|
||||
"account_invite_code": "INVITECODE99",
|
||||
})
|
||||
assert response.status_code == 400, response.json()
|
||||
assert response.json()["detail"]["error"] == "invite_email_mismatch"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_account_invite_code_email_match_accepted(
|
||||
client, test_db, test_user
|
||||
):
|
||||
"""Invite code is for invited@example.com - registering with that email
|
||||
succeeds and joins the existing account."""
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.user import User
|
||||
import uuid
|
||||
|
||||
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
|
||||
invite = AccountInvite(
|
||||
account_id=account_id,
|
||||
invited_by_id=invited_by_id,
|
||||
email="invited@example.com",
|
||||
code="INVITECODE100",
|
||||
role="engineer",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
test_db.add(invite)
|
||||
await test_db.commit()
|
||||
|
||||
with patch(
|
||||
"app.core.email.EmailService.send_email_verification_email",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
response = await client.post("/api/v1/auth/register", json={
|
||||
"email": "invited@example.com",
|
||||
"password": "Verystrong1Pwd",
|
||||
"name": "Invited",
|
||||
"account_invite_code": "INVITECODE100",
|
||||
})
|
||||
assert response.status_code in (200, 201), response.json()
|
||||
|
||||
new_user = (await test_db.execute(
|
||||
select(User).where(User.email == "invited@example.com")
|
||||
)).scalar_one()
|
||||
assert new_user.account_id == account_id # joined existing account
|
||||
@@ -13,6 +13,14 @@ pytestmark = pytest.mark.asyncio
|
||||
@pytest.fixture
|
||||
async def kb_setup(client, auth_headers, test_db):
|
||||
"""Seed KB plan limits and return helpers."""
|
||||
# KB tests were authored against a free-plan user. Phase 1 conftest seeds
|
||||
# the test_user with a pro/active Subscription; downgrade to free here so
|
||||
# quota numbers match the original test intent.
|
||||
from app.models.subscription import Subscription
|
||||
sub = (await test_db.execute(__import__("sqlalchemy").select(Subscription))).scalar_one()
|
||||
sub.plan = "free"
|
||||
await test_db.commit()
|
||||
|
||||
# Update plan_limits with KB columns for 'free' plan
|
||||
await test_db.execute(
|
||||
__import__("sqlalchemy").text("""
|
||||
|
||||
120
backend/tests/test_oauth_callbacks.py
Normal file
120
backend/tests/test_oauth_callbacks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
from app.models.subscription import Subscription
|
||||
from app.services.oauth_providers import OAuthProfile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_callback_creates_user_account_subscription(
|
||||
client, test_db, monkeypatch
|
||||
):
|
||||
"""Brand-new user via Google OAuth -> User + Account + Subscription + OAuthIdentity."""
|
||||
from app.core.config import settings
|
||||
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||
|
||||
profile = OAuthProfile(
|
||||
provider_subject="google_subject_123",
|
||||
email="newuser@example.com",
|
||||
name="New User",
|
||||
)
|
||||
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
body = response.json()
|
||||
assert body["is_new_user"] is True
|
||||
assert body["access_token"]
|
||||
|
||||
user = (await test_db.execute(
|
||||
select(User).where(User.email == "newuser@example.com")
|
||||
)).scalar_one()
|
||||
assert user.password_hash is None
|
||||
assert user.email_verified_at is not None
|
||||
|
||||
identity = (await test_db.execute(
|
||||
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert identity.provider == "google"
|
||||
assert identity.provider_subject == "google_subject_123"
|
||||
|
||||
sub = (await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == user.account_id)
|
||||
)).scalar_one()
|
||||
assert sub.status == "trialing"
|
||||
assert sub.plan == "pro"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_callback_existing_user_is_idempotent(
|
||||
client, test_db, test_user, monkeypatch
|
||||
):
|
||||
"""When test_user's email is already registered, OAuth links + returns the
|
||||
same user. Two calls with same provider_subject must not duplicate
|
||||
OAuthIdentity rows."""
|
||||
from app.core.config import settings
|
||||
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||
|
||||
user_id = uuid.UUID(test_user["user_data"]["id"])
|
||||
email = test_user["email"]
|
||||
name = test_user["user_data"]["name"]
|
||||
|
||||
profile = OAuthProfile(
|
||||
provider_subject="google_subject_456",
|
||||
email=email,
|
||||
name=name,
|
||||
)
|
||||
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||
r1 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||
r2 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r1.json()["is_new_user"] is False
|
||||
assert r2.json()["is_new_user"] is False
|
||||
|
||||
identities = (await test_db.execute(
|
||||
select(OAuthIdentity).where(OAuthIdentity.user_id == user_id)
|
||||
)).scalars().all()
|
||||
assert len(identities) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_callback_503_when_unconfigured(client, monkeypatch):
|
||||
from app.core.config import settings
|
||||
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/google/callback", json={"code": "x"}
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
|
||||
from app.core.config import settings
|
||||
monkeypatch.setattr(settings, "MS_CLIENT_ID", "client_dummy")
|
||||
monkeypatch.setattr(settings, "MS_CLIENT_SECRET", "secret_dummy")
|
||||
|
||||
profile = OAuthProfile(
|
||||
provider_subject="ms_subject_789",
|
||||
email="msuser@example.com",
|
||||
name="MS User",
|
||||
)
|
||||
with patch("app.api.endpoints.oauth.microsoft_exchange_code", return_value=profile):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/microsoft/callback", json={"code": "auth_code"}
|
||||
)
|
||||
assert response.status_code == 200, response.json()
|
||||
user = (await test_db.execute(
|
||||
select(User).where(User.email == "msuser@example.com")
|
||||
)).scalar_one()
|
||||
identity = (await test_db.execute(
|
||||
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
|
||||
)).scalar_one()
|
||||
assert identity.provider == "microsoft"
|
||||
83
backend/tests/test_oauth_only_user_paths.py
Normal file
83
backend/tests/test_oauth_only_user_paths.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
from app.models.oauth_identity import OAuthIdentity
|
||||
|
||||
|
||||
async def _make_oauth_only_user(test_db, email, *, with_identity=True):
|
||||
"""Create an OAuth-only user (password_hash=None) directly in the test DB."""
|
||||
import secrets
|
||||
account = Account(
|
||||
name=f"{email}-acct",
|
||||
display_code=secrets.token_hex(4).upper(),
|
||||
)
|
||||
test_db.add(account)
|
||||
await test_db.flush()
|
||||
user = User(
|
||||
email=email,
|
||||
name="OAuth User",
|
||||
password_hash=None,
|
||||
account_id=account.id,
|
||||
account_role="owner",
|
||||
)
|
||||
test_db.add(user)
|
||||
await test_db.flush()
|
||||
if with_identity:
|
||||
test_db.add(OAuthIdentity(
|
||||
user_id=user.id, provider="google",
|
||||
provider_subject=f"google_{email}",
|
||||
provider_email_at_link=email,
|
||||
))
|
||||
await test_db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_form_rejects_oauth_only_user_with_helpful_error(client, test_db):
|
||||
await _make_oauth_only_user(test_db, "oauth-only@example.com")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "oauth-only@example.com", "password": "wontwork"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
body = response.json()
|
||||
assert body["detail"]["error"] == "use_oauth_provider"
|
||||
assert "google" in body["detail"]["providers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_json_rejects_oauth_only_user(client, test_db):
|
||||
await _make_oauth_only_user(test_db, "oauth-only2@example.com")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={"email": "oauth-only2@example.com", "password": "wontwork"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"]["error"] == "use_oauth_provider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_forgot_silent_for_oauth_only_user(client, test_db):
|
||||
"""OAuth-only users get the generic message; no email is sent."""
|
||||
await _make_oauth_only_user(test_db, "oauth-forgot@example.com", with_identity=False)
|
||||
from unittest.mock import AsyncMock, patch
|
||||
with patch("app.core.email.EmailService.send_password_reset_email", new_callable=AsyncMock) as mock_send:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password/forgot",
|
||||
json={"email": "oauth-forgot@example.com"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_for_password_user_still_works(client, test_user):
|
||||
"""Regression: existing password-based login must still succeed."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login/json",
|
||||
json={"email": test_user["email"], "password": test_user["password"]},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["access_token"]
|
||||
85
backend/tests/test_pilot_complimentary_backfill.py
Normal file
85
backend/tests/test_pilot_complimentary_backfill.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Smoke test for the complimentary backfill: assertions about the post-state.
|
||||
The actual migration runs at deploy time; tests use create_all so the
|
||||
migration body isn't executed automatically. We invoke the SQL inline to
|
||||
exercise the same effect."""
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy import select, text, delete
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complimentary_backfill_sets_status_and_inserts_missing_rows(test_db):
|
||||
"""Inline-run the backfill SQL and assert post-state."""
|
||||
# Seed a fresh account with no subscription
|
||||
no_sub_account = Account(name="NoSub", display_code="NOSUB001")
|
||||
test_db.add(no_sub_account)
|
||||
await test_db.flush()
|
||||
|
||||
# Seed an account with a trialing subscription (should become complimentary)
|
||||
trial_account = Account(name="Trial", display_code="TRIAL001")
|
||||
test_db.add(trial_account)
|
||||
await test_db.flush()
|
||||
test_db.add(Subscription(
|
||||
account_id=trial_account.id, plan="free", status="trialing",
|
||||
))
|
||||
|
||||
# Seed an account with a canceled subscription (should be preserved)
|
||||
canceled_account = Account(name="Cancel", display_code="CANCL001")
|
||||
test_db.add(canceled_account)
|
||||
await test_db.flush()
|
||||
test_db.add(Subscription(
|
||||
account_id=canceled_account.id, plan="pro", status="canceled",
|
||||
))
|
||||
await test_db.commit()
|
||||
|
||||
# Run the same SQL the migration runs
|
||||
await test_db.execute(text("""
|
||||
UPDATE subscriptions
|
||||
SET status = 'complimentary', plan = 'pro',
|
||||
current_period_end = NULL, current_period_start = NULL,
|
||||
updated_at = now()
|
||||
WHERE status NOT IN ('canceled', 'past_due')
|
||||
"""))
|
||||
await test_db.execute(text("""
|
||||
INSERT INTO subscriptions (id, account_id, plan, status, cancel_at_period_end, created_at, updated_at)
|
||||
SELECT gen_random_uuid(), a.id, 'pro', 'complimentary', false, now(), now()
|
||||
FROM accounts a
|
||||
WHERE NOT EXISTS (SELECT 1 FROM subscriptions s WHERE s.account_id = a.id)
|
||||
"""))
|
||||
await test_db.commit()
|
||||
|
||||
# All three accounts now have a Subscription
|
||||
no_sub_row = (await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == no_sub_account.id)
|
||||
)).scalar_one()
|
||||
assert no_sub_row.status == "complimentary"
|
||||
assert no_sub_row.plan == "pro"
|
||||
|
||||
trial_row = (await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == trial_account.id)
|
||||
)).scalar_one()
|
||||
assert trial_row.status == "complimentary"
|
||||
assert trial_row.plan == "pro"
|
||||
|
||||
canceled_row = (await test_db.execute(
|
||||
select(Subscription).where(Subscription.account_id == canceled_account.id)
|
||||
)).scalar_one()
|
||||
# Canceled is preserved
|
||||
assert canceled_row.status == "canceled"
|
||||
assert canceled_row.plan == "pro"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complimentary_subscription_passes_active_subscription_guard(
|
||||
client, test_db, test_user, auth_headers
|
||||
):
|
||||
"""The require_active_subscription guard accepts complimentary status."""
|
||||
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||
test_db.add(Subscription(account_id=account_id, plan="pro", status="complimentary"))
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||
assert response.status_code != 402
|
||||
Reference in New Issue
Block a user