diff --git a/backend/app/api/endpoints/accounts.py b/backend/app/api/endpoints/accounts.py index f8dc3744..6b3fab83 100644 --- a/backend/app/api/endpoints/accounts.py +++ b/backend/app/api/endpoints/accounts.py @@ -19,7 +19,7 @@ from app.models.account_invite import AccountInvite from app.models.account_settings import AccountSettings from app.models.subscription import Subscription from app.models.user import User -from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, TransferOwnershipRequest +from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, AccountInviteBulkCreate, AccountInviteBulkResponse, TransferOwnershipRequest from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails from app.schemas.user import UserResponse, AccountRoleUpdate from app.core.security import verify_password @@ -299,6 +299,86 @@ async def create_invite( return invite +@router.post("/me/invites/bulk", response_model=AccountInviteBulkResponse, status_code=status.HTTP_201_CREATED) +async def create_invites_bulk( + payload: AccountInviteBulkCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Create multiple invites in one call (wizard step 3 supports up to N). + Per-row failures are returned in `failed`; successes in `created`.""" + # Lookup account once for email rendering + account_result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = account_result.scalar_one() + + created: list[AccountInvite] = [] + failed: list[dict] = [] + for invite_data in payload.invites: + try: + code = secrets.token_urlsafe(16) + expires_at = None + if invite_data.expires_in_days: + expires_at = datetime.now(timezone.utc) + timedelta(days=invite_data.expires_in_days) + + invite = AccountInvite( + account_id=current_user.account_id, + invited_by_id=current_user.id, + email=invite_data.email, + code=code, + role=invite_data.role, + expires_at=expires_at, + ) + db.add(invite) + await db.flush() + + email_sent = await EmailService.send_account_invite_email( + to_email=invite.email, + code=code, + account_name=account.name, + role=invite.role, + ) + if email_sent: + invite.email_sent_at = datetime.now(timezone.utc) + + created.append(invite) + except Exception as e: + failed.append({"email": invite_data.email, "error": str(e)}) + + await db.commit() + for inv in created: + await db.refresh(inv) + + return AccountInviteBulkResponse(created=created, failed=failed) + + +@router.delete("/me/invites/{invite_id}", status_code=status.HTTP_204_NO_CONTENT) +async def revoke_invite( + invite_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_account_owner)] +): + """Soft-revoke an invitation by setting revoked_at. Idempotent on already- + revoked invites; rejects already-accepted invites.""" + result = await db.execute( + select(AccountInvite).where( + AccountInvite.id == invite_id, + AccountInvite.account_id == current_user.account_id, + ) + ) + invite = result.scalar_one_or_none() + if not invite: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invite not found") + if invite.is_revoked: + return None # idempotent + if invite.is_used: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke an accepted invite") + invite.revoked_at = datetime.now(timezone.utc) + await db.commit() + return None + + @router.post("/me/invites/{invite_id}/resend", response_model=AccountInviteResponse) async def resend_invite( invite_id: UUID, diff --git a/backend/app/schemas/account.py b/backend/app/schemas/account.py index 6909d3d7..3d1e0c28 100644 --- a/backend/app/schemas/account.py +++ b/backend/app/schemas/account.py @@ -42,3 +42,12 @@ class AccountInviteResponse(BaseModel): used_at: Optional[datetime] = None model_config = {"from_attributes": True} + + +class AccountInviteBulkCreate(BaseModel): + invites: list[AccountInviteCreate] + + +class AccountInviteBulkResponse(BaseModel): + created: list[AccountInviteResponse] + failed: list[dict] # entries shaped {"email": str, "error": str} diff --git a/backend/tests/test_account_invite_extensions.py b/backend/tests/test_account_invite_extensions.py index 994903bb..698ec84b 100644 --- a/backend/tests/test_account_invite_extensions.py +++ b/backend/tests/test_account_invite_extensions.py @@ -52,3 +52,129 @@ async def test_create_invite_email_failure_still_creates_row( select(AccountInvite).where(AccountInvite.email == "fail-mail@example.com") )).scalar_one() assert invite.email_sent_at is None + + +@pytest.mark.asyncio +async def test_bulk_invite_creates_n_rows_and_sends_n_emails( + client, test_db, test_user, auth_headers +): + with patch( + "app.core.email.EmailService.send_account_invite_email", + new_callable=AsyncMock, return_value=True, + ) as mock_send: + response = await client.post( + "/api/v1/accounts/me/invites/bulk", + json={"invites": [ + {"email": "a@example.com", "role": "engineer"}, + {"email": "b@example.com", "role": "engineer"}, + {"email": "c@example.com", "role": "viewer"}, + ]}, + headers=auth_headers, + ) + assert response.status_code == 201, response.json() + body = response.json() + assert len(body["created"]) == 3 + assert body["failed"] == [] + assert mock_send.call_count == 3 + + +@pytest.mark.asyncio +async def test_revoke_invite_sets_revoked_at(client, test_db, test_user, auth_headers): + import uuid + from datetime import datetime, timezone, timedelta + from app.models.account_invite import AccountInvite + + invited_by_id = uuid.UUID(test_user["user_data"]["id"]) + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + + invite = AccountInvite( + account_id=account_id, + invited_by_id=invited_by_id, + email="revoked@example.com", + code="REVOKEME01", + role="engineer", + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + ) + test_db.add(invite) + await test_db.commit() + invite_id = invite.id + + response = await client.delete( + f"/api/v1/accounts/me/invites/{invite_id}", + headers=auth_headers, + ) + assert response.status_code == 204 + + await test_db.refresh(invite) + assert invite.revoked_at is not None + assert invite.is_valid is False + + +@pytest.mark.asyncio +async def test_revoke_invite_idempotent(client, test_db, test_user, auth_headers): + import uuid + from datetime import datetime, timezone, timedelta + from app.models.account_invite import AccountInvite + + invited_by_id = uuid.UUID(test_user["user_data"]["id"]) + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + + invite = AccountInvite( + account_id=account_id, + invited_by_id=invited_by_id, + email="revoked2@example.com", + code="REVOKEME02", + role="engineer", + revoked_at=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + ) + test_db.add(invite) + await test_db.commit() + invite_id = invite.id + + response = await client.delete( + f"/api/v1/accounts/me/invites/{invite_id}", + headers=auth_headers, + ) + assert response.status_code == 204 + + +@pytest.mark.asyncio +async def test_revoke_invite_404_when_not_found(client, test_user, auth_headers): + import uuid + response = await client.delete( + f"/api/v1/accounts/me/invites/{uuid.uuid4()}", + headers=auth_headers, + ) + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_revoke_used_invite_returns_400( + client, test_db, test_user, auth_headers +): + import uuid + from datetime import datetime, timezone, timedelta + from app.models.account_invite import AccountInvite + + invited_by_id = uuid.UUID(test_user["user_data"]["id"]) + account_id = uuid.UUID(test_user["user_data"]["account_id"]) + + invite = AccountInvite( + account_id=account_id, + invited_by_id=invited_by_id, + email="used@example.com", + code="USEDCODE01", + role="engineer", + accepted_by_id=invited_by_id, # mark as used + expires_at=datetime.now(timezone.utc) + timedelta(days=7), + ) + test_db.add(invite) + await test_db.commit() + invite_id = invite.id + + response = await client.delete( + f"/api/v1/accounts/me/invites/{invite_id}", + headers=auth_headers, + ) + assert response.status_code == 400