feat: add tree forking, custom step tracking, and session sharing

Implement three foundational schema features from the design doc:

- Tree forking with lineage tracking (migration 022): parent_tree_id,
  root_tree_id, fork_depth columns with self-referential FKs and
  composite analytics index
- Custom step enhancement: CustomStepSchema with source tracking
  (ad-hoc, step-library, forked-tree) for backward-compatible JSONB
- Session sharing (migration 023): session_shares and session_share_views
  tables with account-scoped visibility, cryptographic tokens, view
  tracking, and allow_public_shares account policy

Includes 21 new integration tests (9 forking, 12 sharing), SaaS
consultant-recommended denormalizations, rate limiting on public share
access, and test fixture fix for invite code requirement.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Michael Chihlas
2026-02-07 19:10:47 -05:00
parent c8e7aaad1a
commit ffb14cd014
16 changed files with 1345 additions and 8 deletions

View File

@@ -0,0 +1,73 @@
"""add tree forking support with lineage tracking
Revision ID: 022
Revises: 021
Create Date: 2026-02-07
Adds fork relationship and lineage tracking to trees table:
- parent_tree_id: Points to immediate parent tree (NULL for root trees)
- fork_reason: Optional engineer note explaining fork purpose
- parent_updated_at: Snapshot of parent's updated_at at fork time
- root_tree_id: Points to original tree at root of fork chain
- fork_depth: How many forks deep (1 = direct fork, 2 = fork of fork, etc.)
Fork on parent delete behavior: SET NULL (orphaned forks survive)
Root tree delete behavior: SET NULL (lineage preserved via fork_depth)
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '022'
down_revision: Union[str, None] = '021'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add fork tracking columns
op.add_column('trees', sa.Column('parent_tree_id', UUID(as_uuid=True), nullable=True))
op.add_column('trees', sa.Column('fork_reason', sa.String(255), nullable=True))
op.add_column('trees', sa.Column('parent_updated_at', sa.DateTime(timezone=True), nullable=True))
# Add fork lineage columns
op.add_column('trees', sa.Column('root_tree_id', UUID(as_uuid=True), nullable=True))
op.add_column('trees', sa.Column('fork_depth', sa.Integer, nullable=False, server_default='0'))
# Add foreign key constraints
op.create_foreign_key(
'fk_trees_parent_tree_id',
'trees', 'trees',
['parent_tree_id'], ['id'],
ondelete='SET NULL'
)
op.create_foreign_key(
'fk_trees_root_tree_id',
'trees', 'trees',
['root_tree_id'], ['id'],
ondelete='SET NULL'
)
# Add indexes for fork queries
op.create_index('ix_trees_parent_tree_id', 'trees', ['parent_tree_id'])
op.create_index('ix_trees_root_tree_id', 'trees', ['root_tree_id'])
# Composite index for fork analytics (descendants + depth sorting)
op.create_index('ix_trees_fork_analytics', 'trees', ['root_tree_id', 'fork_depth'])
def downgrade() -> None:
op.drop_index('ix_trees_fork_analytics', table_name='trees')
op.drop_index('ix_trees_root_tree_id', table_name='trees')
op.drop_index('ix_trees_parent_tree_id', table_name='trees')
op.drop_constraint('fk_trees_root_tree_id', 'trees', type_='foreignkey')
op.drop_constraint('fk_trees_parent_tree_id', 'trees', type_='foreignkey')
op.drop_column('trees', 'fork_depth')
op.drop_column('trees', 'root_tree_id')
op.drop_column('trees', 'parent_updated_at')
op.drop_column('trees', 'fork_reason')
op.drop_column('trees', 'parent_tree_id')

View File

@@ -0,0 +1,89 @@
"""add session sharing tables
Revision ID: 023
Revises: 022
Create Date: 2026-02-07
Adds session sharing infrastructure:
- session_shares: Share tokens with visibility control (public/account-only)
- session_share_views: Detailed view tracking with viewer identification
- accounts.allow_public_shares: Account-level policy for public share creation
Includes SaaS consultant-recommended denormalizations:
- account_id on session_shares (faster access control, survives user transfers)
- session_id on session_share_views (faster analytics queries)
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = '023'
down_revision: Union[str, None] = '022'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create session_shares table
op.create_table(
'session_shares',
sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')),
sa.Column('session_id', UUID(as_uuid=True), nullable=False),
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
sa.Column('share_token', sa.String(64), nullable=False, unique=True),
sa.Column('share_name', sa.String(100), nullable=True),
sa.Column('visibility', sa.String(20), nullable=False, server_default='public'),
sa.Column('created_by', UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('view_count', sa.Integer, nullable=False, server_default='0'),
sa.Column('last_viewed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('is_active', sa.Boolean, nullable=False, server_default='true'),
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='CASCADE'),
sa.CheckConstraint("visibility IN ('public', 'account')", name='ck_session_shares_visibility')
)
# Create indexes for session_shares
op.create_index('ix_session_shares_session_id', 'session_shares', ['session_id'])
op.create_index('ix_session_shares_account_id', 'session_shares', ['account_id'])
op.create_index('ix_session_shares_share_token', 'session_shares', ['share_token'])
op.create_index('ix_session_shares_created_by', 'session_shares', ['created_by'])
op.create_index('ix_session_shares_expires_at', 'session_shares', ['expires_at'])
op.create_index('ix_session_shares_is_active', 'session_shares', ['is_active'])
# Create session_share_views table
op.create_table(
'session_share_views',
sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')),
sa.Column('share_id', UUID(as_uuid=True), nullable=False),
sa.Column('session_id', UUID(as_uuid=True), nullable=False),
sa.Column('viewer_id', UUID(as_uuid=True), nullable=True),
sa.Column('viewed_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('viewer_ip', sa.String(45), nullable=True),
sa.Column('viewer_user_agent', sa.String(500), nullable=True),
sa.ForeignKeyConstraint(['share_id'], ['session_shares.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['viewer_id'], ['users.id'], ondelete='SET NULL')
)
# Create indexes for session_share_views
op.create_index('ix_session_share_views_share_id', 'session_share_views', ['share_id'])
op.create_index('ix_session_share_views_session_id', 'session_share_views', ['session_id'])
op.create_index('ix_session_share_views_viewer_id', 'session_share_views', ['viewer_id'])
op.create_index('ix_session_share_views_viewed_at', 'session_share_views', ['viewed_at'])
# Add account policy for public shares
op.add_column('accounts', sa.Column('allow_public_shares', sa.Boolean, nullable=False, server_default='true'))
def downgrade() -> None:
op.drop_column('accounts', 'allow_public_shares')
op.drop_table('session_share_views')
op.drop_table('session_shares')

View File

@@ -0,0 +1,287 @@
import secrets
from datetime import datetime, timezone
from typing import Annotated, Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import IntegrityError
from app.core.database import get_db
from app.models.session import Session
from app.models.session_share import SessionShare, SessionShareView
from app.models.user import User
from app.models.account import Account
from app.schemas.session_share import ShareCreate, ShareResponse, SharePublicView
from app.api.deps import get_current_active_user, require_engineer_or_admin
from app.core.audit import log_audit
from app.core.rate_limit import limiter
router = APIRouter(tags=["shares"])
def build_share_response(share: SessionShare) -> ShareResponse:
return ShareResponse(
id=share.id,
session_id=share.session_id,
account_id=share.account_id,
share_token=share.share_token,
share_name=share.share_name,
visibility=share.visibility,
created_by=share.created_by,
created_at=share.created_at,
updated_at=share.updated_at,
expires_at=share.expires_at,
view_count=share.view_count,
last_viewed_at=share.last_viewed_at,
is_active=share.is_active,
)
# --- Session Share CRUD ---
@router.post(
"/sessions/{session_id}/shares",
response_model=ShareResponse,
status_code=status.HTTP_201_CREATED
)
async def create_share(
session_id: UUID,
share_data: ShareCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_engineer_or_admin)]
):
"""Create a share link for a session.
Only the session owner can create shares.
Public shares require account.allow_public_shares policy.
"""
# Verify session exists and user owns it
result = await db.execute(
select(Session).where(Session.id == session_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
if session.user_id != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the session owner can create share links"
)
# Require account_id for account-scoped shares
if share_data.visibility == "account" and not current_user.account_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create account-scoped share without an account"
)
# Check account policy for public shares
if share_data.visibility == "public" and current_user.account_id:
account_result = await db.execute(
select(Account).where(Account.id == current_user.account_id)
)
account = account_result.scalar_one_or_none()
if account and not account.allow_public_shares:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your organization does not allow public session sharing. Use account-only visibility."
)
# Generate token with collision retry
max_retries = 3
for attempt in range(max_retries):
try:
share_token = secrets.token_urlsafe(48)
share = SessionShare(
session_id=session_id,
account_id=current_user.account_id,
share_token=share_token,
share_name=share_data.share_name,
visibility=share_data.visibility,
created_by=current_user.id,
expires_at=share_data.expires_at,
)
db.add(share)
await db.flush()
await log_audit(db, current_user.id, "share.create", "session_share", share.id,
{"session_id": str(session_id), "visibility": share_data.visibility})
await db.commit()
await db.refresh(share)
return build_share_response(share)
except IntegrityError as e:
await db.rollback()
if "session_shares_share_token_key" in str(e) and attempt < max_retries - 1:
continue
raise
@router.get("/shares/my-shares", response_model=list[ShareResponse])
async def list_my_shares(
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100)
):
"""List all shares created by the current user."""
result = await db.execute(
select(SessionShare)
.where(
SessionShare.created_by == current_user.id,
SessionShare.is_active == True
)
.order_by(SessionShare.created_at.desc())
.offset(skip)
.limit(limit)
)
shares = result.scalars().all()
return [build_share_response(s) for s in shares]
@router.delete("/shares/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
async def revoke_share(
share_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Revoke a share link (soft delete - sets is_active=False)."""
result = await db.execute(
select(SessionShare).where(SessionShare.id == share_id)
)
share = result.scalar_one_or_none()
if not share:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share not found"
)
if share.created_by != current_user.id and not current_user.is_super_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only the share creator can revoke it"
)
share.is_active = False
await log_audit(db, current_user.id, "share.revoke", "session_share", share.id,
{"session_id": str(share.session_id)})
await db.commit()
return None
# --- Public Share Access ---
async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[User]:
"""Try to extract authenticated user from request, return None if not authenticated."""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header.replace("Bearer ", "")
try:
from app.core.security import decode_token
payload = decode_token(token)
if not payload or payload.get("type") != "access":
return None
user_id = payload.get("sub")
if not user_id:
return None
result = await db.execute(select(User).where(User.id == UUID(user_id)))
return result.scalar_one_or_none()
except Exception:
return None
@router.get("/share/{share_token}", response_model=SharePublicView)
@limiter.limit("30/minute")
async def access_share(
share_token: str,
request: Request,
db: Annotated[AsyncSession, Depends(get_db)],
):
"""Access a shared session via share token.
Public shares: No authentication required.
Account-only shares: Requires authentication + account membership.
"""
current_user = await _get_optional_user(request, db)
# Lookup share
result = await db.execute(
select(SessionShare)
.options(joinedload(SessionShare.session))
.where(SessionShare.share_token == share_token)
)
share = result.scalar_one_or_none()
# Validate share
if not share or not share.is_active:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share not found or has been revoked"
)
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail="Share link has expired"
)
# Check visibility
if share.visibility == "account":
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="This share requires authentication"
)
if current_user.account_id != share.account_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this session"
)
# Record view
session = share.session
view = SessionShareView(
share_id=share.id,
session_id=session.id,
viewer_id=current_user.id if current_user else None,
viewer_ip=request.client.host if request.client else None,
viewer_user_agent=request.headers.get("user-agent"),
)
db.add(view)
share.view_count += 1
share.last_viewed_at = datetime.now(timezone.utc)
await db.commit()
# Build read-only response
tree_snapshot = session.tree_snapshot or {}
return SharePublicView(
session_id=session.id,
tree_name=tree_snapshot.get("question", "Untitled Tree"),
tree_description=tree_snapshot.get("description"),
tree_structure=tree_snapshot,
path_taken=session.path_taken or [],
decisions=session.decisions or [],
custom_steps=session.custom_steps or [],
started_at=session.started_at,
completed_at=session.completed_at,
ticket_number=session.ticket_number,
client_name=session.client_name,
share_name=share.share_name,
visibility=share.visibility,
)

View File

@@ -12,7 +12,7 @@ from app.models.user import User
from app.models.category import TreeCategory
from app.models.tag import TreeTag, tree_tag_assignments
from app.models.folder import UserFolder, user_folder_trees
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo, ForkCreate, ForkInfo
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
from app.core.permissions import can_edit_tree, can_access_tree
from app.core.subscriptions import check_tree_limit
@@ -73,8 +73,8 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
)
def build_full_tree_response(tree: Tree) -> TreeResponse:
"""Build TreeResponse with all details including category_info and tags."""
def build_full_tree_response(tree: Tree, parent_tree: Tree = None) -> TreeResponse:
"""Build TreeResponse with all details including category_info, tags, and fork_info."""
category_info = None
if tree.category_rel:
category_info = CategoryInfo(
@@ -83,6 +83,20 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
slug=tree.category_rel.slug
)
fork_info = None
if tree.parent_tree_id or tree.fork_depth > 0:
has_updates = False
if parent_tree and tree.parent_updated_at:
has_updates = parent_tree.updated_at > tree.parent_updated_at
fork_info = ForkInfo(
parent_tree_id=tree.parent_tree_id,
root_tree_id=tree.root_tree_id,
fork_reason=tree.fork_reason,
fork_depth=tree.fork_depth,
parent_updated_at=tree.parent_updated_at,
has_parent_updates=has_updates
)
return TreeResponse(
id=tree.id,
name=tree.name,
@@ -91,6 +105,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
category_id=tree.category_id,
category_info=category_info,
tags=tree.tag_names,
fork_info=fork_info,
tree_structure=tree.tree_structure,
author_id=tree.author_id,
account_id=tree.account_id,
@@ -561,3 +576,166 @@ async def delete_tree(
{"tree_name": tree.name})
await db.commit()
return None
# --- Fork Endpoints ---
@router.post("/{tree_id}/fork", response_model=TreeResponse, status_code=status.HTTP_201_CREATED)
async def fork_tree(
tree_id: UUID,
fork_data: ForkCreate,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(require_engineer_or_admin)]
):
"""Fork a tree to create a personal copy.
Engineers can fork any tree they can access (public, account, or default).
Fork inherits tree_structure but gets new ownership.
"""
# Load parent tree
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == tree_id)
)
parent = result.scalar_one_or_none()
if not parent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
if not can_access_tree(current_user, parent):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
)
# Check subscription tree limit
if current_user.account_id:
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
if not can_create:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
)
# Build fork
fork_name = fork_data.name or f"Fork of {parent.name}"
fork = Tree(
name=fork_name,
description=parent.description,
category=parent.category,
category_id=parent.category_id,
tree_structure=parent.tree_structure,
author_id=current_user.id,
account_id=current_user.account_id,
is_public=False,
is_default=False,
version=1,
# Fork tracking
parent_tree_id=parent.id,
fork_reason=fork_data.fork_reason,
parent_updated_at=parent.updated_at,
# Lineage tracking
root_tree_id=parent.root_tree_id if parent.root_tree_id else parent.id,
fork_depth=parent.fork_depth + 1,
)
db.add(fork)
await db.flush()
await log_audit(db, current_user.id, "tree.fork", "tree", fork.id,
{"parent_tree_id": str(parent.id), "parent_name": parent.name,
"fork_reason": fork_data.fork_reason})
await db.commit()
# Reload with relationships
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == fork.id)
)
fork = result.scalar_one()
return build_full_tree_response(fork, parent_tree=parent)
@router.get("/{tree_id}/forks", response_model=list[TreeListResponse])
async def list_forks(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)],
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=100)
):
"""List all direct forks of a tree."""
# Verify parent exists and user can access it
parent_result = await db.execute(select(Tree).where(Tree.id == tree_id))
parent = parent_result.scalar_one_or_none()
if not parent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tree not found"
)
if not can_access_tree(current_user, parent):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have access to this tree"
)
# Query direct forks, filtered by access
query = select(Tree).options(
selectinload(Tree.category_rel),
selectinload(Tree.tags)
).where(
Tree.parent_tree_id == tree_id,
Tree.is_active == True,
build_tree_access_filter(current_user)
).order_by(Tree.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
forks = result.scalars().unique().all()
return [build_tree_response(tree) for tree in forks]
@router.get("/{tree_id}/lineage", response_model=list[TreeListResponse])
async def get_tree_lineage(
tree_id: UUID,
db: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""Get the fork lineage chain from current tree back to root.
Returns ordered list: [current tree, parent, grandparent, ..., root]
Limited to 10 levels to prevent infinite loops.
"""
lineage = []
current_id = tree_id
visited = set()
max_depth = 10
for _ in range(max_depth):
if current_id is None or current_id in visited:
break
visited.add(current_id)
result = await db.execute(
select(Tree)
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
.where(Tree.id == current_id)
)
tree = result.scalar_one_or_none()
if not tree:
break
lineage.append(build_tree_response(tree))
current_id = tree.parent_tree_id
return lineage

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares
api_router = APIRouter()
@@ -15,3 +15,4 @@ api_router.include_router(steps.router)
api_router.include_router(admin.router)
api_router.include_router(accounts.router)
api_router.include_router(webhooks.router)
api_router.include_router(shares.router)

View File

@@ -15,6 +15,7 @@ from .step_category import StepCategory
from .step_library import StepLibrary, StepRating, StepUsageLog
from .refresh_token import RefreshToken
from .audit_log import AuditLog
from .session_share import SessionShare, SessionShareView
__all__ = [
"User",
@@ -38,4 +39,6 @@ __all__ = [
"StepUsageLog",
"RefreshToken",
"AuditLog",
"SessionShare",
"SessionShareView",
]

View File

@@ -1,7 +1,7 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey
from sqlalchemy import String, DateTime, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
@@ -26,6 +26,13 @@ class Account(Base):
stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
allow_public_shares: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
server_default="true",
comment="Policy: engineers can create public shares. Only affects NEW shares (grandfathered)."
)
# Relationships
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")

View File

@@ -1,12 +1,15 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, Any
from typing import Optional, Any, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Text
import sqlalchemy as sa
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.core.database import Base
if TYPE_CHECKING:
from app.models.session_share import SessionShare
class Session(Base):
__tablename__ = "sessions"
@@ -53,3 +56,4 @@ class Session(Base):
tree: Mapped["Tree"] = relationship("Tree", back_populates="sessions")
user: Mapped["User"] = relationship("User", back_populates="sessions")
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")

View File

@@ -0,0 +1,152 @@
import uuid
from datetime import datetime, timezone
from typing import Optional, TYPE_CHECKING
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer, CheckConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import UUID
from app.core.database import Base
if TYPE_CHECKING:
from app.models.session import Session
from app.models.user import User
from app.models.account import Account
class SessionShare(Base):
__tablename__ = "session_shares"
__table_args__ = (
CheckConstraint(
"visibility IN ('public', 'account')",
name='ck_session_shares_visibility'
),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True
)
account_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("accounts.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Account that owns this share (denormalized from session at creation)"
)
share_token: Mapped[str] = mapped_column(
String(64),
unique=True,
nullable=False,
index=True,
comment="URL-safe random token (48 bytes -> 64 base64 chars)"
)
share_name: Mapped[Optional[str]] = mapped_column(
String(100),
nullable=True,
comment="Optional label: 'Training link', 'Customer escalation #1234'"
)
visibility: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="public",
comment="public = anyone with link, account = account members only"
)
created_by: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc)
)
expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True,
index=True,
comment="Optional expiration for time-limited shares"
)
view_count: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0
)
last_viewed_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
index=True
)
# Relationships
session: Mapped["Session"] = relationship("Session", back_populates="shares")
account: Mapped["Account"] = relationship("Account")
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
views: Mapped[list["SessionShareView"]] = relationship(
"SessionShareView",
back_populates="share",
cascade="all, delete-orphan"
)
class SessionShareView(Base):
__tablename__ = "session_share_views"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4
)
share_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("session_shares.id", ondelete="CASCADE"),
nullable=False,
index=True
)
session_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sessions.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="Denormalized from share for analytics queries"
)
viewer_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="NULL for public shares (unauthenticated views)"
)
viewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
index=True
)
viewer_ip: Mapped[Optional[str]] = mapped_column(
String(45), # IPv6 max length
nullable=True
)
viewer_user_agent: Mapped[Optional[str]] = mapped_column(
String(500),
nullable=True
)
# Relationships
share: Mapped["SessionShare"] = relationship("SessionShare", back_populates="views")
viewer: Mapped[Optional["User"]] = relationship("User")

