feat: add tree forking, custom step tracking, and session sharing

Implement three foundational schema features from the design doc:

- Tree forking with lineage tracking (migration 022): parent_tree_id,
  root_tree_id, fork_depth columns with self-referential FKs and
  composite analytics index
- Custom step enhancement: CustomStepSchema with source tracking
  (ad-hoc, step-library, forked-tree) for backward-compatible JSONB
- Session sharing (migration 023): session_shares and session_share_views
  tables with account-scoped visibility, cryptographic tokens, view
  tracking, and allow_public_shares account policy

Includes 21 new integration tests (9 forking, 12 sharing), SaaS
consultant-recommended denormalizations, rate limiting on public share
access, and test fixture fix for invite code requirement.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-02-07 19:10:47 -05:00
parent c8e7aaad1a
commit ffb14cd014
16 changed files with 1345 additions and 8 deletions

View File

@@ -0,0 +1,287 @@
import secrets
from datetime import datetime, timezone
from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError
from app.core.database import get_db
from app.models.session import Session
from app.models.session_share import SessionShare, SessionShareView
from app.models.user import User
from app.models.account import Account
from app.schemas.session_share import ShareCreate, ShareResponse, SharePublicView
from app.api.deps import get_current_active_user, require_engineer_or_admin
from app.core.audit import log_audit
from app.core.rate_limit import limiter
router = APIRouter(tags=["shares"])
def build_share_response(share: SessionShare) -> ShareResponse:
return ShareResponse(
id=share.id,
session_id=share.session_id,
account_id=share.account_id,
share_token=share.share_token,
share_name=share.share_name,
visibility=share.visibility,
created_by=share.created_by,
created_at=share.created_at,
updated_at=share.updated_at,
expires_at=share.expires_at,
view_count=share.view_count,
last_viewed_at=share.last_viewed_at,
is_active=share.is_active,
)
# --- Session Share CRUD ---
@router.post(
"/sessions/{session_id}/shares",
response_model=ShareResponse,
status_code=status.HTTP_201_CREATED
)
async def create_share(
session_id: UUID,
share_data: ShareCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_engineer_or_admin)]
):
"""Create a share link for a session.
Only the session owner can create shares.
Public shares require account.allow_public_shares policy.
"""
# Verify session exists and user owns it
result = await db.execute(
select(Session).where(Session.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
if session.user_id != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the session owner can create share links"
)
# Require account_id for account-scoped shares
if share_data.visibility == "account" and not current_user.account_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create account-scoped share without an account"
)
# Check account policy for public shares
if share_data.visibility == "public" and current_user.account_id:
account_result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = account_result.scalar_one_or_none()
if account and not account.allow_public_shares:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your organization does not allow public session sharing. Use account-only visibility."
)
# Generate token with collision retry
max_retries = 3
for attempt in range(max_retries):
try:
share_token = secrets.token_urlsafe(48)
share = SessionShare(
session_id=session_id,
account_id=current_user.account_id,
share_token=share_token,
share_name=share_data.share_name,
visibility=share_data.visibility,
created_by=current_user.id,
expires_at=share_data.expires_at,
)
db.add(share)
await db.flush()
await log_audit(db, current_user.id, "share.create", "session_share", share.id,
{"session_id": str(session_id), "visibility": share_data.visibility})
await db.commit()
await db.refresh(share)
return build_share_response(share)
except IntegrityError as e:
await db.rollback()
if "session_shares_share_token_key" in str(e) and attempt < max_retries - 1:
continue
raise
@router.get("/shares/my-shares", response_model=list[ShareResponse])
async def list_my_shares(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100)
):
"""List all shares created by the current user."""
result = await db.execute(
select(SessionShare)
.where(
SessionShare.created_by == current_user.id,
SessionShare.is_active == True
)
.order_by(SessionShare.created_at.desc())
.offset(skip)
.limit(limit)
)
shares = result.scalars().all()
return [build_share_response(s) for s in shares]
@router.delete("/shares/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
async def revoke_share(
share_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Revoke a share link (soft delete - sets is_active=False)."""
result = await db.execute(
select(SessionShare).where(SessionShare.id == share_id)
)
share = result.scalar_one_or_none()
if not share:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share not found"
)
if share.created_by != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the share creator can revoke it"
)
share.is_active = False
await log_audit(db, current_user.id, "share.revoke", "session_share", share.id,
{"session_id": str(share.session_id)})
await db.commit()
return None
# --- Public Share Access ---
async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[User]:
"""Try to extract authenticated user from request, return None if not authenticated."""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header.replace("Bearer ", "")
try:
from app.core.security import decode_token
payload = decode_token(token)
if not payload or payload.get("type") != "access":
return None
user_id = payload.get("sub")
if not user_id:
return None
result = await db.execute(select(User).where(User.id == UUID(user_id)))
return result.scalar_one_or_none()
except Exception:
return None
@router.get("/share/{share_token}", response_model=SharePublicView)
@limiter.limit("30/minute")
async def access_share(
share_token: str,
request: Request,
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Access a shared session via share token.
Public shares: No authentication required.
Account-only shares: Requires authentication + account membership.
"""
current_user = await _get_optional_user(request, db)
# Lookup share
result = await db.execute(
select(SessionShare)
.options(joinedload(SessionShare.session))
.where(SessionShare.share_token == share_token)
)
share = result.scalar_one_or_none()
# Validate share
if not share or not share.is_active:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share not found or has been revoked"
)
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail="Share link has expired"
)
# Check visibility
if share.visibility == "account":
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="This share requires authentication"
)
if current_user.account_id != share.account_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
)
# Record view
session = share.session
view = SessionShareView(
share_id=share.id,
session_id=session.id,
viewer_id=current_user.id if current_user else None,
viewer_ip=request.client.host if request.client else None,
viewer_user_agent=request.headers.get("user-agent"),
)
db.add(view)
share.view_count += 1
share.last_viewed_at = datetime.now(timezone.utc)
await db.commit()
# Build read-only response
tree_snapshot = session.tree_snapshot or {}
return SharePublicView(
session_id=session.id,
tree_name=tree_snapshot.get("question", "Untitled Tree"),
tree_description=tree_snapshot.get("description"),
tree_structure=tree_snapshot,
path_taken=session.path_taken or [],
decisions=session.decisions or [],
custom_steps=session.custom_steps or [],
started_at=session.started_at,
completed_at=session.completed_at,
ticket_number=session.ticket_number,
client_name=session.client_name,
share_name=share.share_name,
visibility=share.visibility,
)

