Compare commits
12 Commits
9b709488d9
...
fix/seed-t
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c64e9ad62 | |||
| f918b766b0 | |||
| fbb41e789c | |||
| 97d36dd400 | |||
| f26f468878 | |||
| 79942c3fd3 | |||
| 4768ae0648 | |||
| e54d6c586a | |||
| 86893562b9 | |||
| b0708ed650 | |||
| 2ef2350de7 | |||
| f4606f073a |
@@ -2,35 +2,35 @@
|
|||||||
|
|
||||||
# HANDOFF.md
|
# HANDOFF.md
|
||||||
|
|
||||||
**Last updated:** 2026-05-02 (post-PR-159 — guides Diátaxis rewrite merged into `main`)
|
**Last updated:** 2026-05-06 (Phase 1 backend complete on `feat/self-serve-signup-spec`)
|
||||||
|
|
||||||
**Active task:** None. Pick next from `.ai/TODO.md` or `03-DEVELOPMENT-ROADMAP.md`.
|
**Active task:** Phase 1 self-serve signup backend foundation — DONE on branch. PR not yet opened.
|
||||||
|
|
||||||
**Just-merged:** PR #159 — In-product User Guides rewritten as 43 Diátaxis how-tos under 10 categories. Drops 3 deprecated guides, renames Step Library → Solutions Library, fixes tip-markdown rendering, adds 14 net-new how-tos for FlowPilot-era surfaces.
|
|
||||||
|
|
||||||
## Where this session ended
|
## Where this session ended
|
||||||
|
|
||||||
PR #159 merged into `main`. CHANGELOG, CURRENT_TASK, SESSION_LOG all updated. See `.ai/CURRENT_TASK.md` "Recently shipped" for the structured rollup.
|
24 commits on top of `main` (`31ca3fb`). All 26 tasks from `docs/superpowers/plans/2026-05-06-self-serve-signup-phase-1-backend.md` complete. Full pytest run is green (1167 passed, 35 deselected). Single alembic head: `c6cbfc534fad`.
|
||||||
|
|
||||||
The 43 guides live at `/guides` in the app. Schema is now category-aware (`Guide.category`, optional `relatedSlugs`); `categories` const drives hub ordering. Browser-verified against engineer + owner test users (sidebar labels, account sub-pages, pilot-screen header buttons, Tasks panel, integration form). tsc and Vite build clean.
|
Phase 1 covered: schema additions (oauth_identities, plan_billing, sales_leads, stripe_events, plus 5 new columns across users/accounts/account_invites), Subscription complimentary status + has_pro_entitlement, the two new guards (`require_active_subscription`, `require_verified_email_after_grace`), full BillingService (start_trial / create_checkout_session / apply_subscription_event / get_billing_state), Stripe webhook handler, Google + Microsoft OAuth callbacks with oauth_identities linking, OAuth-only password guard, register-time verification email + invite email-match, bulk + soft-revoke invite routes, GET /billing/state, and the pilot complimentary backfill migration.
|
||||||
|
|
||||||
|
The conftest's `test_user` fixture was modified to seed a Pro/active Subscription post-register (delete-then-insert) so the new subscription guard doesn't 402 every existing test. Two existing tests adapted because they explicitly assumed the old free-plan default: `test_subscription_limits.py` (the two free-plan tests now downgrade inline) and `test_kb_accelerator.py::TestQuota::test_get_quota` (the `kb_setup` fixture downgrades to free).
|
||||||
|
|
||||||
## Resume point — DO THIS NEXT
|
## Resume point — DO THIS NEXT
|
||||||
|
|
||||||
The issue cleanup plan continues from before this session. Pick up `docs/plans/2026-05-01-issue-cleanup-plan.md` at section 3: **#58 structured "step is wrong" quality signals**. Then section 4 (#60 recurring issue detection) and section 5 (#129 hierarchical guide navigation).
|
1. Open the PR for branch `feat/self-serve-signup-spec`. Use `gh pr create` against `main`. Suggested title: `feat: self-serve signup backend (Phase 1)`. Body should mention dark-launch posture (every new endpoint is gated by env config, not a feature flag — see Task 26 §3 in the plan).
|
||||||
|
2. Phase 2 (frontend + cutover) lives in a sibling plan: `docs/superpowers/plans/2026-05-06-self-serve-signup-phase-2-frontend.md` (assumed; verify path). It's the next logical task once Phase 1 ships.
|
||||||
`$GITEA_TOKEN` is in `.claude/settings.local.json` — confirmed working via the PR-creation API call this session. Issue tracker actions can be done from the code-server LXC via `curl` against `https://gitea.resolutionflow.com/api/v1/...`.
|
|
||||||
|
|
||||||
## Followups deferred from this session
|
## Followups deferred from this session
|
||||||
|
|
||||||
Worth picking up if a related touch happens:
|
- **OAuth callbacks don't call `_store_refresh_token`.** The Google/Microsoft callbacks issue a refresh JWT but never persist its hash to `refresh_tokens` (the password-login flow does via `auth.py:_store_refresh_token`). Result: refresh-token revocation/rotation lookups won't find OAuth-issued tokens. Decide before Phase 2 dark-launch whether to backfill — likely yes, by extracting `_store_refresh_token` to a shared module and calling it from `_sign_in_or_register`.
|
||||||
|
- **`stripe_enabled` was relaxed** in Task 14 from `bool(STRIPE_SECRET_KEY) and bool(STRIPE_WEBHOOK_SECRET)` to just the secret key. The webhook handler in Task 16 independently checks `STRIPE_WEBHOOK_SECRET` before calling `construct_event`, so signature verification is still safe — but if any other code reads `stripe_enabled` and assumes the webhook secret is set, that's a latent bug. Audit before Phase 2 cutover.
|
||||||
- **`change-teammate-role` how-to was dropped** from PR #159 because the test owner account has no non-owner members to inspect the role-change control. Once a teammate is invited via the Membership form on `/account`, verify whether the list exposes a Role dropdown (or some other control) for non-owners and add the guide back to `frontend/src/data/guides.ts` under the `account-admin` category.
|
- **`backend/app/core/stripe_handlers.py`** is a stub module that's no longer referenced after Task 16. Safe to delete in a follow-up; left in place to keep Phase 1 diff focused.
|
||||||
- **Resolve / Escalate modal contents are unverified.** Browser couldn't drive Resolve to completion (test session has 6 pending Tasks gating it; clicking Resolve fired a toast). The how-tos point at the right buttons in the right place, but the exact modal copy and the Escalation Mode wedge specifics are based on project context, not live observation. Worth a quick spot-check the next time a clean test session is available.
|
- **Pilot backfill migration `c6cbfc534fad` has not been applied to prod yet.** It runs once at deploy time and is forward-only.
|
||||||
|
|
||||||
## Environment notes (carry-forward)
|
## Environment notes (carry-forward)
|
||||||
|
|
||||||
- Code-server LXC has bun + docker but no native `python`/`node`/`npm`. Use `docker exec resolutionflow_{backend,frontend} …` for build/test commands.
|
- Code-server LXC has bun + docker but no native python/node/npm. Use `docker exec resolutionflow_{backend,frontend} ...` for build/test commands.
|
||||||
- No `gh` CLI on this LXC — use the Gitea API (`$GITEA_TOKEN`) for PR/issue work, or run `gh` from a host that has it.
|
- Pytest WORKDIR is `/app` — test paths in pytest commands are `tests/<file>`, NOT `backend/tests/<file>`.
|
||||||
|
- Backend pytest cmd: `docker exec resolutionflow_backend pytest tests/<path> -v --override-ini="addopts="`. The full run takes ~25 min.
|
||||||
|
- Alembic via `docker exec -w /app resolutionflow_backend alembic ...`. Never pass `--rev-id`.
|
||||||
|
- No `gh` CLI on this LXC — use the Gitea API (`$GITEA_TOKEN` in `.claude/settings.local.json`) for PR/issue work, or run `gh` from a host that has it.
|
||||||
- Headless Chromium (`/qa`, `/browse`) needs `CONTAINER=1` in the env launching the browse server (LXC namespace constraint).
|
- Headless Chromium (`/qa`, `/browse`) needs `CONTAINER=1` in the env launching the browse server (LXC namespace constraint).
|
||||||
- `/etc/hosts` has `100.64.78.44 docker-01` so the headless browser resolves the bake-in `VITE_API_URL`. The previous handoff claimed this entry was persistent but it was missing on this LXC at the start of this session — re-added via `sudo tee` from a real terminal (the `!` shell prefix can't drive interactive sudo). Confirmed working.
|
|
||||||
- Multi-head alembic state on `main` (heads `070`, `c0f3a4b7e91d`, `024`) is pre-existing. Use `alembic upgrade heads` (plural) if `head` complains.
|
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""subscriptions pilot complimentary backfill
|
||||||
|
|
||||||
|
This migration converts existing pilot/dev accounts to permanent complimentary
|
||||||
|
Pro per the self-serve signup spec section 5. Forward-only; downgrade is
|
||||||
|
prohibited because original status is not preserved.
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "c6cbfc534fad"
|
||||||
|
down_revision: Union[str, None] = "c982a3fc4bf1"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Set status='complimentary' and plan='pro' for all existing accounts that
|
||||||
|
don't have a canceled or past_due subscription. Pilot users transition to
|
||||||
|
permanent complimentary Pro per spec section 5.
|
||||||
|
|
||||||
|
Forward-only — does not preserve original status values."""
|
||||||
|
conn = op.get_bind()
|
||||||
|
# Update existing rows
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
UPDATE subscriptions
|
||||||
|
SET status = 'complimentary', plan = 'pro',
|
||||||
|
current_period_end = NULL, current_period_start = NULL,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE status NOT IN ('canceled', 'past_due')
|
||||||
|
"""))
|
||||||
|
# Backfill: any account without a Subscription row gets one
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
INSERT INTO subscriptions (id, account_id, plan, status, cancel_at_period_end, created_at, updated_at)
|
||||||
|
SELECT gen_random_uuid(), a.id, 'pro', 'complimentary', false, now(), now()
|
||||||
|
FROM accounts a
|
||||||
|
WHERE NOT EXISTS (SELECT 1 FROM subscriptions s WHERE s.account_id = a.id)
|
||||||
|
"""))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot downgrade: original subscription state is not preserved. "
|
||||||
|
"Restore from backup if needed."
|
||||||
|
)
|
||||||
@@ -19,7 +19,7 @@ from app.models.account_invite import AccountInvite
|
|||||||
from app.models.account_settings import AccountSettings
|
from app.models.account_settings import AccountSettings
|
||||||
from app.models.subscription import Subscription
|
from app.models.subscription import Subscription
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, TransferOwnershipRequest
|
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, AccountInviteBulkCreate, AccountInviteBulkResponse, TransferOwnershipRequest
|
||||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||||
from app.core.security import verify_password
|
from app.core.security import verify_password
|
||||||
@@ -260,7 +260,7 @@ async def create_invite(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(require_account_owner)]
|
current_user: Annotated[User, Depends(require_account_owner)]
|
||||||
):
|
):
|
||||||
"""Create an invite to join this account (owner only)."""
|
"""Create an invite to join this account (owner only). Sends invite email."""
|
||||||
code = secrets.token_urlsafe(16)
|
code = secrets.token_urlsafe(16)
|
||||||
|
|
||||||
expires_at = None
|
expires_at = None
|
||||||
@@ -276,11 +276,109 @@ async def create_invite(
|
|||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
)
|
)
|
||||||
db.add(invite)
|
db.add(invite)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Lookup account name for email
|
||||||
|
account_result = await db.execute(
|
||||||
|
select(Account).where(Account.id == current_user.account_id)
|
||||||
|
)
|
||||||
|
account = account_result.scalar_one()
|
||||||
|
|
||||||
|
# Send invite email — non-blocking on failure (function returns False on error)
|
||||||
|
email_sent = await EmailService.send_account_invite_email(
|
||||||
|
to_email=invite.email,
|
||||||
|
code=code,
|
||||||
|
account_name=account.name,
|
||||||
|
role=invite.role,
|
||||||
|
)
|
||||||
|
if email_sent:
|
||||||
|
invite.email_sent_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(invite)
|
await db.refresh(invite)
|
||||||
return invite
|
return invite
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/me/invites/bulk", response_model=AccountInviteBulkResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_invites_bulk(
|
||||||
|
payload: AccountInviteBulkCreate,
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(require_account_owner)]
|
||||||
|
):
|
||||||
|
"""Create multiple invites in one call (wizard step 3 supports up to N).
|
||||||
|
Per-row failures are returned in `failed`; successes in `created`."""
|
||||||
|
# Lookup account once for email rendering
|
||||||
|
account_result = await db.execute(
|
||||||
|
select(Account).where(Account.id == current_user.account_id)
|
||||||
|
)
|
||||||
|
account = account_result.scalar_one()
|
||||||
|
|
||||||
|
created: list[AccountInvite] = []
|
||||||
|
failed: list[dict] = []
|
||||||
|
for invite_data in payload.invites:
|
||||||
|
try:
|
||||||
|
code = secrets.token_urlsafe(16)
|
||||||
|
expires_at = None
|
||||||
|
if invite_data.expires_in_days:
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(days=invite_data.expires_in_days)
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=current_user.account_id,
|
||||||
|
invited_by_id=current_user.id,
|
||||||
|
email=invite_data.email,
|
||||||
|
code=code,
|
||||||
|
role=invite_data.role,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(invite)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
email_sent = await EmailService.send_account_invite_email(
|
||||||
|
to_email=invite.email,
|
||||||
|
code=code,
|
||||||
|
account_name=account.name,
|
||||||
|
role=invite.role,
|
||||||
|
)
|
||||||
|
if email_sent:
|
||||||
|
invite.email_sent_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
created.append(invite)
|
||||||
|
except Exception as e:
|
||||||
|
failed.append({"email": invite_data.email, "error": str(e)})
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
for inv in created:
|
||||||
|
await db.refresh(inv)
|
||||||
|
|
||||||
|
return AccountInviteBulkResponse(created=created, failed=failed)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/invites/{invite_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def revoke_invite(
|
||||||
|
invite_id: UUID,
|
||||||
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(require_account_owner)]
|
||||||
|
):
|
||||||
|
"""Soft-revoke an invitation by setting revoked_at. Idempotent on already-
|
||||||
|
revoked invites; rejects already-accepted invites."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(AccountInvite).where(
|
||||||
|
AccountInvite.id == invite_id,
|
||||||
|
AccountInvite.account_id == current_user.account_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
if not invite:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invite not found")
|
||||||
|
if invite.is_revoked:
|
||||||
|
return None # idempotent
|
||||||
|
if invite.is_used:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke an accepted invite")
|
||||||
|
invite.revoked_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/me/invites/{invite_id}/resend", response_model=AccountInviteResponse)
|
@router.post("/me/invites/{invite_id}/resend", response_model=AccountInviteResponse)
|
||||||
async def resend_invite(
|
async def resend_invite(
|
||||||
invite_id: UUID,
|
invite_id: UUID,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
@@ -41,6 +42,8 @@ from app.core.email import EmailService
|
|||||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||||
from app.core.audit import log_audit
|
from app.core.audit import log_audit
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||||
|
|
||||||
|
|
||||||
@@ -62,6 +65,22 @@ def _generate_display_code() -> str:
|
|||||||
return ''.join(secrets.choice(chars) for _ in range(8))
|
return ''.join(secrets.choice(chars) for _ in range(8))
|
||||||
|
|
||||||
|
|
||||||
|
async def _reject_if_oauth_only(db: AsyncSession, user) -> None:
|
||||||
|
"""If the user has no password_hash, raise 400 with a list of linked
|
||||||
|
providers so the client can redirect them to the right OAuth flow."""
|
||||||
|
if user is None or user.password_hash is not None:
|
||||||
|
return
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthIdentity.provider).where(OAuthIdentity.user_id == user.id)
|
||||||
|
)
|
||||||
|
providers = [row for row in result.scalars().all()]
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "use_oauth_provider", "providers": providers},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
@limiter.limit("3/minute")
|
@limiter.limit("3/minute")
|
||||||
async def register(
|
async def register(
|
||||||
@@ -108,6 +127,12 @@ async def register(
|
|||||||
detail="Account invite code has expired"
|
detail="Account invite code has expired"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if account_invite_record.email.lower() != user_data.email.lower():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": "invite_email_mismatch"},
|
||||||
|
)
|
||||||
|
|
||||||
# Validate platform invite code (skip if account invite was provided)
|
# Validate platform invite code (skip if account invite was provided)
|
||||||
invite_code_record = None
|
invite_code_record = None
|
||||||
if not account_invite_record:
|
if not account_invite_record:
|
||||||
@@ -228,6 +253,34 @@ async def register(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(new_user)
|
await db.refresh(new_user)
|
||||||
|
|
||||||
|
# Auto-send verification email for newly-registered users.
|
||||||
|
# Skip silently if verification already done (shouldn't happen for fresh
|
||||||
|
# users, but defensive).
|
||||||
|
if new_user.email_verified_at is None:
|
||||||
|
verification_enabled = await SettingsManager.get(
|
||||||
|
"email_verification_enabled", db, default=True
|
||||||
|
)
|
||||||
|
if verification_enabled:
|
||||||
|
try:
|
||||||
|
raw_token = create_email_verification_token(str(new_user.id))
|
||||||
|
payload = decode_token(raw_token)
|
||||||
|
if payload and payload.get("jti"):
|
||||||
|
token_record = EmailVerificationToken(
|
||||||
|
token_hash=hash_token(payload["jti"]),
|
||||||
|
user_id=new_user.id,
|
||||||
|
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||||
|
)
|
||||||
|
db.add(token_record)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
|
||||||
|
await EmailService.send_email_verification_email(
|
||||||
|
to_email=new_user.email,
|
||||||
|
verification_url=verification_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("verification email send failed for %s: %s", new_user.email, e)
|
||||||
|
|
||||||
return new_user
|
return new_user
|
||||||
|
|
||||||
|
|
||||||
@@ -243,6 +296,7 @@ async def login(
|
|||||||
result = await db.execute(select(User).where(User.email == form_data.username))
|
result = await db.execute(select(User).where(User.email == form_data.username))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
await _reject_if_oauth_only(db, user)
|
||||||
if not user or not verify_password(form_data.password, user.password_hash):
|
if not user or not verify_password(form_data.password, user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -280,6 +334,7 @@ async def login_json(
|
|||||||
result = await db.execute(select(User).where(User.email == credentials.email))
|
result = await db.execute(select(User).where(User.email == credentials.email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
await _reject_if_oauth_only(db, user)
|
||||||
if not user or not verify_password(credentials.password, user.password_hash):
|
if not user or not verify_password(credentials.password, user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -445,6 +500,7 @@ async def change_password(
|
|||||||
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
db: Annotated[AsyncSession, Depends(get_admin_db)]
|
||||||
):
|
):
|
||||||
"""Change the current user's password."""
|
"""Change the current user's password."""
|
||||||
|
await _reject_if_oauth_only(db, current_user)
|
||||||
if not verify_password(data.current_password, current_user.password_hash):
|
if not verify_password(data.current_password, current_user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -488,7 +544,7 @@ async def forgot_password(
|
|||||||
result = await db.execute(select(User).where(User.email == data.email))
|
result = await db.execute(select(User).where(User.email == data.email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user:
|
if user and user.password_hash is not None:
|
||||||
# Create reset token JWT
|
# Create reset token JWT
|
||||||
raw_token = create_password_reset_token(str(user.id))
|
raw_token = create_password_reset_token(str(user.id))
|
||||||
payload = decode_token(raw_token)
|
payload = decode_token(raw_token)
|
||||||
|
|||||||
@@ -9,7 +9,11 @@ from app.core.admin_database import get_admin_db
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.models.account import Account
|
from app.models.account import Account
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.billing import CheckoutSessionCreate, CheckoutSessionResponse
|
from app.schemas.billing import (
|
||||||
|
BillingStateResponse,
|
||||||
|
CheckoutSessionCreate,
|
||||||
|
CheckoutSessionResponse,
|
||||||
|
)
|
||||||
from app.services.billing import BillingService
|
from app.services.billing import BillingService
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
@@ -34,3 +38,15 @@ async def create_checkout_session(
|
|||||||
cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan",
|
cancel_url=f"{settings.FRONTEND_URL}/account/billing/select-plan",
|
||||||
)
|
)
|
||||||
return CheckoutSessionResponse(url=url)
|
return CheckoutSessionResponse(url=url)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/state", response_model=BillingStateResponse)
|
||||||
|
async def get_billing_state(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
) -> BillingStateResponse:
|
||||||
|
account = (await db.execute(
|
||||||
|
select(Account).where(Account.id == current_user.account_id)
|
||||||
|
)).scalar_one()
|
||||||
|
state = await BillingService.get_billing_state(db, account)
|
||||||
|
return BillingStateResponse(**state)
|
||||||
|
|||||||
123
backend/app/api/endpoints/oauth.py
Normal file
123
backend/app/api/endpoints/oauth.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.admin_database import get_admin_db
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.security import create_access_token, create_refresh_token
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.oauth import OAuthCallbackPayload, OAuthCallbackResponse
|
||||||
|
from app.services.billing import BillingService
|
||||||
|
from app.services.oauth_providers import (
|
||||||
|
google_exchange_code,
|
||||||
|
microsoft_exchange_code,
|
||||||
|
OAuthProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth-oauth"])
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_display_code(length: int = 8) -> str:
|
||||||
|
"""Match the helper used by /auth/register — A-Z + 0-9, length 8."""
|
||||||
|
alphabet = string.ascii_uppercase + string.digits
|
||||||
|
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
async def _sign_in_or_register(
|
||||||
|
db: AsyncSession, provider: str, profile: OAuthProfile
|
||||||
|
) -> tuple[User, bool]:
|
||||||
|
"""Returns (user, is_new_user). Idempotent on (provider, provider_subject)."""
|
||||||
|
identity = (
|
||||||
|
await db.execute(
|
||||||
|
select(OAuthIdentity).where(
|
||||||
|
OAuthIdentity.provider == provider,
|
||||||
|
OAuthIdentity.provider_subject == profile.provider_subject,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if identity:
|
||||||
|
user = (
|
||||||
|
await db.execute(select(User).where(User.id == identity.user_id))
|
||||||
|
).scalar_one()
|
||||||
|
return user, False
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await db.execute(select(User).where(User.email == profile.email))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
is_new_user = user is None
|
||||||
|
if is_new_user:
|
||||||
|
account = Account(
|
||||||
|
name=f"{profile.name}'s Account",
|
||||||
|
display_code=_generate_display_code(),
|
||||||
|
)
|
||||||
|
db.add(account)
|
||||||
|
await db.flush()
|
||||||
|
user = User(
|
||||||
|
email=profile.email,
|
||||||
|
name=profile.name,
|
||||||
|
password_hash=None,
|
||||||
|
account_id=account.id,
|
||||||
|
account_role="owner",
|
||||||
|
role="engineer",
|
||||||
|
email_verified_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
account.owner_id = user.id
|
||||||
|
await db.flush()
|
||||||
|
# start_trial commits internally; flushed account/user above.
|
||||||
|
await BillingService.start_trial(db, account.id)
|
||||||
|
|
||||||
|
db.add(
|
||||||
|
OAuthIdentity(
|
||||||
|
user_id=user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_subject=profile.provider_subject,
|
||||||
|
provider_email_at_link=profile.email,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user, is_new_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/google/callback", response_model=OAuthCallbackResponse)
|
||||||
|
async def google_callback(
|
||||||
|
payload: OAuthCallbackPayload,
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
) -> OAuthCallbackResponse:
|
||||||
|
if not settings.GOOGLE_CLIENT_ID:
|
||||||
|
raise HTTPException(status_code=503, detail="Google sign-in not configured")
|
||||||
|
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/google/callback"
|
||||||
|
profile = await google_exchange_code(payload.code, redirect_uri)
|
||||||
|
user, is_new = await _sign_in_or_register(db, "google", profile)
|
||||||
|
return OAuthCallbackResponse(
|
||||||
|
access_token=create_access_token({"sub": str(user.id)}),
|
||||||
|
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||||
|
is_new_user=is_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/microsoft/callback", response_model=OAuthCallbackResponse)
|
||||||
|
async def microsoft_callback(
|
||||||
|
payload: OAuthCallbackPayload,
|
||||||
|
db: Annotated[AsyncSession, Depends(get_admin_db)],
|
||||||
|
) -> OAuthCallbackResponse:
|
||||||
|
if not settings.MS_CLIENT_ID:
|
||||||
|
raise HTTPException(status_code=503, detail="Microsoft sign-in not configured")
|
||||||
|
redirect_uri = f"{settings.OAUTH_REDIRECT_BASE}/auth/microsoft/callback"
|
||||||
|
profile = await microsoft_exchange_code(payload.code, redirect_uri)
|
||||||
|
user, is_new = await _sign_in_or_register(db, "microsoft", profile)
|
||||||
|
return OAuthCallbackResponse(
|
||||||
|
access_token=create_access_token({"sub": str(user.id)}),
|
||||||
|
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||||
|
is_new_user=is_new,
|
||||||
|
)
|
||||||
@@ -41,6 +41,7 @@ from app.api.endpoints import (
|
|||||||
maintenance_schedules,
|
maintenance_schedules,
|
||||||
network_diagrams,
|
network_diagrams,
|
||||||
notifications,
|
notifications,
|
||||||
|
oauth as oauth_endpoints,
|
||||||
onboarding,
|
onboarding,
|
||||||
public_templates,
|
public_templates,
|
||||||
ratings,
|
ratings,
|
||||||
@@ -82,6 +83,7 @@ api_router = APIRouter()
|
|||||||
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
# in Phase 1. This will need revisiting in Phase 2 when `users` gets RLS.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
api_router.include_router(auth.router)
|
api_router.include_router(auth.router)
|
||||||
|
api_router.include_router(oauth_endpoints.router)
|
||||||
api_router.include_router(billing.router) # Reachable when subscription locked
|
api_router.include_router(billing.router) # Reachable when subscription locked
|
||||||
api_router.include_router(shared.router) # Public share links (no auth)
|
api_router.include_router(shared.router) # Public share links (no auth)
|
||||||
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
api_router.include_router(shares.public_router) # Public session share links (optional auth)
|
||||||
|
|||||||
@@ -194,6 +194,13 @@ class Settings(BaseSettings):
|
|||||||
"""Check if ConnectWise integration is configured."""
|
"""Check if ConnectWise integration is configured."""
|
||||||
return self.CW_CLIENT_ID is not None
|
return self.CW_CLIENT_ID is not None
|
||||||
|
|
||||||
|
# OAuth providers (self-serve signup)
|
||||||
|
GOOGLE_CLIENT_ID: Optional[str] = None
|
||||||
|
GOOGLE_CLIENT_SECRET: Optional[str] = None
|
||||||
|
MS_CLIENT_ID: Optional[str] = None
|
||||||
|
MS_CLIENT_SECRET: Optional[str] = None
|
||||||
|
OAUTH_REDIRECT_BASE: str = "http://localhost:5173"
|
||||||
|
|
||||||
# Monitoring
|
# Monitoring
|
||||||
SENTRY_DSN: Optional[str] = None
|
SENTRY_DSN: Optional[str] = None
|
||||||
|
|
||||||
|
|||||||
@@ -42,3 +42,12 @@ class AccountInviteResponse(BaseModel):
|
|||||||
used_at: Optional[datetime] = None
|
used_at: Optional[datetime] = None
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AccountInviteBulkCreate(BaseModel):
|
||||||
|
invites: list[AccountInviteCreate]
|
||||||
|
|
||||||
|
|
||||||
|
class AccountInviteBulkResponse(BaseModel):
|
||||||
|
created: list[AccountInviteResponse]
|
||||||
|
failed: list[dict] # entries shaped {"email": str, "error": str}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Literal
|
from typing import Literal, Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -10,3 +11,30 @@ class CheckoutSessionCreate(BaseModel):
|
|||||||
|
|
||||||
class CheckoutSessionResponse(BaseModel):
|
class CheckoutSessionResponse(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionState(BaseModel):
|
||||||
|
status: str
|
||||||
|
plan: str
|
||||||
|
current_period_start: Optional[datetime]
|
||||||
|
current_period_end: Optional[datetime]
|
||||||
|
cancel_at_period_end: bool
|
||||||
|
seat_limit: Optional[int]
|
||||||
|
has_pro_entitlement: bool
|
||||||
|
is_paid: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PlanBillingState(BaseModel):
|
||||||
|
display_name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
monthly_price_cents: Optional[int] = None
|
||||||
|
annual_price_cents: Optional[int] = None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class BillingStateResponse(BaseModel):
|
||||||
|
subscription: SubscriptionState
|
||||||
|
plan_billing: Optional[PlanBillingState]
|
||||||
|
plan_limits: Dict[str, Any]
|
||||||
|
enabled_features: Dict[str, bool]
|
||||||
|
|||||||
13
backend/app/schemas/oauth.py
Normal file
13
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackPayload(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
is_new_user: bool
|
||||||
@@ -105,6 +105,61 @@ class BillingService:
|
|||||||
)
|
)
|
||||||
return session.url
|
return session.url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_billing_state(db: AsyncSession, account):
|
||||||
|
"""Aggregate Subscription + PlanLimits + PlanBilling + resolved feature
|
||||||
|
flags for the account."""
|
||||||
|
from app.models.plan_limits import PlanLimits
|
||||||
|
from app.models.plan_billing import PlanBilling
|
||||||
|
from app.models.feature_flag import (
|
||||||
|
FeatureFlag, PlanFeatureDefault, AccountFeatureOverride,
|
||||||
|
)
|
||||||
|
|
||||||
|
sub = (await db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == account.id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=404, detail="No subscription for account")
|
||||||
|
|
||||||
|
pl = (await db.execute(
|
||||||
|
select(PlanLimits).where(PlanLimits.plan == sub.plan)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
pb = (await db.execute(
|
||||||
|
select(PlanBilling).where(PlanBilling.plan == sub.plan)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
|
||||||
|
# Resolved feature flags: plan defaults overridden by account overrides
|
||||||
|
defaults = (await db.execute(
|
||||||
|
select(PlanFeatureDefault, FeatureFlag)
|
||||||
|
.join(FeatureFlag, PlanFeatureDefault.flag_id == FeatureFlag.id)
|
||||||
|
.where(PlanFeatureDefault.plan == sub.plan)
|
||||||
|
)).all()
|
||||||
|
resolved = {flag.flag_key: pfd.enabled for pfd, flag in defaults}
|
||||||
|
overrides = (await db.execute(
|
||||||
|
select(AccountFeatureOverride, FeatureFlag)
|
||||||
|
.join(FeatureFlag, AccountFeatureOverride.flag_id == FeatureFlag.id)
|
||||||
|
.where(AccountFeatureOverride.account_id == account.id)
|
||||||
|
)).all()
|
||||||
|
for ovr, flag in overrides:
|
||||||
|
resolved[flag.flag_key] = ovr.enabled
|
||||||
|
|
||||||
|
return {
|
||||||
|
"subscription": {
|
||||||
|
"status": sub.status,
|
||||||
|
"plan": sub.plan,
|
||||||
|
"current_period_start": sub.current_period_start,
|
||||||
|
"current_period_end": sub.current_period_end,
|
||||||
|
"cancel_at_period_end": sub.cancel_at_period_end,
|
||||||
|
"seat_limit": sub.seat_limit,
|
||||||
|
"has_pro_entitlement": sub.has_pro_entitlement,
|
||||||
|
"is_paid": sub.is_paid,
|
||||||
|
},
|
||||||
|
"plan_billing": pb,
|
||||||
|
"plan_limits": _plan_limits_to_dict(pl) if pl else {},
|
||||||
|
"enabled_features": resolved,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def apply_subscription_event(
|
async def apply_subscription_event(
|
||||||
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
db: AsyncSession, event_id: str, event_type: str, payload: dict
|
||||||
@@ -136,6 +191,10 @@ class BillingService:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _plan_limits_to_dict(pl) -> dict:
|
||||||
|
return {c.name: getattr(pl, c.name) for c in pl.__table__.columns}
|
||||||
|
|
||||||
|
|
||||||
def _excerpt(payload: dict) -> dict:
|
def _excerpt(payload: dict) -> dict:
|
||||||
obj = payload.get("data", {}).get("object", {})
|
obj = payload.get("data", {}).get("object", {})
|
||||||
return {
|
return {
|
||||||
|
|||||||
71
backend/app/services/oauth_providers.py
Normal file
71
backend/app/services/oauth_providers.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""OAuth provider helpers. Each provider exposes:
|
||||||
|
- exchange_code(code, redirect_uri) -> OAuthProfile
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthProfile:
|
||||||
|
provider_subject: str
|
||||||
|
email: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
async def google_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as cli:
|
||||||
|
token_response = await cli.post(
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": settings.GOOGLE_CLIENT_ID,
|
||||||
|
"client_secret": settings.GOOGLE_CLIENT_SECRET,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token_response.raise_for_status()
|
||||||
|
access_token = token_response.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo = await cli.get(
|
||||||
|
"https://openidconnect.googleapis.com/v1/userinfo",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo.raise_for_status()
|
||||||
|
data = userinfo.json()
|
||||||
|
return OAuthProfile(
|
||||||
|
provider_subject=data["sub"],
|
||||||
|
email=data["email"],
|
||||||
|
name=data.get("name") or data["email"].split("@")[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def microsoft_exchange_code(code: str, redirect_uri: str) -> OAuthProfile:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as cli:
|
||||||
|
token_response = await cli.post(
|
||||||
|
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": settings.MS_CLIENT_ID,
|
||||||
|
"client_secret": settings.MS_CLIENT_SECRET,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token_response.raise_for_status()
|
||||||
|
access_token = token_response.json()["access_token"]
|
||||||
|
|
||||||
|
userinfo = await cli.get(
|
||||||
|
"https://graph.microsoft.com/v1.0/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
userinfo.raise_for_status()
|
||||||
|
data = userinfo.json()
|
||||||
|
return OAuthProfile(
|
||||||
|
provider_subject=data["id"],
|
||||||
|
email=data.get("mail") or data["userPrincipalName"],
|
||||||
|
name=data.get("displayName") or data["userPrincipalName"].split("@")[0],
|
||||||
|
)
|
||||||
@@ -97,7 +97,18 @@ async def main() -> None:
|
|||||||
)
|
)
|
||||||
row = result.first()
|
row = result.first()
|
||||||
if row:
|
if row:
|
||||||
print(f" [SKIP] {cfg['email']} already exists")
|
# Backfill email_verified_at for existing rows so older test
|
||||||
|
# users created before this script set the field still bypass
|
||||||
|
# the 7-day verification grace.
|
||||||
|
await conn.execute(
|
||||||
|
text("""
|
||||||
|
UPDATE users
|
||||||
|
SET email_verified_at = COALESCE(email_verified_at, :now)
|
||||||
|
WHERE email = :email
|
||||||
|
"""),
|
||||||
|
{"email": cfg["email"], "now": now},
|
||||||
|
)
|
||||||
|
print(f" [SKIP] {cfg['email']} already exists (email_verified_at backfilled if null)")
|
||||||
if cfg["key"] == "team_admin":
|
if cfg["key"] == "team_admin":
|
||||||
team_account_id = row.account_id
|
team_account_id = row.account_id
|
||||||
continue
|
continue
|
||||||
@@ -130,12 +141,17 @@ async def main() -> None:
|
|||||||
|
|
||||||
# ---- Create User ----
|
# ---- Create User ----
|
||||||
user_id = uuid.uuid4()
|
user_id = uuid.uuid4()
|
||||||
|
# email_verified_at is stamped at seed time so test users bypass the
|
||||||
|
# 7-day verification grace immediately. Without this, fixtures hit
|
||||||
|
# require_verified_email_after_grace once their created_at ages past
|
||||||
|
# 7 days and get walled out of protected routes.
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
text("""
|
text("""
|
||||||
INSERT INTO users (id, email, password_hash, name, role, is_super_admin,
|
INSERT INTO users (id, email, password_hash, name, role, is_super_admin,
|
||||||
is_team_admin, is_active, account_id, account_role, created_at)
|
is_team_admin, is_active, account_id, account_role,
|
||||||
|
created_at, email_verified_at)
|
||||||
VALUES (:id, :email, :pw, :name, 'engineer', :is_sa, :is_ta, true,
|
VALUES (:id, :email, :pw, :name, 'engineer', :is_sa, :is_ta, true,
|
||||||
:account_id, :account_role, :now)
|
:account_id, :account_role, :now, :now)
|
||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"id": user_id,
|
"id": user_id,
|
||||||
|
|||||||
180
backend/tests/test_account_invite_extensions.py
Normal file
180
backend/tests/test_account_invite_extensions.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_invite_sends_email_and_stamps_email_sent_at(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
"""Regression: today's create_invite does NOT send email. After this task, it MUST."""
|
||||||
|
with patch(
|
||||||
|
"app.core.email.EmailService.send_account_invite_email",
|
||||||
|
new_callable=AsyncMock, return_value=True,
|
||||||
|
) as mock_send:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/accounts/me/invites",
|
||||||
|
json={"email": "teammate@example.com", "role": "engineer"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 201, response.json()
|
||||||
|
mock_send.assert_called_once()
|
||||||
|
kwargs = mock_send.call_args.kwargs
|
||||||
|
assert kwargs["to_email"] == "teammate@example.com"
|
||||||
|
assert kwargs["role"] == "engineer"
|
||||||
|
assert kwargs["code"]
|
||||||
|
|
||||||
|
invite = (await test_db.execute(
|
||||||
|
select(AccountInvite).where(AccountInvite.email == "teammate@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
assert invite.email_sent_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_invite_email_failure_still_creates_row(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
"""When EmailService returns False, the invite row is still created but
|
||||||
|
email_sent_at remains NULL."""
|
||||||
|
with patch(
|
||||||
|
"app.core.email.EmailService.send_account_invite_email",
|
||||||
|
new_callable=AsyncMock, return_value=False,
|
||||||
|
):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/accounts/me/invites",
|
||||||
|
json={"email": "fail-mail@example.com", "role": "engineer"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
|
||||||
|
invite = (await test_db.execute(
|
||||||
|
select(AccountInvite).where(AccountInvite.email == "fail-mail@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
assert invite.email_sent_at is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_invite_creates_n_rows_and_sends_n_emails(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"app.core.email.EmailService.send_account_invite_email",
|
||||||
|
new_callable=AsyncMock, return_value=True,
|
||||||
|
) as mock_send:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/accounts/me/invites/bulk",
|
||||||
|
json={"invites": [
|
||||||
|
{"email": "a@example.com", "role": "engineer"},
|
||||||
|
{"email": "b@example.com", "role": "engineer"},
|
||||||
|
{"email": "c@example.com", "role": "viewer"},
|
||||||
|
]},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 201, response.json()
|
||||||
|
body = response.json()
|
||||||
|
assert len(body["created"]) == 3
|
||||||
|
assert body["failed"] == []
|
||||||
|
assert mock_send.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_invite_sets_revoked_at(client, test_db, test_user, auth_headers):
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
|
||||||
|
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=account_id,
|
||||||
|
invited_by_id=invited_by_id,
|
||||||
|
email="revoked@example.com",
|
||||||
|
code="REVOKEME01",
|
||||||
|
role="engineer",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
)
|
||||||
|
test_db.add(invite)
|
||||||
|
await test_db.commit()
|
||||||
|
invite_id = invite.id
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
await test_db.refresh(invite)
|
||||||
|
assert invite.revoked_at is not None
|
||||||
|
assert invite.is_valid is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_invite_idempotent(client, test_db, test_user, auth_headers):
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
|
||||||
|
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=account_id,
|
||||||
|
invited_by_id=invited_by_id,
|
||||||
|
email="revoked2@example.com",
|
||||||
|
code="REVOKEME02",
|
||||||
|
role="engineer",
|
||||||
|
revoked_at=datetime.now(timezone.utc),
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
)
|
||||||
|
test_db.add(invite)
|
||||||
|
await test_db.commit()
|
||||||
|
invite_id = invite.id
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_invite_404_when_not_found(client, test_user, auth_headers):
|
||||||
|
import uuid
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/accounts/me/invites/{uuid.uuid4()}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_used_invite_returns_400(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
|
||||||
|
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=account_id,
|
||||||
|
invited_by_id=invited_by_id,
|
||||||
|
email="used@example.com",
|
||||||
|
code="USEDCODE01",
|
||||||
|
role="engineer",
|
||||||
|
accepted_by_id=invited_by_id, # mark as used
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
)
|
||||||
|
test_db.add(invite)
|
||||||
|
await test_db.commit()
|
||||||
|
invite_id = invite.id
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/accounts/me/invites/{invite_id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
64
backend/tests/test_billing_state_endpoint.py
Normal file
64
backend/tests/test_billing_state_endpoint.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
from app.models.feature_flag import FeatureFlag, PlanFeatureDefault, AccountFeatureOverride
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_state_returns_subscription_plan_features(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
"""Subscription is already seeded by test_user fixture (pro/active).
|
||||||
|
Add a feature flag default for `pro` and verify it shows up in the response."""
|
||||||
|
flag = FeatureFlag(flag_key="psa_integration", display_name="PSA Integration")
|
||||||
|
test_db.add(flag)
|
||||||
|
await test_db.flush()
|
||||||
|
test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=True))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
body = response.json()
|
||||||
|
assert body["subscription"]["status"] == "active"
|
||||||
|
assert body["subscription"]["plan"] == "pro"
|
||||||
|
assert body["subscription"]["has_pro_entitlement"] is True
|
||||||
|
assert body["subscription"]["is_paid"] is True
|
||||||
|
assert body["enabled_features"]["psa_integration"] is True
|
||||||
|
# plan_limits should be a dict with the seeded pro limits from conftest
|
||||||
|
assert body["plan_limits"]["plan"] == "pro"
|
||||||
|
assert body["plan_limits"]["max_trees"] == 25
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_state_account_override_beats_plan_default(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
flag = FeatureFlag(flag_key="escalation_mode", display_name="Escalation Mode")
|
||||||
|
test_db.add(flag)
|
||||||
|
await test_db.flush()
|
||||||
|
test_db.add(PlanFeatureDefault(plan="pro", flag_id=flag.id, enabled=False))
|
||||||
|
test_db.add(AccountFeatureOverride(
|
||||||
|
account_id=account_id, flag_id=flag.id, enabled=True,
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["enabled_features"]["escalation_mode"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_billing_state_404_when_no_subscription(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
"""Wipe the seeded subscription and verify the endpoint surfaces 404."""
|
||||||
|
from sqlalchemy import delete
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/billing/state", headers=auth_headers)
|
||||||
|
assert response.status_code == 404
|
||||||
98
backend/tests/test_email_verification_autosend.py
Normal file
98
backend/tests/test_email_verification_autosend.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_auto_sends_verification_email(client, test_db):
|
||||||
|
"""Fresh registration triggers send_email_verification_email."""
|
||||||
|
with patch(
|
||||||
|
"app.core.email.EmailService.send_email_verification_email",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_send:
|
||||||
|
response = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "newshop@example.com",
|
||||||
|
"password": "Verystrong1Pwd",
|
||||||
|
"name": "New Shop",
|
||||||
|
})
|
||||||
|
assert response.status_code in (200, 201), response.json()
|
||||||
|
mock_send.assert_called_once()
|
||||||
|
kwargs = mock_send.call_args.kwargs
|
||||||
|
assert kwargs["to_email"] == "newshop@example.com"
|
||||||
|
assert "/verify-email?token=" in kwargs["verification_url"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_with_account_invite_code_email_mismatch_rejected(
|
||||||
|
client, test_db, test_user
|
||||||
|
):
|
||||||
|
"""Invite code is for invited@example.com but user registers with a
|
||||||
|
different email -> 400 invite_email_mismatch."""
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=account_id,
|
||||||
|
invited_by_id=invited_by_id,
|
||||||
|
email="invited@example.com",
|
||||||
|
code="INVITECODE99",
|
||||||
|
role="engineer",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
)
|
||||||
|
test_db.add(invite)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "wrong-email@example.com",
|
||||||
|
"password": "Verystrong1Pwd",
|
||||||
|
"name": "Wrong Email",
|
||||||
|
"account_invite_code": "INVITECODE99",
|
||||||
|
})
|
||||||
|
assert response.status_code == 400, response.json()
|
||||||
|
assert response.json()["detail"]["error"] == "invite_email_mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_with_account_invite_code_email_match_accepted(
|
||||||
|
client, test_db, test_user
|
||||||
|
):
|
||||||
|
"""Invite code is for invited@example.com - registering with that email
|
||||||
|
succeeds and joins the existing account."""
|
||||||
|
from app.models.account_invite import AccountInvite
|
||||||
|
from app.models.user import User
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
invited_by_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
|
||||||
|
invite = AccountInvite(
|
||||||
|
account_id=account_id,
|
||||||
|
invited_by_id=invited_by_id,
|
||||||
|
email="invited@example.com",
|
||||||
|
code="INVITECODE100",
|
||||||
|
role="engineer",
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
)
|
||||||
|
test_db.add(invite)
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.core.email.EmailService.send_email_verification_email",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
):
|
||||||
|
response = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "invited@example.com",
|
||||||
|
"password": "Verystrong1Pwd",
|
||||||
|
"name": "Invited",
|
||||||
|
"account_invite_code": "INVITECODE100",
|
||||||
|
})
|
||||||
|
assert response.status_code in (200, 201), response.json()
|
||||||
|
|
||||||
|
new_user = (await test_db.execute(
|
||||||
|
select(User).where(User.email == "invited@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
assert new_user.account_id == account_id # joined existing account
|
||||||
@@ -13,6 +13,14 @@ pytestmark = pytest.mark.asyncio
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def kb_setup(client, auth_headers, test_db):
|
async def kb_setup(client, auth_headers, test_db):
|
||||||
"""Seed KB plan limits and return helpers."""
|
"""Seed KB plan limits and return helpers."""
|
||||||
|
# KB tests were authored against a free-plan user. Phase 1 conftest seeds
|
||||||
|
# the test_user with a pro/active Subscription; downgrade to free here so
|
||||||
|
# quota numbers match the original test intent.
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
sub = (await test_db.execute(__import__("sqlalchemy").select(Subscription))).scalar_one()
|
||||||
|
sub.plan = "free"
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
# Update plan_limits with KB columns for 'free' plan
|
# Update plan_limits with KB columns for 'free' plan
|
||||||
await test_db.execute(
|
await test_db.execute(
|
||||||
__import__("sqlalchemy").text("""
|
__import__("sqlalchemy").text("""
|
||||||
|
|||||||
120
backend/tests/test_oauth_callbacks.py
Normal file
120
backend/tests/test_oauth_callbacks.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
from app.services.oauth_providers import OAuthProfile
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_creates_user_account_subscription(
|
||||||
|
client, test_db, monkeypatch
|
||||||
|
):
|
||||||
|
"""Brand-new user via Google OAuth -> User + Account + Subscription + OAuthIdentity."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||||
|
|
||||||
|
profile = OAuthProfile(
|
||||||
|
provider_subject="google_subject_123",
|
||||||
|
email="newuser@example.com",
|
||||||
|
name="New User",
|
||||||
|
)
|
||||||
|
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/google/callback", json={"code": "auth_code_xyz"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
body = response.json()
|
||||||
|
assert body["is_new_user"] is True
|
||||||
|
assert body["access_token"]
|
||||||
|
|
||||||
|
user = (await test_db.execute(
|
||||||
|
select(User).where(User.email == "newuser@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
assert user.password_hash is None
|
||||||
|
assert user.email_verified_at is not None
|
||||||
|
|
||||||
|
identity = (await test_db.execute(
|
||||||
|
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert identity.provider == "google"
|
||||||
|
assert identity.provider_subject == "google_subject_123"
|
||||||
|
|
||||||
|
sub = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == user.account_id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert sub.status == "trialing"
|
||||||
|
assert sub.plan == "pro"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_existing_user_is_idempotent(
|
||||||
|
client, test_db, test_user, monkeypatch
|
||||||
|
):
|
||||||
|
"""When test_user's email is already registered, OAuth links + returns the
|
||||||
|
same user. Two calls with same provider_subject must not duplicate
|
||||||
|
OAuthIdentity rows."""
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", "client_dummy")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_SECRET", "secret_dummy")
|
||||||
|
|
||||||
|
user_id = uuid.UUID(test_user["user_data"]["id"])
|
||||||
|
email = test_user["email"]
|
||||||
|
name = test_user["user_data"]["name"]
|
||||||
|
|
||||||
|
profile = OAuthProfile(
|
||||||
|
provider_subject="google_subject_456",
|
||||||
|
email=email,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
with patch("app.api.endpoints.oauth.google_exchange_code", return_value=profile):
|
||||||
|
r1 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||||
|
r2 = await client.post("/api/v1/auth/google/callback", json={"code": "x"})
|
||||||
|
assert r1.status_code == 200
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r1.json()["is_new_user"] is False
|
||||||
|
assert r2.json()["is_new_user"] is False
|
||||||
|
|
||||||
|
identities = (await test_db.execute(
|
||||||
|
select(OAuthIdentity).where(OAuthIdentity.user_id == user_id)
|
||||||
|
)).scalars().all()
|
||||||
|
assert len(identities) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_google_callback_503_when_unconfigured(client, monkeypatch):
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_CLIENT_ID", None)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/google/callback", json={"code": "x"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_microsoft_callback_creates_user(client, test_db, monkeypatch):
|
||||||
|
from app.core.config import settings
|
||||||
|
monkeypatch.setattr(settings, "MS_CLIENT_ID", "client_dummy")
|
||||||
|
monkeypatch.setattr(settings, "MS_CLIENT_SECRET", "secret_dummy")
|
||||||
|
|
||||||
|
profile = OAuthProfile(
|
||||||
|
provider_subject="ms_subject_789",
|
||||||
|
email="msuser@example.com",
|
||||||
|
name="MS User",
|
||||||
|
)
|
||||||
|
with patch("app.api.endpoints.oauth.microsoft_exchange_code", return_value=profile):
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/microsoft/callback", json={"code": "auth_code"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, response.json()
|
||||||
|
user = (await test_db.execute(
|
||||||
|
select(User).where(User.email == "msuser@example.com")
|
||||||
|
)).scalar_one()
|
||||||
|
identity = (await test_db.execute(
|
||||||
|
select(OAuthIdentity).where(OAuthIdentity.user_id == user.id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert identity.provider == "microsoft"
|
||||||
83
backend/tests/test_oauth_only_user_paths.py
Normal file
83
backend/tests/test_oauth_only_user_paths.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.oauth_identity import OAuthIdentity
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_oauth_only_user(test_db, email, *, with_identity=True):
|
||||||
|
"""Create an OAuth-only user (password_hash=None) directly in the test DB."""
|
||||||
|
import secrets
|
||||||
|
account = Account(
|
||||||
|
name=f"{email}-acct",
|
||||||
|
display_code=secrets.token_hex(4).upper(),
|
||||||
|
)
|
||||||
|
test_db.add(account)
|
||||||
|
await test_db.flush()
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
name="OAuth User",
|
||||||
|
password_hash=None,
|
||||||
|
account_id=account.id,
|
||||||
|
account_role="owner",
|
||||||
|
)
|
||||||
|
test_db.add(user)
|
||||||
|
await test_db.flush()
|
||||||
|
if with_identity:
|
||||||
|
test_db.add(OAuthIdentity(
|
||||||
|
user_id=user.id, provider="google",
|
||||||
|
provider_subject=f"google_{email}",
|
||||||
|
provider_email_at_link=email,
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_form_rejects_oauth_only_user_with_helpful_error(client, test_db):
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-only@example.com")
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
data={"username": "oauth-only@example.com", "password": "wontwork"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
body = response.json()
|
||||||
|
assert body["detail"]["error"] == "use_oauth_provider"
|
||||||
|
assert "google" in body["detail"]["providers"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_json_rejects_oauth_only_user(client, test_db):
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-only2@example.com")
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": "oauth-only2@example.com", "password": "wontwork"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["detail"]["error"] == "use_oauth_provider"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_password_forgot_silent_for_oauth_only_user(client, test_db):
|
||||||
|
"""OAuth-only users get the generic message; no email is sent."""
|
||||||
|
await _make_oauth_only_user(test_db, "oauth-forgot@example.com", with_identity=False)
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
with patch("app.core.email.EmailService.send_password_reset_email", new_callable=AsyncMock) as mock_send:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/password/forgot",
|
||||||
|
json={"email": "oauth-forgot@example.com"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_for_password_user_still_works(client, test_user):
|
||||||
|
"""Regression: existing password-based login must still succeed."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login/json",
|
||||||
|
json={"email": test_user["email"], "password": test_user["password"]},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["access_token"]
|
||||||
85
backend/tests/test_pilot_complimentary_backfill.py
Normal file
85
backend/tests/test_pilot_complimentary_backfill.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Smoke test for the complimentary backfill: assertions about the post-state.
|
||||||
|
The actual migration runs at deploy time; tests use create_all so the
|
||||||
|
migration body isn't executed automatically. We invoke the SQL inline to
|
||||||
|
exercise the same effect."""
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select, text, delete
|
||||||
|
from app.models.account import Account
|
||||||
|
from app.models.subscription import Subscription
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complimentary_backfill_sets_status_and_inserts_missing_rows(test_db):
|
||||||
|
"""Inline-run the backfill SQL and assert post-state."""
|
||||||
|
# Seed a fresh account with no subscription
|
||||||
|
no_sub_account = Account(name="NoSub", display_code="NOSUB001")
|
||||||
|
test_db.add(no_sub_account)
|
||||||
|
await test_db.flush()
|
||||||
|
|
||||||
|
# Seed an account with a trialing subscription (should become complimentary)
|
||||||
|
trial_account = Account(name="Trial", display_code="TRIAL001")
|
||||||
|
test_db.add(trial_account)
|
||||||
|
await test_db.flush()
|
||||||
|
test_db.add(Subscription(
|
||||||
|
account_id=trial_account.id, plan="free", status="trialing",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Seed an account with a canceled subscription (should be preserved)
|
||||||
|
canceled_account = Account(name="Cancel", display_code="CANCL001")
|
||||||
|
test_db.add(canceled_account)
|
||||||
|
await test_db.flush()
|
||||||
|
test_db.add(Subscription(
|
||||||
|
account_id=canceled_account.id, plan="pro", status="canceled",
|
||||||
|
))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
# Run the same SQL the migration runs
|
||||||
|
await test_db.execute(text("""
|
||||||
|
UPDATE subscriptions
|
||||||
|
SET status = 'complimentary', plan = 'pro',
|
||||||
|
current_period_end = NULL, current_period_start = NULL,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE status NOT IN ('canceled', 'past_due')
|
||||||
|
"""))
|
||||||
|
await test_db.execute(text("""
|
||||||
|
INSERT INTO subscriptions (id, account_id, plan, status, cancel_at_period_end, created_at, updated_at)
|
||||||
|
SELECT gen_random_uuid(), a.id, 'pro', 'complimentary', false, now(), now()
|
||||||
|
FROM accounts a
|
||||||
|
WHERE NOT EXISTS (SELECT 1 FROM subscriptions s WHERE s.account_id = a.id)
|
||||||
|
"""))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
# All three accounts now have a Subscription
|
||||||
|
no_sub_row = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == no_sub_account.id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert no_sub_row.status == "complimentary"
|
||||||
|
assert no_sub_row.plan == "pro"
|
||||||
|
|
||||||
|
trial_row = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == trial_account.id)
|
||||||
|
)).scalar_one()
|
||||||
|
assert trial_row.status == "complimentary"
|
||||||
|
assert trial_row.plan == "pro"
|
||||||
|
|
||||||
|
canceled_row = (await test_db.execute(
|
||||||
|
select(Subscription).where(Subscription.account_id == canceled_account.id)
|
||||||
|
)).scalar_one()
|
||||||
|
# Canceled is preserved
|
||||||
|
assert canceled_row.status == "canceled"
|
||||||
|
assert canceled_row.plan == "pro"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complimentary_subscription_passes_active_subscription_guard(
|
||||||
|
client, test_db, test_user, auth_headers
|
||||||
|
):
|
||||||
|
"""The require_active_subscription guard accepts complimentary status."""
|
||||||
|
account_id = uuid.UUID(test_user["user_data"]["account_id"])
|
||||||
|
await test_db.execute(delete(Subscription).where(Subscription.account_id == account_id))
|
||||||
|
test_db.add(Subscription(account_id=account_id, plan="pro", status="complimentary"))
|
||||||
|
await test_db.commit()
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/trees", headers=auth_headers)
|
||||||
|
assert response.status_code != 402
|
||||||
Reference in New Issue
Block a user