View File

@@ -79,10 +79,62 @@ class Tree(Base):
)
usage_count: Mapped[int] = mapped_column(Integer, default=0)
# Fork tracking
parent_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="SET NULL"),
nullable=True,
index=True
)
fork_reason: Mapped[Optional[str]] = mapped_column(
String(255),
nullable=True,
comment="Brief reason: 'Added Cisco Meraki steps for our network'"
)
parent_updated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Snapshot of parent's updated_at when fork created. Compare to detect parent updates."
)
# Fork lineage tracking
root_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
UUID(as_uuid=True),
ForeignKey("trees.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="Original tree at root of fork chain (NULL for non-forked trees)"
)
fork_depth: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
server_default="0",
comment="Fork depth: 0 = original, 1 = direct fork, 2 = fork of fork, etc."
)
# Relationships
author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees")
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees")
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees")
# Fork relationships (self-referential)
parent: Mapped[Optional["Tree"]] = relationship(
"Tree",
remote_side="Tree.id",
foreign_keys=[parent_tree_id],
back_populates="forks"
)
forks: Mapped[list["Tree"]] = relationship(
"Tree",
foreign_keys=[parent_tree_id],
back_populates="parent"
)
root: Mapped[Optional["Tree"]] = relationship(
"Tree",
remote_side="Tree.id",
foreign_keys=[root_tree_id]
)
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree")
# New organization relationships

