feat: add tree forking, custom step tracking, and session sharing
Implement three foundational schema features from the design doc: - Tree forking with lineage tracking (migration 022): parent_tree_id, root_tree_id, fork_depth columns with self-referential FKs and composite analytics index - Custom step enhancement: CustomStepSchema with source tracking (ad-hoc, step-library, forked-tree) for backward-compatible JSONB - Session sharing (migration 023): session_shares and session_share_views tables with account-scoped visibility, cryptographic tokens, view tracking, and allow_public_shares account policy Includes 21 new integration tests (9 forking, 12 sharing), SaaS consultant-recommended denormalizations, rate limiting on public share access, and test fixture fix for invite code requirement. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
287
backend/app/api/endpoints/shares.py
Normal file
287
backend/app/api/endpoints/shares.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.session import Session
|
||||
from app.models.session_share import SessionShare, SessionShareView
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
from app.schemas.session_share import ShareCreate, ShareResponse, SharePublicView
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin
|
||||
from app.core.audit import log_audit
|
||||
from app.core.rate_limit import limiter
|
||||
|
||||
router = APIRouter(tags=["shares"])
|
||||
|
||||
|
||||
def build_share_response(share: SessionShare) -> ShareResponse:
|
||||
return ShareResponse(
|
||||
id=share.id,
|
||||
session_id=share.session_id,
|
||||
account_id=share.account_id,
|
||||
share_token=share.share_token,
|
||||
share_name=share.share_name,
|
||||
visibility=share.visibility,
|
||||
created_by=share.created_by,
|
||||
created_at=share.created_at,
|
||||
updated_at=share.updated_at,
|
||||
expires_at=share.expires_at,
|
||||
view_count=share.view_count,
|
||||
last_viewed_at=share.last_viewed_at,
|
||||
is_active=share.is_active,
|
||||
)
|
||||
|
||||
|
||||
# --- Session Share CRUD ---
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/shares",
|
||||
response_model=ShareResponse,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_share(
|
||||
session_id: UUID,
|
||||
share_data: ShareCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)]
|
||||
):
|
||||
"""Create a share link for a session.
|
||||
|
||||
Only the session owner can create shares.
|
||||
Public shares require account.allow_public_shares policy.
|
||||
"""
|
||||
# Verify session exists and user owns it
|
||||
result = await db.execute(
|
||||
select(Session).where(Session.id == session_id)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
if session.user_id != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the session owner can create share links"
|
||||
)
|
||||
|
||||
# Require account_id for account-scoped shares
|
||||
if share_data.visibility == "account" and not current_user.account_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create account-scoped share without an account"
|
||||
)
|
||||
|
||||
# Check account policy for public shares
|
||||
if share_data.visibility == "public" and current_user.account_id:
|
||||
account_result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = account_result.scalar_one_or_none()
|
||||
if account and not account.allow_public_shares:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Your organization does not allow public session sharing. Use account-only visibility."
|
||||
)
|
||||
|
||||
# Generate token with collision retry
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
share_token = secrets.token_urlsafe(48)
|
||||
|
||||
share = SessionShare(
|
||||
session_id=session_id,
|
||||
account_id=current_user.account_id,
|
||||
share_token=share_token,
|
||||
share_name=share_data.share_name,
|
||||
visibility=share_data.visibility,
|
||||
created_by=current_user.id,
|
||||
expires_at=share_data.expires_at,
|
||||
)
|
||||
|
||||
db.add(share)
|
||||
await db.flush()
|
||||
|
||||
await log_audit(db, current_user.id, "share.create", "session_share", share.id,
|
||||
{"session_id": str(session_id), "visibility": share_data.visibility})
|
||||
await db.commit()
|
||||
await db.refresh(share)
|
||||
|
||||
return build_share_response(share)
|
||||
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
if "session_shares_share_token_key" in str(e) and attempt < max_retries - 1:
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/shares/my-shares", response_model=list[ShareResponse])
|
||||
async def list_my_shares(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100)
|
||||
):
|
||||
"""List all shares created by the current user."""
|
||||
result = await db.execute(
|
||||
select(SessionShare)
|
||||
.where(
|
||||
SessionShare.created_by == current_user.id,
|
||||
SessionShare.is_active == True
|
||||
)
|
||||
.order_by(SessionShare.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
shares = result.scalars().all()
|
||||
return [build_share_response(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/shares/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def revoke_share(
|
||||
share_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Revoke a share link (soft delete - sets is_active=False)."""
|
||||
result = await db.execute(
|
||||
select(SessionShare).where(SessionShare.id == share_id)
|
||||
)
|
||||
share = result.scalar_one_or_none()
|
||||
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share not found"
|
||||
)
|
||||
|
||||
if share.created_by != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the share creator can revoke it"
|
||||
)
|
||||
|
||||
share.is_active = False
|
||||
|
||||
await log_audit(db, current_user.id, "share.revoke", "session_share", share.id,
|
||||
{"session_id": str(share.session_id)})
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
# --- Public Share Access ---
|
||||
|
||||
|
||||
async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[User]:
|
||||
"""Try to extract authenticated user from request, return None if not authenticated."""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
try:
|
||||
from app.core.security import decode_token
|
||||
payload = decode_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
return None
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
result = await db.execute(select(User).where(User.id == UUID(user_id)))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/share/{share_token}", response_model=SharePublicView)
|
||||
@limiter.limit("30/minute")
|
||||
async def access_share(
|
||||
share_token: str,
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Access a shared session via share token.
|
||||
|
||||
Public shares: No authentication required.
|
||||
Account-only shares: Requires authentication + account membership.
|
||||
"""
|
||||
current_user = await _get_optional_user(request, db)
|
||||
|
||||
# Lookup share
|
||||
result = await db.execute(
|
||||
select(SessionShare)
|
||||
.options(joinedload(SessionShare.session))
|
||||
.where(SessionShare.share_token == share_token)
|
||||
)
|
||||
share = result.scalar_one_or_none()
|
||||
|
||||
# Validate share
|
||||
if not share or not share.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share not found or has been revoked"
|
||||
)
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE,
|
||||
detail="Share link has expired"
|
||||
)
|
||||
|
||||
# Check visibility
|
||||
if share.visibility == "account":
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="This share requires authentication"
|
||||
)
|
||||
if current_user.account_id != share.account_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
)
|
||||
|
||||
# Record view
|
||||
session = share.session
|
||||
view = SessionShareView(
|
||||
share_id=share.id,
|
||||
session_id=session.id,
|
||||
viewer_id=current_user.id if current_user else None,
|
||||
viewer_ip=request.client.host if request.client else None,
|
||||
viewer_user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
db.add(view)
|
||||
|
||||
share.view_count += 1
|
||||
share.last_viewed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
# Build read-only response
|
||||
tree_snapshot = session.tree_snapshot or {}
|
||||
return SharePublicView(
|
||||
session_id=session.id,
|
||||
tree_name=tree_snapshot.get("question", "Untitled Tree"),
|
||||
tree_description=tree_snapshot.get("description"),
|
||||
tree_structure=tree_snapshot,
|
||||
path_taken=session.path_taken or [],
|
||||
decisions=session.decisions or [],
|
||||
custom_steps=session.custom_steps or [],
|
||||
started_at=session.started_at,
|
||||
completed_at=session.completed_at,
|
||||
ticket_number=session.ticket_number,
|
||||
client_name=session.client_name,
|
||||
share_name=share.share_name,
|
||||
visibility=share.visibility,
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from app.models.user import User
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.tag import TreeTag, tree_tag_assignments
|
||||
from app.models.folder import UserFolder, user_folder_trees
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo, ForkCreate, ForkInfo
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
@@ -73,8 +73,8 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
|
||||
)
|
||||
|
||||
|
||||
def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
"""Build TreeResponse with all details including category_info and tags."""
|
||||
def build_full_tree_response(tree: Tree, parent_tree: Tree = None) -> TreeResponse:
|
||||
"""Build TreeResponse with all details including category_info, tags, and fork_info."""
|
||||
category_info = None
|
||||
if tree.category_rel:
|
||||
category_info = CategoryInfo(
|
||||
@@ -83,6 +83,20 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
slug=tree.category_rel.slug
|
||||
)
|
||||
|
||||
fork_info = None
|
||||
if tree.parent_tree_id or tree.fork_depth > 0:
|
||||
has_updates = False
|
||||
if parent_tree and tree.parent_updated_at:
|
||||
has_updates = parent_tree.updated_at > tree.parent_updated_at
|
||||
fork_info = ForkInfo(
|
||||
parent_tree_id=tree.parent_tree_id,
|
||||
root_tree_id=tree.root_tree_id,
|
||||
fork_reason=tree.fork_reason,
|
||||
fork_depth=tree.fork_depth,
|
||||
parent_updated_at=tree.parent_updated_at,
|
||||
has_parent_updates=has_updates
|
||||
)
|
||||
|
||||
return TreeResponse(
|
||||
id=tree.id,
|
||||
name=tree.name,
|
||||
@@ -91,6 +105,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
category_id=tree.category_id,
|
||||
category_info=category_info,
|
||||
tags=tree.tag_names,
|
||||
fork_info=fork_info,
|
||||
tree_structure=tree.tree_structure,
|
||||
author_id=tree.author_id,
|
||||
account_id=tree.account_id,
|
||||
@@ -561,3 +576,166 @@ async def delete_tree(
|
||||
{"tree_name": tree.name})
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
# --- Fork Endpoints ---
|
||||
|
||||
|
||||
@router.post("/{tree_id}/fork", response_model=TreeResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def fork_tree(
|
||||
tree_id: UUID,
|
||||
fork_data: ForkCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)]
|
||||
):
|
||||
"""Fork a tree to create a personal copy.
|
||||
|
||||
Engineers can fork any tree they can access (public, account, or default).
|
||||
Fork inherits tree_structure but gets new ownership.
|
||||
"""
|
||||
# Load parent tree
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == tree_id)
|
||||
)
|
||||
parent = result.scalar_one_or_none()
|
||||
|
||||
if not parent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
if not can_access_tree(current_user, parent):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
)
|
||||
|
||||
# Check subscription tree limit
|
||||
if current_user.account_id:
|
||||
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
||||
if not can_create:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
# Build fork
|
||||
fork_name = fork_data.name or f"Fork of {parent.name}"
|
||||
fork = Tree(
|
||||
name=fork_name,
|
||||
description=parent.description,
|
||||
category=parent.category,
|
||||
category_id=parent.category_id,
|
||||
tree_structure=parent.tree_structure,
|
||||
author_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
is_public=False,
|
||||
is_default=False,
|
||||
version=1,
|
||||
# Fork tracking
|
||||
parent_tree_id=parent.id,
|
||||
fork_reason=fork_data.fork_reason,
|
||||
parent_updated_at=parent.updated_at,
|
||||
# Lineage tracking
|
||||
root_tree_id=parent.root_tree_id if parent.root_tree_id else parent.id,
|
||||
fork_depth=parent.fork_depth + 1,
|
||||
)
|
||||
|
||||
db.add(fork)
|
||||
await db.flush()
|
||||
|
||||
await log_audit(db, current_user.id, "tree.fork", "tree", fork.id,
|
||||
{"parent_tree_id": str(parent.id), "parent_name": parent.name,
|
||||
"fork_reason": fork_data.fork_reason})
|
||||
await db.commit()
|
||||
|
||||
# Reload with relationships
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == fork.id)
|
||||
)
|
||||
fork = result.scalar_one()
|
||||
|
||||
return build_full_tree_response(fork, parent_tree=parent)
|
||||
|
||||
|
||||
@router.get("/{tree_id}/forks", response_model=list[TreeListResponse])
|
||||
async def list_forks(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100)
|
||||
):
|
||||
"""List all direct forks of a tree."""
|
||||
# Verify parent exists and user can access it
|
||||
parent_result = await db.execute(select(Tree).where(Tree.id == tree_id))
|
||||
parent = parent_result.scalar_one_or_none()
|
||||
|
||||
if not parent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
if not can_access_tree(current_user, parent):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
)
|
||||
|
||||
# Query direct forks, filtered by access
|
||||
query = select(Tree).options(
|
||||
selectinload(Tree.category_rel),
|
||||
selectinload(Tree.tags)
|
||||
).where(
|
||||
Tree.parent_tree_id == tree_id,
|
||||
Tree.is_active == True,
|
||||
build_tree_access_filter(current_user)
|
||||
).order_by(Tree.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
forks = result.scalars().unique().all()
|
||||
|
||||
return [build_tree_response(tree) for tree in forks]
|
||||
|
||||
|
||||
@router.get("/{tree_id}/lineage", response_model=list[TreeListResponse])
|
||||
async def get_tree_lineage(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get the fork lineage chain from current tree back to root.
|
||||
|
||||
Returns ordered list: [current tree, parent, grandparent, ..., root]
|
||||
Limited to 10 levels to prevent infinite loops.
|
||||
"""
|
||||
lineage = []
|
||||
current_id = tree_id
|
||||
visited = set()
|
||||
max_depth = 10
|
||||
|
||||
for _ in range(max_depth):
|
||||
if current_id is None or current_id in visited:
|
||||
break
|
||||
visited.add(current_id)
|
||||
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == current_id)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
|
||||
if not tree:
|
||||
break
|
||||
|
||||
lineage.append(build_tree_response(tree))
|
||||
current_id = tree.parent_tree_id
|
||||
|
||||
return lineage
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -15,3 +15,4 @@ api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
api_router.include_router(shares.router)
|
||||
|
||||
Reference in New Issue
Block a user