feat: Slate & Ice Modern aesthetic redesign (#94)
* chore: update Google Fonts to Bricolage Grotesque, IBM Plex Sans, JetBrains Mono Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: update Tailwind config to Slate & Ice theme colors and fonts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update CSS variables and glass-card utilities for Slate & Ice theme - Replace all color variables with Slate & Ice palette - Add glass system vars (--glass-bg, --glass-blur, --shadow-float) - Replace legacy glass-card with new variable-driven glass classes - Add breatheGlow, bellWobble, slideDown, fadeInRight keyframes - Update font references to IBM Plex Sans and Bricolage Grotesque Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: recolor BrandLogo to cyan gradient, split BrandWordmark for gradient Flow text Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update TopBar with glassmorphism backdrop and cyan accent styling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update Sidebar with glassmorphism backdrop Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add ambient atmosphere gradient orbs behind app shell Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update QuickStats and SessionsPanel with glass-card styling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add WeeklyCalendar, QuickActions, OpenSessions, RecentActivity dashboard components Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: redesign dashboard layout with calendar, open sessions, and glass-card panels New layout: greeting → calendar+actions → sessions+stats → activity Replaces old QuickStats and SessionsPanel with new dashboard components Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: replace remaining purple hex references with ice-cyan accent Sweep of hardcoded purple hex values (#818cf8, #6366f1) replaced with new cyan accent (#06b6d4) in QuickActions, RecentActivity, QuickLaunch, and SVG brand assets. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: update CLAUDE.md branding and design system for Slate & Ice Modern Updated Last Updated date, branding section (fonts, colors, glass utilities, atmosphere orbs), component styling rules, and Design System section to reflect the new ice-cyan glassmorphism theme. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: add Slate & Ice Modern design doc and implementation plan Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: redesign login page with Slate & Ice Modern design system Apply glassmorphism styling, atmosphere orbs, branded wordmark, and consistent design tokens to match the updated app shell aesthetic. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: raise TopBar z-index so profile dropdown renders above main content Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add AI assistant with in-session copilot and standalone chat with RAG Implements three-phase AI assistant feature: - Phase 0: RAG infrastructure with pgvector embeddings, Voyage AI integration, tree chunking service, and semantic search over team's flow library - Phase 1: In-session copilot panel during flow navigation with contextual AI help, current step awareness, and suggested related flows - Phase 2: Standalone AI chat page with persistent conversation history, pin/delete, and configurable retention policies (account-level) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add account management, email verification, AI fixes, and user guides - Profile settings, account transfer, delete/leave account flows - Email verification with JWT tokens and Resend integration - AI assistant/copilot fixes: markdown rendering, shared RAG helpers, token tracking, input refocus, model_validate usage - User guides hub + detail pages with 13 topic guides - Sidebar and top bar navigation for guides Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: prevent stale chunk errors after deployments - Set Cache-Control no-cache on index.html in nginx so browsers always fetch fresh chunk references after a deploy - Auto-reload on chunk load failures (stale deploy detection) with loop prevention via sessionStorage - Show friendly "App Updated" message if auto-reload doesn't resolve it Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: add email verification toggle to admin settings Adds platform-level toggle to enable/disable email verification. When disabled, the verification banner is hidden and the send endpoint returns 403. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit was merged in pull request #94.
This commit is contained in:
@@ -10,6 +10,10 @@ from alembic import context
|
||||
# Import your models
|
||||
from app.core.database import Base
|
||||
from app.models import User, Team, Tree, Session, Attachment, InviteCode
|
||||
from app.models.email_verification_token import EmailVerificationToken
|
||||
from app.models.tree_embedding import TreeEmbedding
|
||||
from app.models.copilot_conversation import CopilotConversation
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.core.config import settings
|
||||
|
||||
# this is the Alembic Config object
|
||||
|
||||
30
backend/alembic/versions/040_add_user_profile_fields.py
Normal file
30
backend/alembic/versions/040_add_user_profile_fields.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Add user profile fields (phone, job_title, timezone, avatar_url, email_verified_at)
|
||||
|
||||
Revision ID: 040
|
||||
Revises: fb1481317ff6
|
||||
Create Date: 2026-03-03
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = "040"
|
||||
down_revision = "e2d81e82ea5e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("users", sa.Column("phone", sa.String(50), nullable=True))
|
||||
op.add_column("users", sa.Column("job_title", sa.String(255), nullable=True))
|
||||
op.add_column("users", sa.Column("timezone", sa.String(100), nullable=False, server_default="UTC"))
|
||||
op.add_column("users", sa.Column("avatar_url", sa.String(500), nullable=True))
|
||||
op.add_column("users", sa.Column("email_verified_at", sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "email_verified_at")
|
||||
op.drop_column("users", "avatar_url")
|
||||
op.drop_column("users", "timezone")
|
||||
op.drop_column("users", "job_title")
|
||||
op.drop_column("users", "phone")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Add email_verification_tokens table
|
||||
|
||||
Revision ID: 041
|
||||
Revises: 040
|
||||
Create Date: 2026-03-03
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision = "041"
|
||||
down_revision = "040"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"email_verification_tokens",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("token_hash", sa.String(64), unique=True, nullable=False, index=True),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id"), nullable=False, index=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("email_verification_tokens")
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add pgvector extension and tree_embeddings table.
|
||||
|
||||
Revision ID: 042
|
||||
Revises: 041
|
||||
Create Date: 2026-03-04
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers
|
||||
revision: str = "042"
|
||||
down_revision: str = "041"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
op.create_table(
|
||||
"tree_embeddings",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("tree_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("trees.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("chunk_type", sa.String(30), nullable=False),
|
||||
sa.Column("node_type", sa.String(30), nullable=True),
|
||||
sa.Column("node_id", sa.String(100), nullable=True),
|
||||
sa.Column("chunk_text", sa.Text(), nullable=False),
|
||||
sa.Column("embedding_model", sa.String(50), nullable=False, server_default="voyage-3.5"),
|
||||
sa.Column("meta", postgresql.JSONB(), nullable=False, server_default=sa.text("'{}'::jsonb")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
op.execute("ALTER TABLE tree_embeddings ADD COLUMN embedding vector(1024)")
|
||||
|
||||
op.create_index("ix_tree_embeddings_account_id", "tree_embeddings", ["account_id"])
|
||||
op.create_index("ix_tree_embeddings_tree_id", "tree_embeddings", ["tree_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tree_embeddings")
|
||||
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||
39
backend/alembic/versions/043_add_copilot_conversations.py
Normal file
39
backend/alembic/versions/043_add_copilot_conversations.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Add copilot_conversations table.
|
||||
|
||||
Revision ID: 043
|
||||
Revises: 042
|
||||
Create Date: 2026-03-04
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "043"
|
||||
down_revision: str = "042"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"copilot_conversations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("session_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("sessions.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("tree_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("trees.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("messages", postgresql.JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("current_node_id", sa.String(100), nullable=True),
|
||||
sa.Column("message_count", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("total_input_tokens", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("total_output_tokens", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("copilot_conversations")
|
||||
37
backend/alembic/versions/044_add_assistant_chats.py
Normal file
37
backend/alembic/versions/044_add_assistant_chats.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Add assistant_chats table.
|
||||
|
||||
Revision ID: 044
|
||||
Revises: 043
|
||||
Create Date: 2026-03-04
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "044"
|
||||
down_revision: str = "043"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"assistant_chats",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("title", sa.String(255), nullable=False, server_default="New Chat"),
|
||||
sa.Column("messages", postgresql.JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb")),
|
||||
sa.Column("message_count", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("total_input_tokens", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("total_output_tokens", sa.Integer(), nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("pinned", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("assistant_chats")
|
||||
31
backend/alembic/versions/045_add_chat_retention_settings.py
Normal file
31
backend/alembic/versions/045_add_chat_retention_settings.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Add chat retention settings to accounts.
|
||||
|
||||
Revision ID: 045
|
||||
Revises: 044
|
||||
Create Date: 2026-03-04
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "045"
|
||||
down_revision: str = "044"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"accounts",
|
||||
sa.Column("chat_retention_days", sa.Integer(), nullable=True, server_default=sa.text("90")),
|
||||
)
|
||||
op.add_column(
|
||||
"accounts",
|
||||
sa.Column("chat_retention_max_count", sa.Integer(), nullable=True, server_default=sa.text("100")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("accounts", "chat_retention_max_count")
|
||||
op.drop_column("accounts", "chat_retention_days")
|
||||
@@ -7,17 +7,20 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.database import get_db
|
||||
from app.core.subscriptions import get_account_subscription, get_plan_limits, get_account_usage
|
||||
from app.core.audit import log_audit
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.core.email import EmailService
|
||||
from app.models.account import Account
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse
|
||||
from app.schemas.account import AccountResponse, AccountUpdate, AccountInviteCreate, AccountInviteResponse, TransferOwnershipRequest
|
||||
from app.schemas.subscription import SubscriptionResponse, PlanLimitsResponse, UsageResponse, SubscriptionDetails
|
||||
from app.schemas.user import UserResponse, AccountRoleUpdate
|
||||
from app.core.security import verify_password
|
||||
from app.api.deps import get_current_active_user, require_account_owner
|
||||
|
||||
router = APIRouter(prefix="/accounts", tags=["accounts"])
|
||||
@@ -142,6 +145,58 @@ async def update_member_role(
|
||||
return user
|
||||
|
||||
|
||||
@router.post("/me/transfer-ownership", response_model=AccountResponse)
|
||||
async def transfer_ownership(
|
||||
data: TransferOwnershipRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Transfer account ownership to another member (owner only)."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
if data.target_user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot transfer ownership to yourself"
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == data.target_user_id,
|
||||
User.account_id == current_user.account_id
|
||||
)
|
||||
)
|
||||
target_user = result.scalar_one_or_none()
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in your account"
|
||||
)
|
||||
|
||||
# Swap roles
|
||||
current_user.account_role = "engineer"
|
||||
target_user.account_role = "owner"
|
||||
|
||||
# Update account owner
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one()
|
||||
account.owner_id = target_user.id
|
||||
|
||||
await log_audit(
|
||||
db, current_user.id, "account.ownership_transfer", "account", account.id,
|
||||
{"new_owner_id": str(target_user.id)}
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
return account
|
||||
|
||||
|
||||
@router.delete("/me/members/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_member(
|
||||
user_id: UUID,
|
||||
@@ -318,3 +373,95 @@ async def list_invites(
|
||||
.order_by(AccountInvite.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/me/leave")
|
||||
async def leave_account(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Leave the current account (non-owners only). Creates a personal account."""
|
||||
if current_user.account_role == "owner":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Account owners cannot leave. Transfer ownership first."
|
||||
)
|
||||
|
||||
# Create a personal account (same pattern as remove_member)
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
display_code = ''.join(secrets.choice(chars) for _ in range(8))
|
||||
|
||||
new_account = Account(
|
||||
name=f"{current_user.name}'s Account",
|
||||
display_code=display_code,
|
||||
owner_id=current_user.id,
|
||||
)
|
||||
db.add(new_account)
|
||||
await db.flush()
|
||||
|
||||
new_sub = Subscription(
|
||||
account_id=new_account.id,
|
||||
plan="free",
|
||||
status="active",
|
||||
)
|
||||
db.add(new_sub)
|
||||
|
||||
old_account_id = current_user.account_id
|
||||
current_user.account_id = new_account.id
|
||||
current_user.account_role = "owner"
|
||||
|
||||
await log_audit(db, current_user.id, "account.leave", "account", old_account_id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "You have left the account"}
|
||||
|
||||
|
||||
class DeleteAccountRequest(BaseModel):
|
||||
current_password: str
|
||||
|
||||
|
||||
@router.delete("/me")
|
||||
async def delete_account(
|
||||
data: DeleteAccountRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_account_owner)]
|
||||
):
|
||||
"""Delete the current account and soft-delete the user (owner only, no other members)."""
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
# Check no other members
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.account_id == current_user.account_id,
|
||||
User.id != current_user.id,
|
||||
User.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
if result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete account with other members. Remove them first."
|
||||
)
|
||||
|
||||
# Soft-delete user
|
||||
current_user.deleted_at = datetime.now(timezone.utc)
|
||||
current_user.is_active = False
|
||||
|
||||
# Revoke all refresh tokens
|
||||
rt_result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.user_id == current_user.id,
|
||||
RefreshToken.revoked_at.is_(None)
|
||||
)
|
||||
)
|
||||
for rt in rt_result.scalars().all():
|
||||
rt.revoked_at = datetime.now(timezone.utc)
|
||||
|
||||
await log_audit(db, current_user.id, "account.delete", "account", current_user.account_id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Account deleted"}
|
||||
|
||||
320
backend/app/api/endpoints/assistant_chat.py
Normal file
320
backend/app/api/endpoints/assistant_chat.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Standalone AI assistant chat endpoints.
|
||||
|
||||
POST /assistant/chats — Create new chat
|
||||
GET /assistant/chats — List chats (paginated, newest first)
|
||||
GET /assistant/chats/{id} — Get chat with messages
|
||||
POST /assistant/chats/{id}/messages — Send message
|
||||
PATCH /assistant/chats/{id} — Update title, pin/unpin
|
||||
DELETE /assistant/chats/{id} — Delete single chat
|
||||
DELETE /assistant/chats — Bulk delete (older_than_days query param)
|
||||
GET /assistant/retention — Get account retention settings
|
||||
PATCH /assistant/retention — Update retention settings (owner only)
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import select, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.rate_limit import limiter
|
||||
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
|
||||
from app.core.config import settings
|
||||
from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.schemas.assistant_chat import (
|
||||
ChatCreateRequest,
|
||||
ChatMessageRequest,
|
||||
ChatMessageResponse,
|
||||
ChatListResponse,
|
||||
ChatDetailResponse,
|
||||
ChatUpdateRequest,
|
||||
RetentionSettingsResponse,
|
||||
RetentionSettingsUpdate,
|
||||
)
|
||||
from app.schemas.copilot import SuggestedFlow
|
||||
from app.services import assistant_chat_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/assistant", tags=["assistant-chat"])
|
||||
|
||||
|
||||
def _require_ai_enabled() -> None:
|
||||
if not settings.ai_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chats", response_model=ChatDetailResponse, status_code=201)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_chat(
|
||||
request: Request,
|
||||
data: ChatCreateRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Create a new empty chat conversation."""
|
||||
chat = await assistant_chat_service.create_chat(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
return ChatDetailResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.get("/chats", response_model=list[ChatListResponse])
|
||||
async def list_chats(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
):
|
||||
"""List user's chat conversations (newest first, pinned on top)."""
|
||||
offset = (page - 1) * size
|
||||
result = await db.execute(
|
||||
select(AssistantChat)
|
||||
.where(AssistantChat.user_id == current_user.id)
|
||||
.order_by(AssistantChat.pinned.desc(), AssistantChat.updated_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)
|
||||
chats = result.scalars().all()
|
||||
return [ChatListResponse.model_validate(c) for c in chats]
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatDetailResponse)
|
||||
async def get_chat(
|
||||
chat_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Get a chat with full message history."""
|
||||
result = await db.execute(
|
||||
select(AssistantChat).where(
|
||||
AssistantChat.id == chat_id,
|
||||
AssistantChat.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
chat = result.scalar_one_or_none()
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
return ChatDetailResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.post("/chats/{chat_id}/messages", response_model=ChatMessageResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def post_message(
|
||||
request: Request,
|
||||
chat_id: UUID,
|
||||
data: ChatMessageRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Send a message and get AI response."""
|
||||
_require_ai_enabled()
|
||||
|
||||
allowed, quota_status = await check_ai_quota(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
db=db,
|
||||
billing_anchor=current_user.ai_billing_cycle_anchor_at,
|
||||
is_super_admin=current_user.is_super_admin,
|
||||
)
|
||||
if not allowed:
|
||||
reset_key = "daily_reset_at" if quota_status.get("deny_reason") == "daily" else "monthly_reset_at"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"message": f"AI limit exceeded ({quota_status['deny_reason']})",
|
||||
"reset_at": quota_status.get(reset_key),
|
||||
"quota": quota_status,
|
||||
},
|
||||
)
|
||||
|
||||
plan = await get_user_plan(current_user.account_id, db)
|
||||
|
||||
try:
|
||||
ai_content, suggested_flows, chat = await assistant_chat_service.send_message(
|
||||
chat_id=chat_id,
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
message=data.message,
|
||||
db=db,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Assistant chat message failed: %s", e)
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=None,
|
||||
generation_type="assistant_message",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"assistant_chat_id": str(chat_id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||
)
|
||||
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=None,
|
||||
generation_type="assistant_message",
|
||||
tier=plan,
|
||||
input_tokens=chat.total_input_tokens,
|
||||
output_tokens=chat.total_output_tokens,
|
||||
estimated_cost=(
|
||||
chat.total_input_tokens * 1.0 / 1_000_000
|
||||
+ chat.total_output_tokens * 5.0 / 1_000_000
|
||||
),
|
||||
succeeded=True,
|
||||
counts_toward_quota=False,
|
||||
error_code=None,
|
||||
extra_data={"assistant_chat_id": str(chat_id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ChatMessageResponse(
|
||||
content=ai_content,
|
||||
suggested_flows=[SuggestedFlow.model_validate(sf) for sf in suggested_flows],
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/chats/{chat_id}", response_model=ChatDetailResponse)
|
||||
async def update_chat(
|
||||
chat_id: UUID,
|
||||
data: ChatUpdateRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Update chat title or pin/unpin."""
|
||||
result = await db.execute(
|
||||
select(AssistantChat).where(
|
||||
AssistantChat.id == chat_id,
|
||||
AssistantChat.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
chat = result.scalar_one_or_none()
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
|
||||
if data.title is not None:
|
||||
chat.title = data.title
|
||||
if data.pinned is not None:
|
||||
chat.pinned = data.pinned
|
||||
|
||||
await db.commit()
|
||||
return ChatDetailResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", status_code=204)
|
||||
async def delete_chat(
|
||||
chat_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Delete a single chat."""
|
||||
result = await db.execute(
|
||||
select(AssistantChat).where(
|
||||
AssistantChat.id == chat_id,
|
||||
AssistantChat.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
chat = result.scalar_one_or_none()
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
|
||||
await db.delete(chat)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.delete("/chats", status_code=204)
|
||||
async def bulk_delete_chats(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
older_than_days: int = Query(..., ge=1),
|
||||
):
|
||||
"""Bulk delete chats older than N days (skips pinned)."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=older_than_days)
|
||||
await db.execute(
|
||||
delete(AssistantChat).where(
|
||||
AssistantChat.user_id == current_user.id,
|
||||
AssistantChat.pinned == False, # noqa: E712
|
||||
AssistantChat.updated_at < cutoff,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.get("/retention", response_model=RetentionSettingsResponse)
|
||||
async def get_retention_settings(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Get account chat retention settings."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
return RetentionSettingsResponse(
|
||||
chat_retention_days=account.chat_retention_days,
|
||||
chat_retention_max_count=account.chat_retention_max_count,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/retention", response_model=RetentionSettingsResponse)
|
||||
async def update_retention_settings(
|
||||
data: RetentionSettingsUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Update account chat retention settings (account owner only)."""
|
||||
result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = result.scalar_one_or_none()
|
||||
if not account:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Account not found")
|
||||
|
||||
if account.owner_id != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the account owner can update retention settings",
|
||||
)
|
||||
|
||||
if data.chat_retention_days is not None:
|
||||
account.chat_retention_days = data.chat_retention_days
|
||||
if data.chat_retention_max_count is not None:
|
||||
account.chat_retention_max_count = data.chat_retention_max_count
|
||||
|
||||
await db.commit()
|
||||
|
||||
return RetentionSettingsResponse(
|
||||
chat_retention_days=account.chat_retention_days,
|
||||
chat_retention_max_count=account.chat_retention_max_count,
|
||||
)
|
||||
@@ -7,6 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.core.config import settings
|
||||
from app.core.settings_manager import SettingsManager
|
||||
from app.core.database import get_db
|
||||
from app.core.rate_limit import limiter
|
||||
from app.core.security import (
|
||||
@@ -15,6 +16,7 @@ from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
create_password_reset_token,
|
||||
create_email_verification_token,
|
||||
decode_token,
|
||||
hash_token,
|
||||
)
|
||||
@@ -24,7 +26,7 @@ from app.models.refresh_token import RefreshToken
|
||||
from app.models.account import Account
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.account_invite import AccountInvite
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin
|
||||
from app.schemas.user import UserCreate, UserResponse, UserLogin, UserUpdate
|
||||
from app.schemas.token import Token
|
||||
from app.schemas.auth_password import (
|
||||
ChangePasswordRequest,
|
||||
@@ -34,6 +36,7 @@ from app.schemas.auth_password import (
|
||||
ResetPasswordRequest,
|
||||
)
|
||||
from app.models.password_reset_token import PasswordResetToken
|
||||
from app.models.email_verification_token import EmailVerificationToken
|
||||
from app.core.email import EmailService
|
||||
from app.api.deps import get_current_active_user, get_refresh_token_payload
|
||||
from app.core.audit import log_audit
|
||||
@@ -351,6 +354,54 @@ async def get_me(
|
||||
return current_user
|
||||
|
||||
|
||||
@router.patch("/me", response_model=UserResponse)
|
||||
async def update_me(
|
||||
data: UserUpdate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Update current user's profile (name, email)."""
|
||||
update_fields = data.model_fields_set - {"current_password"}
|
||||
if not update_fields:
|
||||
return current_user
|
||||
|
||||
# Email change requires current_password
|
||||
if "email" in data.model_fields_set:
|
||||
if not data.current_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is required to change email"
|
||||
)
|
||||
if not verify_password(data.current_password, current_user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
# Check uniqueness
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == data.email, User.id != current_user.id)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
current_user.email = data.email
|
||||
|
||||
if "name" in data.model_fields_set and data.name is not None:
|
||||
current_user.name = data.name
|
||||
|
||||
# Handle simple string profile fields
|
||||
for field in ("phone", "job_title", "timezone"):
|
||||
if field in data.model_fields_set:
|
||||
setattr(current_user, field, getattr(data, field))
|
||||
|
||||
await log_audit(db, current_user.id, "auth.profile_update", "user", current_user.id)
|
||||
await db.commit()
|
||||
await db.refresh(current_user)
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
payload: Annotated[dict, Depends(get_refresh_token_payload)],
|
||||
@@ -543,3 +594,113 @@ async def reset_password(
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Password has been reset successfully"}
|
||||
|
||||
|
||||
@router.get("/email/verification-status")
|
||||
async def get_verification_status(
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Check if email verification is enabled on the platform."""
|
||||
enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
return {"enabled": enabled}
|
||||
|
||||
|
||||
@router.post("/email/send-verification")
|
||||
@limiter.limit("3/minute")
|
||||
async def send_verification_email(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Send an email verification link to the current user."""
|
||||
verification_enabled = await SettingsManager.get("email_verification_enabled", db, default=True)
|
||||
if not verification_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Email verification is currently disabled"
|
||||
)
|
||||
|
||||
if current_user.email_verified_at is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is already verified"
|
||||
)
|
||||
|
||||
raw_token = create_email_verification_token(str(current_user.id))
|
||||
payload = decode_token(raw_token)
|
||||
if payload and payload.get("jti"):
|
||||
token_record = EmailVerificationToken(
|
||||
token_hash=hash_token(payload["jti"]),
|
||||
user_id=current_user.id,
|
||||
expires_at=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
||||
)
|
||||
db.add(token_record)
|
||||
await db.commit()
|
||||
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={raw_token}"
|
||||
await EmailService.send_email_verification_email(
|
||||
to_email=current_user.email,
|
||||
verification_url=verification_url,
|
||||
)
|
||||
|
||||
return {"message": "Verification email sent"}
|
||||
|
||||
|
||||
@router.post("/email/verify")
|
||||
async def verify_email(
|
||||
data: dict,
|
||||
db: Annotated[AsyncSession, Depends(get_db)]
|
||||
):
|
||||
"""Verify an email using a token. Public endpoint."""
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token is required"
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if not payload or payload.get("type") != "email_verification":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired verification token"
|
||||
)
|
||||
|
||||
jti = payload.get("jti")
|
||||
user_id = payload.get("sub")
|
||||
if not jti or not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid verification token"
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(EmailVerificationToken).where(
|
||||
EmailVerificationToken.token_hash == hash_token(jti)
|
||||
)
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if not token_record or not token_record.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Verification token has already been used or has expired"
|
||||
)
|
||||
|
||||
# Mark token as used
|
||||
token_record.used_at = datetime.now(timezone.utc)
|
||||
|
||||
# Mark user email as verified
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid verification token"
|
||||
)
|
||||
|
||||
user.email_verified_at = datetime.now(timezone.utc)
|
||||
await log_audit(db, user.id, "auth.email_verified", "user", user.id)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Email verified successfully"}
|
||||
|
||||
192
backend/app/api/endpoints/copilot.py
Normal file
192
backend/app/api/endpoints/copilot.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""In-session copilot endpoints.
|
||||
|
||||
Contextual AI assistant during flow navigation:
|
||||
POST /copilot/conversations — Start conversation (requires tree_id)
|
||||
POST /copilot/conversations/{id}/messages — Send message, get response + suggestions
|
||||
GET /copilot/conversations/{id} — Get conversation history
|
||||
"""
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.rate_limit import limiter
|
||||
from app.api.deps import get_current_active_user, get_db, require_engineer_or_admin
|
||||
from app.core.config import settings
|
||||
from app.core.ai_quota_service import check_ai_quota, record_ai_usage, get_user_plan
|
||||
from app.models.user import User
|
||||
from app.schemas.copilot import (
|
||||
CopilotStartRequest,
|
||||
CopilotStartResponse,
|
||||
CopilotMessageRequest,
|
||||
CopilotMessageResponse,
|
||||
CopilotConversationResponse,
|
||||
SuggestedFlow,
|
||||
)
|
||||
from app.models.copilot_conversation import CopilotConversation
|
||||
from app.services import copilot_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/copilot", tags=["copilot"])
|
||||
|
||||
|
||||
def _require_ai_enabled() -> None:
|
||||
if not settings.ai_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="AI is not configured. Set GOOGLE_AI_API_KEY or ANTHROPIC_API_KEY.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conversations", response_model=CopilotStartResponse, status_code=201)
|
||||
@limiter.limit("10/minute")
|
||||
async def start_conversation(
|
||||
request: Request,
|
||||
data: CopilotStartRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Start a new copilot conversation for a flow."""
|
||||
_require_ai_enabled()
|
||||
|
||||
allowed, quota_status = await check_ai_quota(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
db=db,
|
||||
billing_anchor=current_user.ai_billing_cycle_anchor_at,
|
||||
is_super_admin=current_user.is_super_admin,
|
||||
)
|
||||
if not allowed:
|
||||
reset_key = "daily_reset_at" if quota_status.get("deny_reason") == "daily" else "monthly_reset_at"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"message": f"AI limit exceeded ({quota_status['deny_reason']})",
|
||||
"reset_at": quota_status.get(reset_key),
|
||||
"quota": quota_status,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
conversation, greeting = await copilot_service.start_conversation(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
tree_id=data.tree_id,
|
||||
session_id=data.session_id,
|
||||
current_node_id=data.current_node_id,
|
||||
db=db,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Copilot conversation start failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return CopilotStartResponse(
|
||||
conversation_id=conversation.id,
|
||||
greeting=greeting,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/messages", response_model=CopilotMessageResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def post_message(
|
||||
request: Request,
|
||||
conversation_id: UUID,
|
||||
data: CopilotMessageRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
_: None = Depends(require_engineer_or_admin),
|
||||
):
|
||||
"""Send a message and get AI response with flow suggestions."""
|
||||
_require_ai_enabled()
|
||||
|
||||
plan = await get_user_plan(current_user.account_id, db)
|
||||
|
||||
try:
|
||||
ai_content, suggested_flows, conversation = await copilot_service.send_message(
|
||||
conversation_id=conversation_id,
|
||||
user_id=current_user.id,
|
||||
message=data.message,
|
||||
current_node_id=data.current_node_id,
|
||||
db=db,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Copilot message failed: %s", e)
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=None,
|
||||
generation_type="copilot_message",
|
||||
tier=plan,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
estimated_cost=0,
|
||||
succeeded=False,
|
||||
counts_toward_quota=False,
|
||||
error_code=type(e).__name__,
|
||||
extra_data={"copilot_conversation_id": str(conversation_id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"AI provider error ({type(e).__name__}). Please try again.",
|
||||
)
|
||||
|
||||
await record_ai_usage(
|
||||
user_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
conversation_id=None,
|
||||
generation_type="copilot_message",
|
||||
tier=plan,
|
||||
input_tokens=conversation.total_input_tokens,
|
||||
output_tokens=conversation.total_output_tokens,
|
||||
estimated_cost=(
|
||||
conversation.total_input_tokens * 1.0 / 1_000_000
|
||||
+ conversation.total_output_tokens * 5.0 / 1_000_000
|
||||
),
|
||||
succeeded=True,
|
||||
counts_toward_quota=False,
|
||||
error_code=None,
|
||||
extra_data={"copilot_conversation_id": str(conversation_id)},
|
||||
db=db,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return CopilotMessageResponse(
|
||||
content=ai_content,
|
||||
suggested_flows=[SuggestedFlow.model_validate(sf) for sf in suggested_flows],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=CopilotConversationResponse)
|
||||
async def get_conversation(
|
||||
conversation_id: UUID,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Get copilot conversation history."""
|
||||
result = await db.execute(
|
||||
select(CopilotConversation).where(
|
||||
CopilotConversation.id == conversation_id,
|
||||
CopilotConversation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
conversation = result.scalar_one_or_none()
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
|
||||
|
||||
return CopilotConversationResponse.model_validate(conversation)
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
@@ -29,6 +30,7 @@ from app.core.audit import log_audit
|
||||
from app.core.config import settings
|
||||
from app.core.tree_validation import can_publish_tree
|
||||
from app.core.step_sync import sync_steps_from_tree, deactivate_synced_steps_for_tree
|
||||
from app.services.rag_service import index_tree as rag_index_tree
|
||||
|
||||
router = APIRouter(prefix="/trees", tags=["trees"])
|
||||
|
||||
@@ -542,6 +544,13 @@ async def create_tree(
|
||||
)
|
||||
tree = result.scalar_one()
|
||||
|
||||
# Index tree for RAG (best-effort, don't fail the request)
|
||||
try:
|
||||
await rag_index_tree(tree.id, db)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning("RAG indexing failed for tree %s", tree.id)
|
||||
|
||||
return build_full_tree_response(tree)
|
||||
|
||||
|
||||
@@ -725,6 +734,13 @@ async def update_tree(
|
||||
)
|
||||
tree = result.scalar_one()
|
||||
|
||||
# Re-index tree for RAG (best-effort)
|
||||
try:
|
||||
await rag_index_tree(tree.id, db)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning("RAG re-indexing failed for tree %s", tree.id)
|
||||
|
||||
return build_full_tree_response(tree)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from app.api.endpoints import feedback
|
||||
from app.api.endpoints import ai_builder
|
||||
from app.api.endpoints import ai_fix
|
||||
from app.api.endpoints import ai_chat
|
||||
from app.api.endpoints import copilot
|
||||
from app.api.endpoints import assistant_chat
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -40,3 +42,5 @@ api_router.include_router(feedback.router)
|
||||
api_router.include_router(ai_builder.router)
|
||||
api_router.include_router(ai_fix.router)
|
||||
api_router.include_router(ai_chat.router)
|
||||
api_router.include_router(copilot.router)
|
||||
api_router.include_router(assistant_chat.router)
|
||||
|
||||
@@ -115,7 +115,7 @@ async def check_ai_quota(
|
||||
select(func.count(AIUsage.id)).where(
|
||||
AIUsage.user_id == user_id,
|
||||
AIUsage.succeeded == True, # noqa: E712
|
||||
AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate"]),
|
||||
AIUsage.generation_type.in_(["scaffold", "branch_detail", "chat_message", "chat_generate", "copilot_message", "assistant_message"]),
|
||||
AIUsage.created_at >= day_start,
|
||||
)
|
||||
) or 0
|
||||
|
||||
@@ -84,6 +84,11 @@ class Settings(BaseSettings):
|
||||
AI_MODEL_GEMINI: str = "gemini-2.5-flash"
|
||||
AI_MODEL_ANTHROPIC: str = "claude-haiku-4-5-20251001"
|
||||
|
||||
# Embedding / RAG
|
||||
VOYAGE_API_KEY: Optional[str] = None
|
||||
EMBEDDING_MODEL: str = "voyage-3.5"
|
||||
EMBEDDING_DIMENSIONS: int = 1024
|
||||
|
||||
@property
|
||||
def ai_enabled(self) -> bool:
|
||||
"""Check if any AI provider is configured."""
|
||||
|
||||
@@ -163,6 +163,39 @@ class EmailService:
|
||||
logger.exception("Failed to send account invite email to %s", to_email)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def send_email_verification_email(
|
||||
to_email: str,
|
||||
verification_url: str,
|
||||
) -> bool:
|
||||
if not settings.email_enabled:
|
||||
logger.warning("Email not sent — RESEND_API_KEY not configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
import resend
|
||||
|
||||
resend.api_key = settings.RESEND_API_KEY
|
||||
|
||||
subject = "Verify Your Email — ResolutionFlow"
|
||||
|
||||
html = _render_email_verification_html(verification_url=verification_url)
|
||||
|
||||
resend.Emails.send(
|
||||
{
|
||||
"from": settings.FROM_EMAIL,
|
||||
"to": [to_email],
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
}
|
||||
)
|
||||
logger.info("Verification email sent to %s", to_email)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to send verification email to %s", to_email)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def send_feedback_email(
|
||||
to_email: str,
|
||||
@@ -485,6 +518,38 @@ def _render_feedback_html(
|
||||
</body></html>"""
|
||||
|
||||
|
||||
def _render_email_verification_html(verification_url: str) -> str:
|
||||
return f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><meta name="viewport" content="width=device-width"></head>
|
||||
<body style="margin:0;padding:0;background:#000;font-family:'Inter',Helvetica,Arial,sans-serif;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" style="background:#000;padding:40px 0;">
|
||||
<tr><td align="center">
|
||||
<table width="560" cellpadding="0" cellspacing="0" style="background:#111;border:1px solid rgba(255,255,255,0.06);border-radius:16px;">
|
||||
<tr><td style="padding:40px 40px 24px;text-align:center;">
|
||||
<h1 style="margin:0;color:#fff;font-size:24px;font-weight:600;">ResolutionFlow</h1>
|
||||
<p style="margin:8px 0 0;color:#a0a0a0;font-size:14px;">Decision Tree Platform for MSP Professionals</p>
|
||||
</td></tr>
|
||||
<tr><td style="padding:0 40px 24px;">
|
||||
<p style="margin:0;color:#e0e0e0;font-size:16px;line-height:1.6;">
|
||||
Please verify your email address by clicking the button below. This link expires in 24 hours.
|
||||
</p>
|
||||
</td></tr>
|
||||
<tr><td style="padding:0 40px 32px;text-align:center;">
|
||||
<a href="{verification_url}" style="display:inline-block;background:#fff;color:#000;font-size:16px;font-weight:600;text-decoration:none;padding:14px 40px;border-radius:8px;">
|
||||
Verify Email
|
||||
</a>
|
||||
</td></tr>
|
||||
<tr><td style="padding:0 40px 32px;">
|
||||
<p style="margin:0;color:#666;font-size:12px;text-align:center;">
|
||||
If you didn't create an account, you can safely ignore this email.
|
||||
</p>
|
||||
</td></tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body></html>"""
|
||||
|
||||
|
||||
def _render_feedback_confirmation_html(
|
||||
feedback_type: str,
|
||||
message_preview: str,
|
||||
|
||||
@@ -70,6 +70,19 @@ def create_password_reset_token(user_id: str) -> str:
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_email_verification_token(user_id: str) -> str:
|
||||
"""Create a JWT email verification token (24-hour expiry, unique JTI)."""
|
||||
jti = str(uuid.uuid4())
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
to_encode = {
|
||||
"sub": user_id,
|
||||
"type": "email_verification",
|
||||
"jti": jti,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def generate_temp_password(length: int = 16) -> str:
|
||||
"""Generate a temporary password with guaranteed complexity.
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.middleware import RequestLoggingMiddleware, ErrorLoggingMiddleware
|
||||
from app.core.rate_limit import limiter
|
||||
from app.api.router import api_router
|
||||
from app.core.scheduler import scheduler, load_all_schedules, _cleanup_expired_ai_conversations
|
||||
from app.services.retention_cleanup import cleanup_expired_chats
|
||||
from app.core.service_account import ensure_service_account
|
||||
|
||||
# Initialize logging configuration
|
||||
@@ -122,6 +123,15 @@ async def lifespan(app: FastAPI):
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Chat retention cleanup (daily)
|
||||
scheduler.add_job(
|
||||
cleanup_expired_chats,
|
||||
trigger="interval",
|
||||
hours=24,
|
||||
id="cleanup_expired_chats",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Auto-seed trees in background on PR environments
|
||||
seed_task = None
|
||||
if settings.SEED_ON_DEPLOY:
|
||||
|
||||
@@ -29,6 +29,9 @@ from .feedback import Feedback
|
||||
from .ai_conversation import AIConversation
|
||||
from .ai_usage import AIUsage
|
||||
from .ai_chat_session import AIChatSession
|
||||
from .tree_embedding import TreeEmbedding
|
||||
from .copilot_conversation import CopilotConversation
|
||||
from .assistant_chat import AssistantChat
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -69,4 +72,7 @@ __all__ = [
|
||||
"AIConversation",
|
||||
"AIUsage",
|
||||
"AIChatSession",
|
||||
"TreeEmbedding",
|
||||
"CopilotConversation",
|
||||
"AssistantChat",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
@@ -35,6 +35,14 @@ class Account(Base):
|
||||
comment="Policy: engineers can create public shares. Only affects NEW shares (grandfathered)."
|
||||
)
|
||||
|
||||
# Chat retention settings
|
||||
chat_retention_days: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, nullable=True, default=90, server_default="90"
|
||||
)
|
||||
chat_retention_max_count: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, nullable=True, default=100, server_default="100"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")
|
||||
users: Mapped[list["User"]] = relationship("User", foreign_keys="[User.account_id]", back_populates="account")
|
||||
|
||||
59
backend/app/models/assistant_chat.py
Normal file
59
backend/app/models/assistant_chat.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Standalone AI assistant chat model.
|
||||
|
||||
Persistent conversation history for general IT questions with RAG context.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Integer, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class AssistantChat(Base):
|
||||
__tablename__ = "assistant_chats"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
title: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, default="New Chat"
|
||||
)
|
||||
messages: Mapped[list[dict[str, Any]]] = mapped_column(
|
||||
JSONB, nullable=False, default=list
|
||||
)
|
||||
message_count: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
total_input_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
total_output_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
pinned: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
69
backend/app/models/copilot_conversation.py
Normal file
69
backend/app/models/copilot_conversation.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Copilot conversation model for in-session AI assistant.
|
||||
|
||||
Tracks conversation state during flow navigation with contextual AI help.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class CopilotConversation(Base):
|
||||
__tablename__ = "copilot_conversations"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
session_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sessions.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
tree_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
messages: Mapped[list[dict[str, Any]]] = mapped_column(
|
||||
JSONB, nullable=False, default=list
|
||||
)
|
||||
current_node_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(100), nullable=True
|
||||
)
|
||||
message_count: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
total_input_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
total_output_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
expires_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
42
backend/app/models/email_verification_token.py
Normal file
42
backend/app/models/email_verification_token.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class EmailVerificationToken(Base):
|
||||
__tablename__ = "email_verification_tokens"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_used(self) -> bool:
|
||||
return self.used_at is not None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return not self.is_used and not self.is_expired
|
||||
72
backend/app/models/tree_embedding.py
Normal file
72
backend/app/models/tree_embedding.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Tree embedding storage for RAG-powered AI assistant.
|
||||
|
||||
Stores vector embeddings of tree content chunks for semantic search.
|
||||
Each tree is split into multiple chunks (node, solution, tree_summary)
|
||||
and embedded via Voyage AI for cosine similarity retrieval.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
# pgvector column type — imported at runtime to avoid import errors
|
||||
# when pgvector is not installed locally
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class TreeEmbedding(Base):
|
||||
__tablename__ = "tree_embeddings"
|
||||
__table_args__ = (
|
||||
Index("ix_tree_embeddings_account_id", "account_id"),
|
||||
Index("ix_tree_embeddings_tree_id", "tree_id"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
tree_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
account_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
chunk_type: Mapped[str] = mapped_column(
|
||||
String(30),
|
||||
nullable=False,
|
||||
comment="node | solution | tree_summary",
|
||||
)
|
||||
node_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(30), nullable=True
|
||||
)
|
||||
node_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(100), nullable=True
|
||||
)
|
||||
chunk_text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
embedding_model: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, default="voyage-3.5"
|
||||
)
|
||||
# The embedding column is created via migration with vector(1024) type
|
||||
# We store it as a generic column here and handle it in queries
|
||||
meta: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB, nullable=False, default=dict
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -68,6 +68,15 @@ class User(Base):
|
||||
)
|
||||
last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Profile fields
|
||||
phone: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
||||
job_title: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
timezone: Mapped[str] = mapped_column(String(100), nullable=False, default="UTC", server_default="UTC")
|
||||
avatar_url: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
email_verified_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# AI billing cycle anchor (for quota reset calculation)
|
||||
ai_billing_cycle_anchor_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
|
||||
@@ -20,6 +20,11 @@ class AccountUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
|
||||
|
||||
class TransferOwnershipRequest(BaseModel):
|
||||
current_password: str
|
||||
target_user_id: UUID
|
||||
|
||||
|
||||
class AccountInviteCreate(BaseModel):
|
||||
email: str = Field(..., max_length=255)
|
||||
role: str = Field("engineer", pattern="^(engineer|viewer)$")
|
||||
|
||||
59
backend/app/schemas/assistant_chat.py
Normal file
59
backend/app/schemas/assistant_chat.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Pydantic schemas for standalone AI assistant chat."""
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.copilot import SuggestedFlow
|
||||
|
||||
|
||||
class ChatCreateRequest(BaseModel):
|
||||
"""Empty body — creates a new blank conversation."""
|
||||
pass
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=8000)
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
content: str
|
||||
suggested_flows: list[SuggestedFlow] = []
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
id: UUID
|
||||
title: str
|
||||
message_count: int
|
||||
pinned: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ChatDetailResponse(BaseModel):
|
||||
id: UUID
|
||||
title: str
|
||||
messages: list[dict[str, Any]]
|
||||
message_count: int
|
||||
pinned: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ChatUpdateRequest(BaseModel):
|
||||
title: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
pinned: Optional[bool] = None
|
||||
|
||||
|
||||
class RetentionSettingsResponse(BaseModel):
|
||||
chat_retention_days: Optional[int]
|
||||
chat_retention_max_count: Optional[int]
|
||||
|
||||
|
||||
class RetentionSettingsUpdate(BaseModel):
|
||||
chat_retention_days: Optional[int] = Field(None, ge=1, le=365)
|
||||
chat_retention_max_count: Optional[int] = Field(None, ge=10, le=10000)
|
||||
44
backend/app/schemas/copilot.py
Normal file
44
backend/app/schemas/copilot.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Pydantic schemas for the in-session copilot."""
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SuggestedFlow(BaseModel):
|
||||
tree_id: UUID
|
||||
tree_name: str
|
||||
tree_type: str
|
||||
relevance_snippet: str
|
||||
|
||||
|
||||
class CopilotStartRequest(BaseModel):
|
||||
tree_id: UUID
|
||||
session_id: Optional[UUID] = None
|
||||
current_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class CopilotStartResponse(BaseModel):
|
||||
conversation_id: UUID
|
||||
greeting: str
|
||||
|
||||
|
||||
class CopilotMessageRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=4000)
|
||||
current_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class CopilotMessageResponse(BaseModel):
|
||||
content: str
|
||||
suggested_flows: list[SuggestedFlow] = []
|
||||
|
||||
|
||||
class CopilotConversationResponse(BaseModel):
|
||||
id: UUID
|
||||
tree_id: UUID
|
||||
messages: list[dict[str, Any]]
|
||||
current_node_id: Optional[str] = None
|
||||
message_count: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -30,6 +30,10 @@ class UserCreate(UserBase):
|
||||
class UserUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
email: Optional[EmailStr] = None
|
||||
current_password: Optional[str] = Field(None, description="Required when changing email")
|
||||
phone: Optional[str] = Field(None, max_length=50)
|
||||
job_title: Optional[str] = Field(None, max_length=255)
|
||||
timezone: Optional[str] = Field(None, max_length=100)
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
@@ -48,6 +52,11 @@ class UserResponse(UserBase):
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
phone: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
timezone: str = "UTC"
|
||||
avatar_url: Optional[str] = None
|
||||
email_verified_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
122
backend/app/services/assistant_chat_service.py
Normal file
122
backend/app/services/assistant_chat_service.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Standalone AI assistant chat service with RAG context.
|
||||
|
||||
Provides persistent conversation history for general IT questions
|
||||
with semantic search over the team's flow library.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
from app.services.rag_service import search as rag_search, build_rag_context, extract_suggested_flows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT = """You are a Senior Systems and Network Engineer with 15+ years of experience working in Managed Service Provider (MSP) environments. You specialize in:
|
||||
- Windows Server, Active Directory, Group Policy, and Hybrid Identity (Entra ID)
|
||||
- Networking (TCP/IP, DNS, DHCP, VPN, firewall troubleshooting, Cisco/Fortinet)
|
||||
- Virtualization (VMware, Hyper-V) and cloud platforms (Azure, AWS, M365)
|
||||
- Endpoint management, RMM tools, and PSA platforms (ConnectWise, Datto, Kaseya)
|
||||
- PowerShell scripting and automation
|
||||
|
||||
When answering:
|
||||
- Be direct and actionable — MSP engineers need fast, practical answers
|
||||
- Include specific commands, paths, and config values when relevant
|
||||
- Mention potential risks or gotchas before suggesting changes
|
||||
- If a relevant troubleshooting flow exists in the team's library, reference it
|
||||
- Keep responses concise but thorough — prefer bullet points and code blocks
|
||||
- Format code with proper markdown code blocks
|
||||
"""
|
||||
|
||||
|
||||
def _auto_title(message: str) -> str:
|
||||
"""Generate a short title from the first user message."""
|
||||
title = message.strip()[:100]
|
||||
if len(message) > 100:
|
||||
title = title.rsplit(" ", 1)[0] + "..."
|
||||
return title
|
||||
|
||||
|
||||
async def create_chat(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
) -> AssistantChat:
|
||||
"""Create a new empty chat."""
|
||||
chat = AssistantChat(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
messages=[],
|
||||
)
|
||||
db.add(chat)
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
async def send_message(
|
||||
chat_id: UUID,
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
message: str,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str, list[dict[str, Any]], AssistantChat]:
|
||||
"""Send a user message and get AI response.
|
||||
|
||||
Returns (ai_content, suggested_flows, chat).
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AssistantChat).where(
|
||||
AssistantChat.id == chat_id,
|
||||
AssistantChat.user_id == user_id,
|
||||
)
|
||||
)
|
||||
chat = result.scalar_one_or_none()
|
||||
if not chat:
|
||||
raise ValueError("Chat not found")
|
||||
|
||||
# Auto-title from first message
|
||||
if chat.message_count == 0:
|
||||
chat.title = _auto_title(message)
|
||||
|
||||
# RAG search
|
||||
rag_results = await rag_search(
|
||||
query=message,
|
||||
account_id=account_id,
|
||||
db=db,
|
||||
limit=8,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = ASSISTANT_SYSTEM_PROMPT + build_rag_context(rag_results)
|
||||
|
||||
# Build messages for AI
|
||||
ai_messages = []
|
||||
for msg in chat.messages:
|
||||
if msg["role"] in ("user", "assistant"):
|
||||
ai_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
ai_messages.append({"role": "user", "content": message})
|
||||
|
||||
# Call AI
|
||||
provider = get_ai_provider()
|
||||
ai_content, input_tokens, output_tokens = await provider.generate_text(
|
||||
system_prompt=system_prompt,
|
||||
messages=ai_messages,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
# Update chat
|
||||
msgs = list(chat.messages)
|
||||
msgs.append({"role": "user", "content": message})
|
||||
msgs.append({"role": "assistant", "content": ai_content})
|
||||
chat.messages = msgs
|
||||
chat.message_count += 2
|
||||
chat.total_input_tokens += input_tokens
|
||||
chat.total_output_tokens += output_tokens
|
||||
|
||||
suggested_flows = extract_suggested_flows(rag_results)
|
||||
|
||||
return ai_content, suggested_flows, chat
|
||||
202
backend/app/services/copilot_service.py
Normal file
202
backend/app/services/copilot_service.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Copilot service — in-session AI assistant with RAG context.
|
||||
|
||||
Builds system prompts with current flow context and RAG results,
|
||||
manages conversation state, and returns AI responses with flow suggestions.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.ai_provider import get_ai_provider
|
||||
from app.models.tree import Tree
|
||||
from app.models.copilot_conversation import CopilotConversation
|
||||
from app.services.rag_service import search as rag_search, build_rag_context, extract_suggested_flows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COPILOT_SYSTEM_PROMPT = """You are a Senior Systems and Network Engineer with 15+ years of experience working in Managed Service Provider (MSP) environments. You specialize in:
|
||||
- Windows Server, Active Directory, Group Policy, and Hybrid Identity (Entra ID)
|
||||
- Networking (TCP/IP, DNS, DHCP, VPN, firewall troubleshooting, Cisco/Fortinet)
|
||||
- Virtualization (VMware, Hyper-V) and cloud platforms (Azure, AWS, M365)
|
||||
- Endpoint management, RMM tools, and PSA platforms (ConnectWise, Datto, Kaseya)
|
||||
- PowerShell scripting and automation
|
||||
|
||||
You are acting as an in-session copilot while the user navigates a troubleshooting or procedural flow. You can see the flow context and their current position.
|
||||
|
||||
When answering:
|
||||
- Be direct and actionable — MSP engineers need fast, practical answers
|
||||
- Include specific commands, paths, and config values when relevant
|
||||
- Mention potential risks or gotchas before suggesting changes
|
||||
- If a relevant troubleshooting flow exists in the team's library, reference it
|
||||
- Keep responses concise but thorough — prefer bullet points and code blocks
|
||||
"""
|
||||
|
||||
|
||||
def _build_flow_context(tree: Tree, current_node_id: Optional[str]) -> str:
|
||||
"""Build flow context string for the system prompt."""
|
||||
parts = [
|
||||
f"\n--- CURRENT FLOW CONTEXT ---",
|
||||
f"Flow: {tree.name}",
|
||||
f"Type: {tree.tree_type}",
|
||||
]
|
||||
if tree.description:
|
||||
parts.append(f"Description: {tree.description}")
|
||||
|
||||
if current_node_id and tree.tree_structure:
|
||||
node = _find_node(tree.tree_structure, current_node_id)
|
||||
if node:
|
||||
parts.append(f"Current node type: {node.get('type', 'unknown')}")
|
||||
parts.append(f"Current node: {node.get('content', node.get('label', 'Unknown'))}")
|
||||
# Add options if it's a question/decision node
|
||||
children = node.get("children", [])
|
||||
if children and isinstance(children, list):
|
||||
option_labels = [
|
||||
c.get("label", c.get("content", ""))
|
||||
for c in children if isinstance(c, dict)
|
||||
]
|
||||
if option_labels:
|
||||
parts.append(f"Available options: {', '.join(option_labels)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _find_node(structure: dict, node_id: str) -> Optional[dict]:
|
||||
"""Recursively find a node by ID in tree structure."""
|
||||
if structure.get("id") == node_id:
|
||||
return structure
|
||||
for child in structure.get("children", []):
|
||||
if isinstance(child, dict):
|
||||
found = _find_node(child, node_id)
|
||||
if found:
|
||||
return found
|
||||
# Check steps array for procedural flows
|
||||
for step in structure.get("steps", []):
|
||||
if isinstance(step, dict):
|
||||
found = _find_node(step, node_id)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
|
||||
async def start_conversation(
|
||||
user_id: UUID,
|
||||
account_id: UUID,
|
||||
tree_id: UUID,
|
||||
session_id: Optional[UUID],
|
||||
current_node_id: Optional[str],
|
||||
db: AsyncSession,
|
||||
) -> tuple[CopilotConversation, str]:
|
||||
"""Start a new copilot conversation.
|
||||
|
||||
Returns (conversation, greeting_message).
|
||||
"""
|
||||
# Load tree
|
||||
result = await db.execute(
|
||||
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == tree_id)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
if not tree:
|
||||
raise ValueError(f"Tree {tree_id} not found")
|
||||
|
||||
conversation = CopilotConversation(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
tree_id=tree_id,
|
||||
session_id=session_id,
|
||||
current_node_id=current_node_id,
|
||||
messages=[],
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=24),
|
||||
)
|
||||
db.add(conversation)
|
||||
await db.flush()
|
||||
|
||||
greeting = f"I'm your copilot for this **{tree.tree_type}** flow: **{tree.name}**. Ask me anything about the current step, alternative approaches, or related troubleshooting tips."
|
||||
|
||||
conversation.messages = [{"role": "assistant", "content": greeting}]
|
||||
conversation.message_count = 1
|
||||
|
||||
return conversation, greeting
|
||||
|
||||
|
||||
async def send_message(
|
||||
conversation_id: UUID,
|
||||
user_id: UUID,
|
||||
message: str,
|
||||
current_node_id: Optional[str],
|
||||
db: AsyncSession,
|
||||
) -> tuple[str, list[dict[str, Any]], CopilotConversation]:
|
||||
"""Send a user message and get AI response.
|
||||
|
||||
Returns (ai_content, suggested_flows, conversation).
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(CopilotConversation).where(
|
||||
CopilotConversation.id == conversation_id,
|
||||
CopilotConversation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
conversation = result.scalar_one_or_none()
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
if conversation.expires_at < datetime.now(timezone.utc):
|
||||
raise ValueError("Conversation has expired")
|
||||
|
||||
# Load tree for context
|
||||
tree_result = await db.execute(
|
||||
select(Tree).options(selectinload(Tree.tags)).where(Tree.id == conversation.tree_id)
|
||||
)
|
||||
tree = tree_result.scalar_one_or_none()
|
||||
if not tree:
|
||||
raise ValueError("Associated flow not found")
|
||||
|
||||
# Update current node
|
||||
if current_node_id:
|
||||
conversation.current_node_id = current_node_id
|
||||
|
||||
# RAG search
|
||||
rag_results = await rag_search(
|
||||
query=message,
|
||||
account_id=conversation.account_id,
|
||||
db=db,
|
||||
limit=8,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = COPILOT_SYSTEM_PROMPT
|
||||
system_prompt += _build_flow_context(tree, conversation.current_node_id)
|
||||
system_prompt += build_rag_context(rag_results)
|
||||
|
||||
# Build messages for AI
|
||||
ai_messages = []
|
||||
for msg in conversation.messages:
|
||||
if msg["role"] in ("user", "assistant"):
|
||||
ai_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
ai_messages.append({"role": "user", "content": message})
|
||||
|
||||
# Call AI
|
||||
provider = get_ai_provider()
|
||||
ai_content, input_tokens, output_tokens = await provider.generate_text(
|
||||
system_prompt=system_prompt,
|
||||
messages=ai_messages,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
# Update conversation
|
||||
msgs = list(conversation.messages)
|
||||
msgs.append({"role": "user", "content": message})
|
||||
msgs.append({"role": "assistant", "content": ai_content})
|
||||
conversation.messages = msgs
|
||||
conversation.message_count += 2
|
||||
conversation.total_input_tokens += input_tokens
|
||||
conversation.total_output_tokens += output_tokens
|
||||
|
||||
# Extract suggested flows
|
||||
suggested_flows = extract_suggested_flows(rag_results, exclude_tree_id=tree.id)
|
||||
|
||||
return ai_content, suggested_flows, conversation
|
||||
78
backend/app/services/embedding_service.py
Normal file
78
backend/app/services/embedding_service.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Embedding provider abstraction for RAG.
|
||||
|
||||
Uses Voyage AI (voyage-3.5, 1024 dims) as the embedding provider.
|
||||
Supports document and query input types for asymmetric search.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_embedding(
|
||||
text: str,
|
||||
input_type: str = "document",
|
||||
) -> Optional[list[float]]:
|
||||
"""Get embedding vector for text using Voyage AI.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
input_type: "document" for indexing, "query" for search queries.
|
||||
|
||||
Returns:
|
||||
List of floats (1024 dims) or None if embedding service unavailable.
|
||||
"""
|
||||
if not settings.VOYAGE_API_KEY:
|
||||
logger.warning("VOYAGE_API_KEY not set — embedding service unavailable")
|
||||
return None
|
||||
|
||||
try:
|
||||
import voyageai
|
||||
|
||||
client = voyageai.AsyncClient(api_key=settings.VOYAGE_API_KEY)
|
||||
result = await client.embed(
|
||||
texts=[text],
|
||||
model=settings.EMBEDDING_MODEL,
|
||||
input_type=input_type,
|
||||
)
|
||||
return result.embeddings[0]
|
||||
except Exception as e:
|
||||
logger.error("Embedding failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def get_embeddings_batch(
|
||||
texts: list[str],
|
||||
input_type: str = "document",
|
||||
) -> Optional[list[list[float]]]:
|
||||
"""Get embedding vectors for multiple texts in a single API call.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
input_type: "document" for indexing, "query" for search queries.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors or None if service unavailable.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if not settings.VOYAGE_API_KEY:
|
||||
logger.warning("VOYAGE_API_KEY not set — embedding service unavailable")
|
||||
return None
|
||||
|
||||
try:
|
||||
import voyageai
|
||||
|
||||
client = voyageai.AsyncClient(api_key=settings.VOYAGE_API_KEY)
|
||||
result = await client.embed(
|
||||
texts=texts,
|
||||
model=settings.EMBEDDING_MODEL,
|
||||
input_type=input_type,
|
||||
)
|
||||
return result.embeddings
|
||||
except Exception as e:
|
||||
logger.error("Batch embedding failed: %s", e)
|
||||
return None
|
||||
209
backend/app/services/rag_service.py
Normal file
209
backend/app/services/rag_service.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""RAG service — index trees and search embeddings for AI context.
|
||||
|
||||
Orchestrates tree chunking, embedding, and semantic search over the
|
||||
team's flow library via pgvector cosine similarity.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import text, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.tree import Tree
|
||||
from app.models.tree_embedding import TreeEmbedding
|
||||
from app.services.embedding_service import get_embedding, get_embeddings_batch
|
||||
from app.services.tree_chunker import chunk_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def index_tree(tree_id: UUID, db: AsyncSession) -> int:
|
||||
"""Chunk and embed a tree, storing results in tree_embeddings.
|
||||
|
||||
Deletes existing embeddings for this tree before re-indexing.
|
||||
Returns the number of chunks indexed.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.tags))
|
||||
.where(Tree.id == tree_id)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
if not tree:
|
||||
logger.warning("index_tree: tree %s not found", tree_id)
|
||||
return 0
|
||||
|
||||
# Delete existing embeddings
|
||||
await db.execute(
|
||||
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
||||
)
|
||||
|
||||
# Chunk the tree
|
||||
tag_names = [t.name for t in tree.tags] if tree.tags else []
|
||||
chunks = chunk_tree(
|
||||
tree_name=tree.name,
|
||||
tree_type=tree.tree_type,
|
||||
description=tree.description,
|
||||
tags=tag_names,
|
||||
tree_structure=tree.tree_structure,
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
logger.info("index_tree: no chunks for tree %s", tree_id)
|
||||
return 0
|
||||
|
||||
# Get embeddings for all chunks in batch
|
||||
texts = [c["chunk_text"] for c in chunks]
|
||||
embeddings = await get_embeddings_batch(texts, input_type="document")
|
||||
|
||||
if embeddings is None:
|
||||
logger.warning("index_tree: embedding service unavailable for tree %s", tree_id)
|
||||
return 0
|
||||
|
||||
# Insert embeddings
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
embedding_str = "[" + ",".join(str(v) for v in embedding) + "]"
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO tree_embeddings
|
||||
(tree_id, account_id, chunk_type, node_type, node_id, chunk_text, embedding_model, embedding, meta)
|
||||
VALUES
|
||||
(:tree_id, :account_id, :chunk_type, :node_type, :node_id, :chunk_text, :embedding_model, :embedding::vector, :meta::jsonb)
|
||||
"""),
|
||||
{
|
||||
"tree_id": str(tree_id),
|
||||
"account_id": str(tree.account_id) if tree.account_id else None,
|
||||
"chunk_type": chunk["chunk_type"],
|
||||
"node_type": chunk.get("node_type"),
|
||||
"node_id": chunk.get("node_id"),
|
||||
"chunk_text": chunk["chunk_text"],
|
||||
"embedding_model": "voyage-3.5",
|
||||
"embedding": embedding_str,
|
||||
"meta": "{}",
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("index_tree: indexed %d chunks for tree %s", len(chunks), tree_id)
|
||||
return len(chunks)
|
||||
|
||||
|
||||
async def delete_tree_embeddings(tree_id: UUID, db: AsyncSession) -> None:
|
||||
"""Delete all embeddings for a tree."""
|
||||
await db.execute(
|
||||
delete(TreeEmbedding).where(TreeEmbedding.tree_id == tree_id)
|
||||
)
|
||||
|
||||
|
||||
async def search(
|
||||
query: str,
|
||||
account_id: UUID,
|
||||
db: AsyncSession,
|
||||
limit: int = 8,
|
||||
exclude_tree_id: Optional[UUID] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Semantic search over team's flow library.
|
||||
|
||||
Args:
|
||||
query: Natural language search query.
|
||||
account_id: Scope search to team's flows.
|
||||
db: Database session.
|
||||
limit: Max results to return.
|
||||
exclude_tree_id: Exclude chunks from this tree (for copilot context).
|
||||
|
||||
Returns:
|
||||
List of dicts with tree_id, tree_name, tree_type, chunk_text, chunk_type, similarity.
|
||||
"""
|
||||
query_embedding = await get_embedding(query, input_type="query")
|
||||
if query_embedding is None:
|
||||
return []
|
||||
|
||||
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
|
||||
|
||||
exclude_clause = ""
|
||||
params: dict[str, Any] = {
|
||||
"embedding": embedding_str,
|
||||
"account_id": str(account_id),
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
if exclude_tree_id:
|
||||
exclude_clause = "AND te.tree_id != :exclude_tree_id"
|
||||
params["exclude_tree_id"] = str(exclude_tree_id)
|
||||
|
||||
result = await db.execute(
|
||||
text(f"""
|
||||
SELECT
|
||||
te.tree_id,
|
||||
t.name as tree_name,
|
||||
t.tree_type,
|
||||
te.chunk_text,
|
||||
te.chunk_type,
|
||||
te.node_id,
|
||||
1 - (te.embedding <=> :embedding::vector) as similarity
|
||||
FROM tree_embeddings te
|
||||
JOIN trees t ON t.id = te.tree_id
|
||||
WHERE te.account_id = :account_id
|
||||
AND t.deleted_at IS NULL
|
||||
{exclude_clause}
|
||||
ORDER BY te.embedding <=> :embedding::vector
|
||||
LIMIT :limit
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
|
||||
rows = result.mappings().all()
|
||||
return [
|
||||
{
|
||||
"tree_id": str(row["tree_id"]),
|
||||
"tree_name": row["tree_name"],
|
||||
"tree_type": row["tree_type"],
|
||||
"chunk_text": row["chunk_text"],
|
||||
"chunk_type": row["chunk_type"],
|
||||
"node_id": row["node_id"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def build_rag_context(rag_results: list[dict[str, Any]]) -> str:
|
||||
"""Format RAG results into a system prompt section."""
|
||||
if not rag_results:
|
||||
return ""
|
||||
|
||||
parts = ["\n--- RELEVANT FLOWS FROM TEAM LIBRARY ---"]
|
||||
for r in rag_results[:5]: # Cap at 5 for prompt size
|
||||
parts.append(f"- [{r['tree_type']}] {r['tree_name']}: {r['chunk_text'][:200]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def extract_suggested_flows(
|
||||
rag_results: list[dict[str, Any]],
|
||||
exclude_tree_id: Optional[UUID] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extract unique suggested flows from RAG results."""
|
||||
seen_tree_ids: set[str] = set()
|
||||
suggestions = []
|
||||
|
||||
for r in rag_results:
|
||||
tid = r["tree_id"]
|
||||
if exclude_tree_id and tid == str(exclude_tree_id):
|
||||
continue
|
||||
if tid in seen_tree_ids:
|
||||
continue
|
||||
if r["similarity"] < 0.3:
|
||||
continue
|
||||
seen_tree_ids.add(tid)
|
||||
suggestions.append({
|
||||
"tree_id": tid,
|
||||
"tree_name": r["tree_name"],
|
||||
"tree_type": r["tree_type"],
|
||||
"relevance_snippet": r["chunk_text"][:150],
|
||||
})
|
||||
|
||||
return suggestions[:3]
|
||||
84
backend/app/services/retention_cleanup.py
Normal file
84
backend/app/services/retention_cleanup.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Chat retention cleanup job.
|
||||
|
||||
Runs daily via APScheduler to enforce account-level retention settings:
|
||||
- Delete non-pinned chats older than chat_retention_days
|
||||
- Delete oldest non-pinned chats when count exceeds chat_retention_max_count
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
|
||||
from app.core.database import async_session_maker
|
||||
from app.models.account import Account
|
||||
from app.models.assistant_chat import AssistantChat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def cleanup_expired_chats() -> None:
|
||||
"""Enforce chat retention policies for all accounts."""
|
||||
async with async_session_maker() as db:
|
||||
try:
|
||||
result = await db.execute(select(Account))
|
||||
accounts = result.scalars().all()
|
||||
|
||||
total_deleted = 0
|
||||
for account in accounts:
|
||||
deleted = await _cleanup_account_chats(account, db)
|
||||
total_deleted += deleted
|
||||
|
||||
await db.commit()
|
||||
if total_deleted > 0:
|
||||
logger.info("[retention] Cleaned up %d expired chats", total_deleted)
|
||||
except Exception as e:
|
||||
logger.error("[retention] Chat cleanup failed: %s", e)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def _cleanup_account_chats(account: Account, db) -> int:
|
||||
"""Enforce retention for a single account. Returns count deleted."""
|
||||
deleted = 0
|
||||
|
||||
# Age-based retention
|
||||
if account.chat_retention_days:
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=account.chat_retention_days)
|
||||
result = await db.execute(
|
||||
delete(AssistantChat)
|
||||
.where(
|
||||
AssistantChat.account_id == account.id,
|
||||
AssistantChat.pinned == False, # noqa: E712
|
||||
AssistantChat.updated_at < cutoff,
|
||||
)
|
||||
.returning(AssistantChat.id)
|
||||
)
|
||||
deleted += len(result.all())
|
||||
|
||||
# Count-based retention
|
||||
if account.chat_retention_max_count:
|
||||
total = await db.scalar(
|
||||
select(func.count(AssistantChat.id)).where(
|
||||
AssistantChat.account_id == account.id,
|
||||
)
|
||||
) or 0
|
||||
|
||||
if total > account.chat_retention_max_count:
|
||||
excess = total - account.chat_retention_max_count
|
||||
# Get oldest non-pinned chat IDs
|
||||
oldest = await db.execute(
|
||||
select(AssistantChat.id)
|
||||
.where(
|
||||
AssistantChat.account_id == account.id,
|
||||
AssistantChat.pinned == False, # noqa: E712
|
||||
)
|
||||
.order_by(AssistantChat.updated_at.asc())
|
||||
.limit(excess)
|
||||
)
|
||||
ids_to_delete = [row[0] for row in oldest.all()]
|
||||
if ids_to_delete:
|
||||
await db.execute(
|
||||
delete(AssistantChat).where(AssistantChat.id.in_(ids_to_delete))
|
||||
)
|
||||
deleted += len(ids_to_delete)
|
||||
|
||||
return deleted
|
||||
165
backend/app/services/tree_chunker.py
Normal file
165
backend/app/services/tree_chunker.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tree chunker — converts tree_structure JSON into embeddable text chunks.
|
||||
|
||||
Produces three chunk types:
|
||||
- tree_summary: Name + description + tags + type overview
|
||||
- node: Individual node content with breadcrumb path context
|
||||
- solution: Full solution/action text with path context
|
||||
"""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_breadcrumb(node: dict, parent_path: str = "") -> str:
|
||||
"""Build a breadcrumb path string for a node."""
|
||||
content = node.get("content", node.get("label", ""))[:80]
|
||||
if parent_path:
|
||||
return f"{parent_path} > {content}"
|
||||
return content
|
||||
|
||||
|
||||
def _chunk_node(
|
||||
node: dict,
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
tags: list[str],
|
||||
parent_path: str = "",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Recursively chunk a node and its children."""
|
||||
chunks = []
|
||||
node_type = node.get("type", "unknown")
|
||||
node_id = node.get("id", "")
|
||||
content = node.get("content", node.get("label", ""))
|
||||
breadcrumb = _get_breadcrumb(node, parent_path)
|
||||
|
||||
# Build chunk text based on node type
|
||||
if node_type in ("question", "decision"):
|
||||
options = node.get("children", [])
|
||||
option_labels = [
|
||||
child.get("label", child.get("content", ""))[:100]
|
||||
for child in options
|
||||
if isinstance(child, dict)
|
||||
]
|
||||
text_parts = [
|
||||
f"[{node_type}] {content}",
|
||||
]
|
||||
if option_labels:
|
||||
text_parts.append(f"Options: {', '.join(option_labels)}")
|
||||
text_parts.append(f"Path: {breadcrumb}")
|
||||
text_parts.append(f"Flow: {tree_name} | Type: {tree_type}")
|
||||
if tags:
|
||||
text_parts.append(f"Tags: {', '.join(tags)}")
|
||||
|
||||
chunks.append({
|
||||
"chunk_type": "node",
|
||||
"node_type": node_type,
|
||||
"node_id": node_id,
|
||||
"chunk_text": "\n".join(text_parts),
|
||||
})
|
||||
|
||||
elif node_type in ("action", "solution", "info", "warning"):
|
||||
text_parts = [
|
||||
f"[{node_type}] {content}",
|
||||
f"Path: {breadcrumb}",
|
||||
f"Flow: {tree_name} | Type: {tree_type}",
|
||||
]
|
||||
if tags:
|
||||
text_parts.append(f"Tags: {', '.join(tags)}")
|
||||
|
||||
chunk_type = "solution" if node_type == "solution" else "node"
|
||||
chunks.append({
|
||||
"chunk_type": chunk_type,
|
||||
"node_type": node_type,
|
||||
"node_id": node_id,
|
||||
"chunk_text": "\n".join(text_parts),
|
||||
})
|
||||
|
||||
elif node_type in ("step", "section_header"):
|
||||
text_parts = [
|
||||
f"[{node_type}] {content}",
|
||||
f"Path: {breadcrumb}",
|
||||
f"Flow: {tree_name} | Type: {tree_type}",
|
||||
]
|
||||
if node.get("description"):
|
||||
text_parts.insert(1, node["description"])
|
||||
if tags:
|
||||
text_parts.append(f"Tags: {', '.join(tags)}")
|
||||
|
||||
chunks.append({
|
||||
"chunk_type": "node",
|
||||
"node_type": node_type,
|
||||
"node_id": node_id,
|
||||
"chunk_text": "\n".join(text_parts),
|
||||
})
|
||||
|
||||
# Recurse into children
|
||||
children = node.get("children", [])
|
||||
if isinstance(children, list):
|
||||
for child in children:
|
||||
if isinstance(child, dict):
|
||||
chunks.extend(
|
||||
_chunk_node(child, tree_name, tree_type, tags, breadcrumb)
|
||||
)
|
||||
|
||||
# Follow next_node_id linked nodes (action nodes)
|
||||
# These are handled at the tree level, not recursively
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_tree(
|
||||
tree_name: str,
|
||||
tree_type: str,
|
||||
description: str | None,
|
||||
tags: list[str],
|
||||
tree_structure: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert a tree into embeddable text chunks.
|
||||
|
||||
Args:
|
||||
tree_name: Name of the flow.
|
||||
tree_type: troubleshooting | procedural | maintenance.
|
||||
description: Flow description.
|
||||
tags: List of tag names.
|
||||
tree_structure: The tree_structure JSONB content.
|
||||
|
||||
Returns:
|
||||
List of chunk dicts with keys: chunk_type, node_type, node_id, chunk_text.
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# Tree summary chunk
|
||||
summary_parts = [
|
||||
f"Flow: {tree_name}",
|
||||
f"Type: {tree_type}",
|
||||
]
|
||||
if description:
|
||||
summary_parts.append(f"Description: {description}")
|
||||
if tags:
|
||||
summary_parts.append(f"Tags: {', '.join(tags)}")
|
||||
|
||||
chunks.append({
|
||||
"chunk_type": "tree_summary",
|
||||
"node_type": None,
|
||||
"node_id": None,
|
||||
"chunk_text": "\n".join(summary_parts),
|
||||
})
|
||||
|
||||
# Chunk the tree structure nodes
|
||||
root = tree_structure
|
||||
if isinstance(root, dict):
|
||||
# Handle both flat structure and nested
|
||||
if "children" in root or "type" in root:
|
||||
chunks.extend(
|
||||
_chunk_node(root, tree_name, tree_type, tags)
|
||||
)
|
||||
# Handle steps array (procedural flows)
|
||||
if "steps" in root and isinstance(root["steps"], list):
|
||||
for step in root["steps"]:
|
||||
if isinstance(step, dict):
|
||||
chunks.extend(
|
||||
_chunk_node(step, tree_name, tree_type, tags)
|
||||
)
|
||||
|
||||
return chunks
|
||||
@@ -35,6 +35,10 @@ httpx>=0.27.0
|
||||
anthropic>=0.40.0
|
||||
google-genai>=1.0.0
|
||||
|
||||
# RAG / Embeddings
|
||||
pgvector>=0.3.6
|
||||
voyageai>=0.3.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.1
|
||||
croniter>=2.0.0
|
||||
|
||||
109
backend/tests/test_account_lifecycle.py
Normal file
109
backend/tests/test_account_lifecycle.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Tests for leave account and delete account endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestLeaveAccount:
|
||||
"""Test POST /accounts/me/leave."""
|
||||
|
||||
async def test_leave_as_non_owner(self, client: AsyncClient, test_db):
|
||||
"""Non-owner can leave and gets a personal account."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
|
||||
# Register owner
|
||||
owner = await client.post("/api/v1/auth/register", json={
|
||||
"email": "owner@example.com", "password": "TestPassword123!", "name": "Owner",
|
||||
})
|
||||
assert owner.status_code == 201
|
||||
owner_data = owner.json()
|
||||
|
||||
# Login as owner
|
||||
login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "owner@example.com", "password": "TestPassword123!",
|
||||
})
|
||||
owner_headers = {"Authorization": f"Bearer {login.json()['access_token']}"}
|
||||
|
||||
# Register member
|
||||
member = await client.post("/api/v1/auth/register", json={
|
||||
"email": "member@example.com", "password": "TestPassword123!", "name": "Member",
|
||||
})
|
||||
member_id = member.json()["id"]
|
||||
|
||||
# Move member to owner's account
|
||||
result = await test_db.execute(select(User).where(User.id == member_id))
|
||||
member_user = result.scalar_one()
|
||||
member_user.account_id = owner_data["account_id"]
|
||||
member_user.account_role = "engineer"
|
||||
await test_db.commit()
|
||||
|
||||
# Login as member
|
||||
login = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "member@example.com", "password": "TestPassword123!",
|
||||
})
|
||||
member_headers = {"Authorization": f"Bearer {login.json()['access_token']}"}
|
||||
|
||||
# Leave
|
||||
response = await client.post("/api/v1/accounts/me/leave", headers=member_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_leave_as_owner_fails(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Owner cannot leave their own account."""
|
||||
response = await client.post("/api/v1/accounts/me/leave", headers=auth_headers)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDeleteAccount:
|
||||
"""Test DELETE /accounts/me."""
|
||||
|
||||
async def test_delete_success(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Owner with no other members can delete account."""
|
||||
response = await client.request(
|
||||
"DELETE",
|
||||
"/api/v1/accounts/me",
|
||||
json={"current_password": "TestPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_delete_wrong_password(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Wrong password returns 401."""
|
||||
response = await client.request(
|
||||
"DELETE",
|
||||
"/api/v1/accounts/me",
|
||||
json={"current_password": "WrongPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_delete_with_members_fails(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
"""Cannot delete account that has other members."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
|
||||
# Get owner's account_id
|
||||
me = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
account_id = me.json()["account_id"]
|
||||
|
||||
# Register and add member
|
||||
member = await client.post("/api/v1/auth/register", json={
|
||||
"email": "member2@example.com", "password": "TestPassword123!", "name": "Member",
|
||||
})
|
||||
member_id = member.json()["id"]
|
||||
|
||||
result = await test_db.execute(select(User).where(User.id == member_id))
|
||||
member_user = result.scalar_one()
|
||||
member_user.account_id = account_id
|
||||
member_user.account_role = "engineer"
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.request(
|
||||
"DELETE",
|
||||
"/api/v1/accounts/me",
|
||||
json={"current_password": "TestPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
63
backend/tests/test_account_transfer.py
Normal file
63
backend/tests/test_account_transfer.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Tests for account ownership transfer."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestOwnershipTransfer:
|
||||
"""Test POST /accounts/me/transfer-ownership."""
|
||||
|
||||
async def _create_member(self, client: AsyncClient, owner_headers: dict, test_db):
|
||||
"""Register a second user and add them to the owner's account."""
|
||||
from sqlalchemy import select
|
||||
from app.models.user import User
|
||||
|
||||
# Register second user (gets own account)
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "member@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Member User",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
member_id = resp.json()["id"]
|
||||
|
||||
# Get owner's account_id
|
||||
me = await client.get("/api/v1/auth/me", headers=owner_headers)
|
||||
owner_account_id = me.json()["account_id"]
|
||||
|
||||
# Move member to owner's account
|
||||
result = await test_db.execute(select(User).where(User.id == member_id))
|
||||
member = result.scalar_one()
|
||||
member.account_id = owner_account_id
|
||||
member.account_role = "engineer"
|
||||
await test_db.commit()
|
||||
|
||||
return member_id
|
||||
|
||||
async def test_transfer_success(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
member_id = await self._create_member(client, auth_headers, test_db)
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/transfer-ownership",
|
||||
json={"current_password": "TestPassword123!", "target_user_id": member_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["owner_id"] == member_id
|
||||
|
||||
async def test_transfer_self(self, client: AsyncClient, auth_headers: dict, test_user):
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/transfer-ownership",
|
||||
json={"current_password": "TestPassword123!", "target_user_id": test_user["user_data"]["id"]},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_transfer_wrong_password(self, client: AsyncClient, auth_headers: dict, test_db):
|
||||
member_id = await self._create_member(client, auth_headers, test_db)
|
||||
response = await client.post(
|
||||
"/api/v1/accounts/me/transfer-ownership",
|
||||
json={"current_password": "WrongPassword123!", "target_user_id": member_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 401
|
||||
90
backend/tests/test_auth_profile.py
Normal file
90
backend/tests/test_auth_profile.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Tests for PATCH /auth/me profile update endpoint."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestProfileUpdate:
|
||||
"""Test profile update via PATCH /auth/me."""
|
||||
|
||||
async def test_update_name(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Name update works without password."""
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "New Name"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "New Name"
|
||||
|
||||
async def test_update_email_with_password(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Email change with correct password succeeds."""
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"email": "newemail@example.com", "current_password": "TestPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["email"] == "newemail@example.com"
|
||||
|
||||
async def test_update_email_without_password(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Email change without password returns 400."""
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"email": "newemail@example.com"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "password" in response.json()["detail"].lower()
|
||||
|
||||
async def test_update_email_wrong_password(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Email change with wrong password returns 401."""
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"email": "newemail@example.com", "current_password": "WrongPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_update_email_duplicate(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Email change to existing email returns 400."""
|
||||
# Register second user
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "other@example.com",
|
||||
"password": "TestPassword123!",
|
||||
"name": "Other User",
|
||||
})
|
||||
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"email": "other@example.com", "current_password": "TestPassword123!"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "already registered" in response.json()["detail"].lower()
|
||||
|
||||
async def test_get_me_returns_updated_name(self, client: AsyncClient, auth_headers: dict):
|
||||
"""GET /me reflects the updated profile."""
|
||||
await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "Updated User"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
response = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated User"
|
||||
|
||||
async def test_no_changes_returns_current_user(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Empty update returns current user without error."""
|
||||
response = await client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_unauthenticated(self, client: AsyncClient):
|
||||
"""Unauthenticated request returns 401."""
|
||||
response = await client.patch("/api/v1/auth/me", json={"name": "X"})
|
||||
assert response.status_code == 401
|
||||
57
backend/tests/test_email_verification.py
Normal file
57
backend/tests/test_email_verification.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Tests for email verification endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestEmailVerification:
|
||||
"""Test email verification send + verify flow."""
|
||||
|
||||
async def test_send_verification(self, client: AsyncClient, auth_headers: dict):
|
||||
"""Send verification email returns 200."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/email/send-verification",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "sent" in response.json()["message"].lower()
|
||||
|
||||
async def test_send_verification_already_verified(
|
||||
self, client: AsyncClient, auth_headers: dict, test_db
|
||||
):
|
||||
"""Returns 400 if email is already verified."""
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime, timezone
|
||||
from app.models.user import User
|
||||
|
||||
# Manually mark email as verified
|
||||
await test_db.execute(
|
||||
update(User).where(User.email == "test@example.com").values(
|
||||
email_verified_at=datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
await test_db.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/email/send-verification",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "already verified" in response.json()["detail"].lower()
|
||||
|
||||
async def test_verify_invalid_token(self, client: AsyncClient):
|
||||
"""Invalid token returns 400."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/email/verify",
|
||||
json={"token": "invalid-token"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_verify_missing_token(self, client: AsyncClient):
|
||||
"""Missing token returns 400."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/email/verify",
|
||||
json={},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
Reference in New Issue
Block a user