View File

@@ -1,9 +1,25 @@
from datetime import datetime
from typing import Optional, Any
from typing import Optional, Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, validator
class CustomStepSchema(BaseModel):
"""Enhanced custom step with source tracking.
Backward compatible: old sessions without new fields load with defaults.
"""
type: str # "decision" | "action" | "solution"
content: str
notes: Optional[str] = None
# Source tracking (new fields, optional for backward compatibility)
source: Literal["ad-hoc", "step-library", "forked-tree"] = "ad-hoc"
source_step_id: Optional[UUID] = None
inserted_at: Optional[datetime] = None
inserted_after_node_id: Optional[str] = None
class DecisionRecord(BaseModel):
node_id: str
question: Optional[str] = None
@@ -24,7 +40,7 @@ class SessionCreate(BaseModel):
class SessionUpdate(BaseModel):
path_taken: Optional[list[str]] = None
decisions: Optional[list[DecisionRecord]] = None
custom_steps: Optional[list[dict[str, Any]]] = None
custom_steps: Optional[list[CustomStepSchema]] = None
ticket_number: Optional[str] = Field(None, max_length=100)
client_name: Optional[str] = Field(None, max_length=255)
scratchpad: Optional[str] = None

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from typing import Optional, Literal
from uuid import UUID
from pydantic import BaseModel, Field
class ShareCreate(BaseModel):
visibility: Literal["public", "account"] = Field("public", description="Share visibility")
share_name: Optional[str] = Field(None, max_length=100, description="Optional label for the share")
expires_at: Optional[datetime] = Field(None, description="Optional expiration datetime")
class ShareResponse(BaseModel):
id: UUID
session_id: UUID
account_id: UUID
share_token: str
share_name: Optional[str] = None
visibility: str
created_by: UUID
created_at: datetime
updated_at: datetime
expires_at: Optional[datetime] = None
view_count: int
last_viewed_at: Optional[datetime] = None
is_active: bool
share_url: Optional[str] = None
class Config:
from_attributes = True
class SharePublicView(BaseModel):
"""Read-only session data returned when accessing a share link."""
session_id: UUID
tree_name: str
tree_description: Optional[str] = None
tree_structure: dict
path_taken: list[str]
decisions: list[dict]
custom_steps: list[dict] = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
ticket_number: Optional[str] = None
client_name: Optional[str] = None
share_name: Optional[str] = None
visibility: str