View File

@@ -12,7 +12,7 @@ from app.models.user import User
from app.models.category import TreeCategory
from app.models.tag import TreeTag, tree_tag_assignments
from app.models.folder import UserFolder, user_folder_trees
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo, ForkCreate, ForkInfo
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.subscriptions import check_tree_limit
@@ -73,8 +73,8 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
)
def build_full_tree_response(tree: Tree) -> TreeResponse:
"""Build TreeResponse with all details including category_info and tags."""
def build_full_tree_response(tree: Tree, parent_tree: Tree = None) -> TreeResponse:
"""Build TreeResponse with all details including category_info, tags, and fork_info."""
category_info = None
if tree.category_rel:
category_info = CategoryInfo(
@@ -83,6 +83,20 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
slug=tree.category_rel.slug
)
fork_info = None
if tree.parent_tree_id or tree.fork_depth > 0:
has_updates = False
if parent_tree and tree.parent_updated_at:
has_updates = parent_tree.updated_at > tree.parent_updated_at
fork_info = ForkInfo(
parent_tree_id=tree.parent_tree_id,
root_tree_id=tree.root_tree_id,
fork_reason=tree.fork_reason,
fork_depth=tree.fork_depth,
parent_updated_at=tree.parent_updated_at,
has_parent_updates=has_updates
)
return TreeResponse(
id=tree.id,
name=tree.name,
@@ -91,6 +105,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
category_id=tree.category_id,
category_info=category_info,
tags=tree.tag_names,
fork_info=fork_info,
tree_structure=tree.tree_structure,
author_id=tree.author_id,
account_id=tree.account_id,
@@ -561,3 +576,166 @@ async def delete_tree(
{"tree_name": tree.name})
await db.commit()
return None
# --- Fork Endpoints ---
@router.post("/{tree_id}/fork", response_model=TreeResponse, status_code=status.HTTP_201_CREATED)
async def fork_tree(
tree_id: UUID,
fork_data: ForkCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_engineer_or_admin)]
):
"""Fork a tree to create a personal copy.
Engineers can fork any tree they can access (public, account, or default).
Fork inherits tree_structure but gets new ownership.
"""
# Load parent tree
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == tree_id)
)
parent = result.scalar_one_or_none()
if not parent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
if not can_access_tree(current_user, parent):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
)
# Check subscription tree limit
if current_user.account_id:
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
if not can_create:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
)
# Build fork
fork_name = fork_data.name or f"Fork of {parent.name}"
fork = Tree(
name=fork_name,
description=parent.description,
category=parent.category,
category_id=parent.category_id,
tree_structure=parent.tree_structure,
author_id=current_user.id,
account_id=current_user.account_id,
is_public=False,
is_default=False,
version=1,
# Fork tracking
parent_tree_id=parent.id,
fork_reason=fork_data.fork_reason,
parent_updated_at=parent.updated_at,
# Lineage tracking
root_tree_id=parent.root_tree_id if parent.root_tree_id else parent.id,
fork_depth=parent.fork_depth + 1,
)
db.add(fork)
await db.flush()
await log_audit(db, current_user.id, "tree.fork", "tree", fork.id,
{"parent_tree_id": str(parent.id), "parent_name": parent.name,
"fork_reason": fork_data.fork_reason})
await db.commit()
# Reload with relationships
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == fork.id)
)
fork = result.scalar_one()
return build_full_tree_response(fork, parent_tree=parent)
@router.get("/{tree_id}/forks", response_model=list[TreeListResponse])
async def list_forks(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100)
):
"""List all direct forks of a tree."""
# Verify parent exists and user can access it
parent_result = await db.execute(select(Tree).where(Tree.id == tree_id))
parent = parent_result.scalar_one_or_none()
if not parent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
if not can_access_tree(current_user, parent):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
)
# Query direct forks, filtered by access
query = select(Tree).options(
selectinload(Tree.category_rel),
selectinload(Tree.tags)
).where(
Tree.parent_tree_id == tree_id,
Tree.is_active == True,
build_tree_access_filter(current_user)
).order_by(Tree.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
forks = result.scalars().unique().all()
return [build_tree_response(tree) for tree in forks]
@router.get("/{tree_id}/lineage", response_model=list[TreeListResponse])
async def get_tree_lineage(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get the fork lineage chain from current tree back to root.
Returns ordered list: [current tree, parent, grandparent, ..., root]
Limited to 10 levels to prevent infinite loops.
"""
lineage = []
current_id = tree_id
visited = set()
max_depth = 10
for _ in range(max_depth):
if current_id is None or current_id in visited:
break
visited.add(current_id)
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == current_id)
)
tree = result.scalar_one_or_none()
if not tree:
break
lineage.append(build_tree_response(tree))
current_id = tree.parent_tree_id
return lineage

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares
api_router = APIRouter()
@@ -15,3 +15,4 @@ api_router.include_router(steps.router)
api_router.include_router(admin.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)
api_router.include_router(shares.router)

View File

@@ -15,6 +15,7 @@ from .step_category import StepCategory
from .step_library import StepLibrary, StepRating, StepUsageLog
from .refresh_token import RefreshToken
from .audit_log import AuditLog
from .session_share import SessionShare, SessionShareView
__all__ = [
"User",
@@ -38,4 +39,6 @@ __all__ = [
"StepUsageLog",
"RefreshToken",
"AuditLog",
"SessionShare",
"SessionShareView",
]

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey
from sqlalchemy import String, DateTime, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
@@ -26,6 +26,13 @@ class Account(Base):
stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
allow_public_shares: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
server_default="true",
comment="Policy: engineers can create public shares. Only affects NEW shares (grandfathered)."
)
# Relationships
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")

View File

@@ -1,12 +1,15 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from typing import Optional, Any, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Text
import sqlalchemy as sa
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
if TYPE_CHECKING:
from app.models.session_share import SessionShare
class Session(Base):
__tablename__ = "sessions"
@@ -53,3 +56,4 @@ class Session(Base):
tree: Mapped["Tree"] = relationship("Tree", back_populates="sessions")
user: Mapped["User"] = relationship("User", back_populates="sessions")
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")

View File

@@ -0,0 +1,152 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer, CheckConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.session import Session
from app.models.user import User
from app.models.account import Account
class SessionShare(Base):
__tablename__ = "session_shares"
__table_args__ = (
CheckConstraint(
"visibility IN ('public', 'account')",
name='ck_session_shares_visibility'
),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Account that owns this share (denormalized from session at creation)"
)
share_token: Mapped[str] = mapped_column(
String(64),
unique=True,
nullable=False,
index=True,
comment="URL-safe random token (48 bytes -> 64 base64 chars)"
)
share_name: Mapped[Optional[str]] = mapped_column(
String(100),
nullable=True,
comment="Optional label: 'Training link', 'Customer escalation #1234'"
)
visibility: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="public",
comment="public = anyone with link, account = account members only"
)
created_by: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc)
)
expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True,
index=True,
comment="Optional expiration for time-limited shares"
)
view_count: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0
)
last_viewed_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
index=True
)
# Relationships
session: Mapped["Session"] = relationship("Session", back_populates="shares")
account: Mapped["Account"] = relationship("Account")
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
views: Mapped[list["SessionShareView"]] = relationship(
"SessionShareView",
back_populates="share",
cascade="all, delete-orphan"
)
class SessionShareView(Base):
__tablename__ = "session_share_views"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4
)
share_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("session_shares.id", ondelete="CASCADE"),
nullable=False,
index=True
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Denormalized from share for analytics queries"
)
viewer_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="NULL for public shares (unauthenticated views)"
)
viewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
index=True
)
viewer_ip: Mapped[Optional[str]] = mapped_column(
String(45), # IPv6 max length
nullable=True
)
viewer_user_agent: Mapped[Optional[str]] = mapped_column(
String(500),
nullable=True
)
# Relationships
share: Mapped["SessionShare"] = relationship("SessionShare", back_populates="views")
viewer: Mapped[Optional["User"]] = relationship("User")

View File

@@ -79,10 +79,62 @@ class Tree(Base):
)
usage_count: Mapped[int] = mapped_column(Integer, default=0)
# Fork tracking
parent_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="SET NULL"),
nullable=True,
index=True
)
fork_reason: Mapped[Optional[str]] = mapped_column(
String(255),
nullable=True,
comment="Brief reason: 'Added Cisco Meraki steps for our network'"
)
parent_updated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Snapshot of parent's updated_at when fork created. Compare to detect parent updates."
)
# Fork lineage tracking
root_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="Original tree at root of fork chain (NULL for non-forked trees)"
)
fork_depth: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
server_default="0",
comment="Fork depth: 0 = original, 1 = direct fork, 2 = fork of fork, etc."
)
# Relationships
author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees")
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees")
# Fork relationships (self-referential)
parent: Mapped[Optional["Tree"]] = relationship(
"Tree",
remote_side="Tree.id",
foreign_keys=[parent_tree_id],
back_populates="forks"
)
forks: Mapped[list["Tree"]] = relationship(
"Tree",
foreign_keys=[parent_tree_id],
back_populates="parent"
)
root: Mapped[Optional["Tree"]] = relationship(
"Tree",
remote_side="Tree.id",
foreign_keys=[root_tree_id]
)
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree")
# New organization relationships

View File

@@ -1,9 +1,25 @@
from datetime import datetime
from typing import Optional, Any
from typing import Optional, Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, validator
class CustomStepSchema(BaseModel):
"""Enhanced custom step with source tracking.
Backward compatible: old sessions without new fields load with defaults.
"""
type: str # "decision" | "action" | "solution"
content: str
notes: Optional[str] = None
# Source tracking (new fields, optional for backward compatibility)
source: Literal["ad-hoc", "step-library", "forked-tree"] = "ad-hoc"
source_step_id: Optional[UUID] = None
inserted_at: Optional[datetime] = None
inserted_after_node_id: Optional[str] = None
class DecisionRecord(BaseModel):
node_id: str
question: Optional[str] = None
@@ -24,7 +40,7 @@ class SessionCreate(BaseModel):
class SessionUpdate(BaseModel):
path_taken: Optional[list[str]] = None
decisions: Optional[list[DecisionRecord]] = None
custom_steps: Optional[list[dict[str, Any]]] = None
custom_steps: Optional[list[CustomStepSchema]] = None
ticket_number: Optional[str] = Field(None, max_length=100)
client_name: Optional[str] = Field(None, max_length=255)
scratchpad: Optional[str] = None

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from typing import Optional, Literal
from uuid import UUID
from pydantic import BaseModel, Field
class ShareCreate(BaseModel):
visibility: Literal["public", "account"] = Field("public", description="Share visibility")
share_name: Optional[str] = Field(None, max_length=100, description="Optional label for the share")
expires_at: Optional[datetime] = Field(None, description="Optional expiration datetime")
class ShareResponse(BaseModel):
id: UUID
session_id: UUID
account_id: UUID
share_token: str
share_name: Optional[str] = None
visibility: str
created_by: UUID
created_at: datetime
updated_at: datetime
expires_at: Optional[datetime] = None
view_count: int
last_viewed_at: Optional[datetime] = None
is_active: bool
share_url: Optional[str] = None
class Config:
from_attributes = True
class SharePublicView(BaseModel):
"""Read-only session data returned when accessing a share link."""
session_id: UUID
tree_name: str
tree_description: Optional[str] = None
tree_structure: dict
path_taken: list[str]
decisions: list[dict]
custom_steps: list[dict] = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
ticket_number: Optional[str] = None
client_name: Optional[str] = None
share_name: Optional[str] = None
visibility: str

View File

@@ -40,6 +40,24 @@ class TreeUpdate(BaseModel):
tags: Optional[list[str]] = Field(None, max_length=10, description="List of tag names to assign (replaces existing)")
class ForkCreate(BaseModel):
fork_reason: Optional[str] = Field(None, max_length=255, description="Brief reason for forking")
name: Optional[str] = Field(None, min_length=1, max_length=255, description="Name for the fork (defaults to 'Fork of {original name}')")
class ForkInfo(BaseModel):
"""Fork metadata included in tree responses."""
parent_tree_id: Optional[UUID] = None
root_tree_id: Optional[UUID] = None
fork_reason: Optional[str] = None
fork_depth: int = 0
parent_updated_at: Optional[datetime] = None
has_parent_updates: bool = False
class Config:
from_attributes = True
class TreeResponse(TreeBase):
id: UUID
tree_structure: dict[str, Any]
@@ -48,6 +66,7 @@ class TreeResponse(TreeBase):
category_id: Optional[UUID] = None
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
fork_info: Optional[ForkInfo] = None
is_active: bool
is_public: bool
is_default: bool