View File

@@ -40,6 +40,24 @@ class TreeUpdate(BaseModel):
tags: Optional[list[str]] = Field(None, max_length=10, description="List of tag names to assign (replaces existing)")
class ForkCreate(BaseModel):
fork_reason: Optional[str] = Field(None, max_length=255, description="Brief reason for forking")
name: Optional[str] = Field(None, min_length=1, max_length=255, description="Name for the fork (defaults to 'Fork of {original name}')")
class ForkInfo(BaseModel):
"""Fork metadata included in tree responses."""
parent_tree_id: Optional[UUID] = None
root_tree_id: Optional[UUID] = None
fork_reason: Optional[str] = None
fork_depth: int = 0
parent_updated_at: Optional[datetime] = None
has_parent_updates: bool = False
class Config:
from_attributes = True
class TreeResponse(TreeBase):
id: UUID
tree_structure: dict[str, Any]
@@ -48,6 +66,7 @@ class TreeResponse(TreeBase):
category_id: Optional[UUID] = None
category_info: Optional[CategoryInfo] = None
tags: list[str] = [] # List of tag names
fork_info: Optional[ForkInfo] = None
is_active: bool
is_public: bool
is_default: bool

View File

@@ -16,6 +16,8 @@ from app.main import app
from app.core.database import Base, get_db
from app.core.config import settings
# Disable invite code requirement for tests
settings.REQUIRE_INVITE_CODE = False
# Test database URL (separate from production)
TEST_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/patherly_test"

View File

@@ -0,0 +1,239 @@
"""Tests for session sharing (create, access, revoke)."""
import pytest
from httpx import AsyncClient
pytestmark = pytest.mark.asyncio
class TestSessionSharing:
"""Test session share creation, access control, and revocation."""
async def _create_session(self, client, auth_headers, tree_id):
"""Helper: start a session for a tree."""
response = await client.post(
"/api/v1/sessions",
json={"tree_id": tree_id},
headers=auth_headers
)
assert response.status_code == 201
return response.json()
async def test_create_public_share(self, client: AsyncClient, auth_headers, test_tree):
"""Create a public share link."""
session = await self._create_session(client, auth_headers, test_tree["id"])
response = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public", "share_name": "Customer link"},
headers=auth_headers
)
assert response.status_code == 201
share = response.json()
assert share["visibility"] == "public"
assert share["share_name"] == "Customer link"
assert share["is_active"] is True
assert share["view_count"] == 0
assert len(share["share_token"]) > 0
async def test_create_account_share(self, client: AsyncClient, auth_headers, test_tree):
"""Create an account-only share link."""
session = await self._create_session(client, auth_headers, test_tree["id"])
response = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "account"},
headers=auth_headers
)
assert response.status_code == 201
assert response.json()["visibility"] == "account"
async def test_access_public_share(self, client: AsyncClient, auth_headers, test_tree):
"""Access a public share without authentication."""
session = await self._create_session(client, auth_headers, test_tree["id"])
# Create share
share_resp = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public"},
headers=auth_headers
)
share_token = share_resp.json()["share_token"]
# Access without auth
response = await client.get(f"/api/v1/share/{share_token}")
assert response.status_code == 200
data = response.json()
assert data["session_id"] == session["id"]
assert data["visibility"] == "public"
assert data["path_taken"] is not None
async def test_access_revoked_share_returns_404(self, client: AsyncClient, auth_headers, test_tree):
"""Accessing a revoked share returns 404."""
session = await self._create_session(client, auth_headers, test_tree["id"])
# Create and revoke share
share_resp = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public"},
headers=auth_headers
)
share = share_resp.json()
await client.delete(
f"/api/v1/shares/{share['id']}",
headers=auth_headers
)
# Try to access revoked share
response = await client.get(f"/api/v1/share/{share['share_token']}")
assert response.status_code == 404
async def test_access_nonexistent_share_returns_404(self, client: AsyncClient):
"""Accessing a nonexistent share token returns 404."""
response = await client.get("/api/v1/share/nonexistent-token-12345")
assert response.status_code == 404
async def test_list_my_shares(self, client: AsyncClient, auth_headers, test_tree):
"""List shares created by current user."""
session = await self._create_session(client, auth_headers, test_tree["id"])
# Create two shares
await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public", "share_name": "Link 1"},
headers=auth_headers
)
await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "account", "share_name": "Link 2"},
headers=auth_headers
)
response = await client.get(
"/api/v1/shares/my-shares",
headers=auth_headers
)
assert response.status_code == 200
shares = response.json()
assert len(shares) == 2
async def test_revoke_share(self, client: AsyncClient, auth_headers, test_tree):
"""Revoke a share link (soft delete)."""
session = await self._create_session(client, auth_headers, test_tree["id"])
share_resp = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public"},
headers=auth_headers
)
share = share_resp.json()
# Revoke
response = await client.delete(
f"/api/v1/shares/{share['id']}",
headers=auth_headers
)
assert response.status_code == 204
# Verify it's gone from my-shares
list_resp = await client.get(
"/api/v1/shares/my-shares",
headers=auth_headers
)
shares = list_resp.json()
assert len(shares) == 0
async def test_multiple_shares_per_session(self, client: AsyncClient, auth_headers, test_tree):
"""Multiple shares for the same session work independently."""
session = await self._create_session(client, auth_headers, test_tree["id"])
# Create public + account shares
resp1 = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public", "share_name": "For customer"},
headers=auth_headers
)
resp2 = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "account", "share_name": "For team"},
headers=auth_headers
)
assert resp1.status_code == 201
assert resp2.status_code == 201
# Both tokens are different
assert resp1.json()["share_token"] != resp2.json()["share_token"]
# Both accessible
access1 = await client.get(f"/api/v1/share/{resp1.json()['share_token']}")
assert access1.status_code == 200
async def test_share_view_count_increments(self, client: AsyncClient, auth_headers, test_tree):
"""View count increments on each access."""
session = await self._create_session(client, auth_headers, test_tree["id"])
share_resp = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public"},
headers=auth_headers
)
token = share_resp.json()["share_token"]
# Access three times
await client.get(f"/api/v1/share/{token}")
await client.get(f"/api/v1/share/{token}")
await client.get(f"/api/v1/share/{token}")
# Check view count via my-shares
list_resp = await client.get(
"/api/v1/shares/my-shares",
headers=auth_headers
)
shares = list_resp.json()
assert shares[0]["view_count"] == 3
async def test_share_requires_session_ownership(self, client: AsyncClient, auth_headers, test_tree, test_db):
"""Non-owner cannot create a share for someone else's session."""
session = await self._create_session(client, auth_headers, test_tree["id"])
# Register a different user
await client.post("/api/v1/auth/register", json={
"email": "other@example.com",
"password": "OtherPassword123!",
"name": "Other User"
})
login_resp = await client.post("/api/v1/auth/login/json", json={
"email": "other@example.com",
"password": "OtherPassword123!"
})
other_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
# Try to share other user's session
response = await client.post(
f"/api/v1/sessions/{session['id']}/shares",
json={"visibility": "public"},
headers=other_headers
)
assert response.status_code == 403
async def test_share_nonexistent_session(self, client: AsyncClient, auth_headers):
"""Creating a share for nonexistent session returns 404."""
response = await client.post(
"/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares",
json={"visibility": "public"},
headers=auth_headers
)
assert response.status_code == 404
async def test_create_share_requires_auth(self, client: AsyncClient, test_tree):
"""Creating a share without auth returns 401."""
response = await client.post(
"/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares",
json={"visibility": "public"}
)
assert response.status_code == 401

View File

@@ -0,0 +1,168 @@
"""Tests for tree forking and lineage tracking."""
import pytest
from httpx import AsyncClient
pytestmark = pytest.mark.asyncio
class TestTreeForking:
"""Test tree fork creation, lineage, and update detection."""
async def test_fork_tree(self, client: AsyncClient, auth_headers, test_tree):
"""Fork a tree and verify fork metadata."""
response = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Customizing for our network"},
headers=auth_headers
)
assert response.status_code == 201
fork = response.json()
assert fork["name"] == f"Fork of {test_tree['name']}"
assert fork["tree_structure"] == test_tree["tree_structure"]
assert fork["is_public"] is False
assert fork["version"] == 1
# Verify fork_info
assert fork["fork_info"] is not None
assert fork["fork_info"]["parent_tree_id"] == test_tree["id"]
assert fork["fork_info"]["root_tree_id"] == test_tree["id"]
assert fork["fork_info"]["fork_reason"] == "Customizing for our network"
assert fork["fork_info"]["fork_depth"] == 1
assert fork["fork_info"]["has_parent_updates"] is False
async def test_fork_with_custom_name(self, client: AsyncClient, auth_headers, test_tree):
"""Fork a tree with a custom name."""
response = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"name": "My Custom Fork", "fork_reason": "Testing"},
headers=auth_headers
)
assert response.status_code == 201
assert response.json()["name"] == "My Custom Fork"
async def test_fork_of_fork_lineage(self, client: AsyncClient, auth_headers, test_tree):
"""Fork a fork and verify lineage tracking (root_tree_id, fork_depth)."""
# Create first fork
resp1 = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Junior's customization"},
headers=auth_headers
)
assert resp1.status_code == 201
fork1 = resp1.json()
# Fork the fork
resp2 = await client.post(
f"/api/v1/trees/{fork1['id']}/fork",
json={"fork_reason": "Senior's refinement"},
headers=auth_headers
)
assert resp2.status_code == 201
fork2 = resp2.json()
# Verify lineage
assert fork2["fork_info"]["parent_tree_id"] == fork1["id"]
assert fork2["fork_info"]["root_tree_id"] == test_tree["id"] # Points to original
assert fork2["fork_info"]["fork_depth"] == 2
assert fork2["fork_info"]["fork_reason"] == "Senior's refinement"
async def test_fork_nonexistent_tree(self, client: AsyncClient, auth_headers):
"""Fork a nonexistent tree returns 404."""
response = await client.post(
"/api/v1/trees/00000000-0000-0000-0000-000000000000/fork",
json={"fork_reason": "test"},
headers=auth_headers
)
assert response.status_code == 404
async def test_list_forks(self, client: AsyncClient, auth_headers, test_tree):
"""List forks of a tree."""
# Create two forks
await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Fork 1"},
headers=auth_headers
)
await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Fork 2"},
headers=auth_headers
)
response = await client.get(
f"/api/v1/trees/{test_tree['id']}/forks",
headers=auth_headers
)
assert response.status_code == 200
forks = response.json()
assert len(forks) == 2
async def test_lineage_chain(self, client: AsyncClient, auth_headers, test_tree):
"""Get lineage from fork back to root."""
# Create chain: root → fork1 → fork2
resp1 = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Level 1"},
headers=auth_headers
)
fork1 = resp1.json()
resp2 = await client.post(
f"/api/v1/trees/{fork1['id']}/fork",
json={"fork_reason": "Level 2"},
headers=auth_headers
)
fork2 = resp2.json()
# Get lineage from fork2
response = await client.get(
f"/api/v1/trees/{fork2['id']}/lineage",
headers=auth_headers
)
assert response.status_code == 200
lineage = response.json()
# Should be [fork2, fork1, root]
assert len(lineage) == 3
assert lineage[0]["id"] == fork2["id"]
assert lineage[1]["id"] == fork1["id"]
assert lineage[2]["id"] == test_tree["id"]
async def test_lineage_of_root_tree(self, client: AsyncClient, auth_headers, test_tree):
"""Root tree lineage is just itself."""
response = await client.get(
f"/api/v1/trees/{test_tree['id']}/lineage",
headers=auth_headers
)
assert response.status_code == 200
lineage = response.json()
assert len(lineage) == 1
assert lineage[0]["id"] == test_tree["id"]
async def test_fork_preserves_tree_structure(self, client: AsyncClient, auth_headers, test_tree):
"""Fork copies the complete tree_structure JSONB."""
response = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "Copy check"},
headers=auth_headers
)
fork = response.json()
# Get full fork detail
detail = await client.get(
f"/api/v1/trees/{fork['id']}",
headers=auth_headers
)
assert detail.status_code == 200
assert detail.json()["tree_structure"] == test_tree["tree_structure"]
async def test_fork_requires_auth(self, client: AsyncClient, test_tree):
"""Fork without auth returns 401."""
response = await client.post(
f"/api/v1/trees/{test_tree['id']}/fork",
json={"fork_reason": "No auth"}
)
assert response.status_code == 401