feat: add tree forking, custom step tracking, and session sharing
Implement three foundational schema features from the design doc: - Tree forking with lineage tracking (migration 022): parent_tree_id, root_tree_id, fork_depth columns with self-referential FKs and composite analytics index - Custom step enhancement: CustomStepSchema with source tracking (ad-hoc, step-library, forked-tree) for backward-compatible JSONB - Session sharing (migration 023): session_shares and session_share_views tables with account-scoped visibility, cryptographic tokens, view tracking, and allow_public_shares account policy Includes 21 new integration tests (9 forking, 12 sharing), SaaS consultant-recommended denormalizations, rate limiting on public share access, and test fixture fix for invite code requirement. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
73
backend/alembic/versions/022_add_tree_forking.py
Normal file
73
backend/alembic/versions/022_add_tree_forking.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""add tree forking support with lineage tracking
|
||||
|
||||
Revision ID: 022
|
||||
Revises: 021
|
||||
Create Date: 2026-02-07
|
||||
|
||||
Adds fork relationship and lineage tracking to trees table:
|
||||
- parent_tree_id: Points to immediate parent tree (NULL for root trees)
|
||||
- fork_reason: Optional engineer note explaining fork purpose
|
||||
- parent_updated_at: Snapshot of parent's updated_at at fork time
|
||||
- root_tree_id: Points to original tree at root of fork chain
|
||||
- fork_depth: How many forks deep (1 = direct fork, 2 = fork of fork, etc.)
|
||||
|
||||
Fork on parent delete behavior: SET NULL (orphaned forks survive)
|
||||
Root tree delete behavior: SET NULL (lineage preserved via fork_depth)
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '022'
|
||||
down_revision: Union[str, None] = '021'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add fork tracking columns
|
||||
op.add_column('trees', sa.Column('parent_tree_id', UUID(as_uuid=True), nullable=True))
|
||||
op.add_column('trees', sa.Column('fork_reason', sa.String(255), nullable=True))
|
||||
op.add_column('trees', sa.Column('parent_updated_at', sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
# Add fork lineage columns
|
||||
op.add_column('trees', sa.Column('root_tree_id', UUID(as_uuid=True), nullable=True))
|
||||
op.add_column('trees', sa.Column('fork_depth', sa.Integer, nullable=False, server_default='0'))
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
'fk_trees_parent_tree_id',
|
||||
'trees', 'trees',
|
||||
['parent_tree_id'], ['id'],
|
||||
ondelete='SET NULL'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_trees_root_tree_id',
|
||||
'trees', 'trees',
|
||||
['root_tree_id'], ['id'],
|
||||
ondelete='SET NULL'
|
||||
)
|
||||
|
||||
# Add indexes for fork queries
|
||||
op.create_index('ix_trees_parent_tree_id', 'trees', ['parent_tree_id'])
|
||||
op.create_index('ix_trees_root_tree_id', 'trees', ['root_tree_id'])
|
||||
|
||||
# Composite index for fork analytics (descendants + depth sorting)
|
||||
op.create_index('ix_trees_fork_analytics', 'trees', ['root_tree_id', 'fork_depth'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_trees_fork_analytics', table_name='trees')
|
||||
op.drop_index('ix_trees_root_tree_id', table_name='trees')
|
||||
op.drop_index('ix_trees_parent_tree_id', table_name='trees')
|
||||
op.drop_constraint('fk_trees_root_tree_id', 'trees', type_='foreignkey')
|
||||
op.drop_constraint('fk_trees_parent_tree_id', 'trees', type_='foreignkey')
|
||||
op.drop_column('trees', 'fork_depth')
|
||||
op.drop_column('trees', 'root_tree_id')
|
||||
op.drop_column('trees', 'parent_updated_at')
|
||||
op.drop_column('trees', 'fork_reason')
|
||||
op.drop_column('trees', 'parent_tree_id')
|
||||
89
backend/alembic/versions/023_add_session_sharing.py
Normal file
89
backend/alembic/versions/023_add_session_sharing.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""add session sharing tables
|
||||
|
||||
Revision ID: 023
|
||||
Revises: 022
|
||||
Create Date: 2026-02-07
|
||||
|
||||
Adds session sharing infrastructure:
|
||||
- session_shares: Share tokens with visibility control (public/account-only)
|
||||
- session_share_views: Detailed view tracking with viewer identification
|
||||
- accounts.allow_public_shares: Account-level policy for public share creation
|
||||
|
||||
Includes SaaS consultant-recommended denormalizations:
|
||||
- account_id on session_shares (faster access control, survives user transfers)
|
||||
- session_id on session_share_views (faster analytics queries)
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '023'
|
||||
down_revision: Union[str, None] = '022'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create session_shares table
|
||||
op.create_table(
|
||||
'session_shares',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')),
|
||||
sa.Column('session_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('account_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('share_token', sa.String(64), nullable=False, unique=True),
|
||||
sa.Column('share_name', sa.String(100), nullable=True),
|
||||
sa.Column('visibility', sa.String(20), nullable=False, server_default='public'),
|
||||
sa.Column('created_by', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('view_count', sa.Integer, nullable=False, server_default='0'),
|
||||
sa.Column('last_viewed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean, nullable=False, server_default='true'),
|
||||
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.CheckConstraint("visibility IN ('public', 'account')", name='ck_session_shares_visibility')
|
||||
)
|
||||
|
||||
# Create indexes for session_shares
|
||||
op.create_index('ix_session_shares_session_id', 'session_shares', ['session_id'])
|
||||
op.create_index('ix_session_shares_account_id', 'session_shares', ['account_id'])
|
||||
op.create_index('ix_session_shares_share_token', 'session_shares', ['share_token'])
|
||||
op.create_index('ix_session_shares_created_by', 'session_shares', ['created_by'])
|
||||
op.create_index('ix_session_shares_expires_at', 'session_shares', ['expires_at'])
|
||||
op.create_index('ix_session_shares_is_active', 'session_shares', ['is_active'])
|
||||
|
||||
# Create session_share_views table
|
||||
op.create_table(
|
||||
'session_share_views',
|
||||
sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')),
|
||||
sa.Column('share_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('session_id', UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('viewer_id', UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('viewed_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('viewer_ip', sa.String(45), nullable=True),
|
||||
sa.Column('viewer_user_agent', sa.String(500), nullable=True),
|
||||
sa.ForeignKeyConstraint(['share_id'], ['session_shares.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['viewer_id'], ['users.id'], ondelete='SET NULL')
|
||||
)
|
||||
|
||||
# Create indexes for session_share_views
|
||||
op.create_index('ix_session_share_views_share_id', 'session_share_views', ['share_id'])
|
||||
op.create_index('ix_session_share_views_session_id', 'session_share_views', ['session_id'])
|
||||
op.create_index('ix_session_share_views_viewer_id', 'session_share_views', ['viewer_id'])
|
||||
op.create_index('ix_session_share_views_viewed_at', 'session_share_views', ['viewed_at'])
|
||||
|
||||
# Add account policy for public shares
|
||||
op.add_column('accounts', sa.Column('allow_public_shares', sa.Boolean, nullable=False, server_default='true'))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('accounts', 'allow_public_shares')
|
||||
op.drop_table('session_share_views')
|
||||
op.drop_table('session_shares')
|
||||
287
backend/app/api/endpoints/shares.py
Normal file
287
backend/app/api/endpoints/shares.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.session import Session
|
||||
from app.models.session_share import SessionShare, SessionShareView
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
from app.schemas.session_share import ShareCreate, ShareResponse, SharePublicView
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin
|
||||
from app.core.audit import log_audit
|
||||
from app.core.rate_limit import limiter
|
||||
|
||||
router = APIRouter(tags=["shares"])
|
||||
|
||||
|
||||
def build_share_response(share: SessionShare) -> ShareResponse:
|
||||
return ShareResponse(
|
||||
id=share.id,
|
||||
session_id=share.session_id,
|
||||
account_id=share.account_id,
|
||||
share_token=share.share_token,
|
||||
share_name=share.share_name,
|
||||
visibility=share.visibility,
|
||||
created_by=share.created_by,
|
||||
created_at=share.created_at,
|
||||
updated_at=share.updated_at,
|
||||
expires_at=share.expires_at,
|
||||
view_count=share.view_count,
|
||||
last_viewed_at=share.last_viewed_at,
|
||||
is_active=share.is_active,
|
||||
)
|
||||
|
||||
|
||||
# --- Session Share CRUD ---
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/shares",
|
||||
response_model=ShareResponse,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_share(
|
||||
session_id: UUID,
|
||||
share_data: ShareCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)]
|
||||
):
|
||||
"""Create a share link for a session.
|
||||
|
||||
Only the session owner can create shares.
|
||||
Public shares require account.allow_public_shares policy.
|
||||
"""
|
||||
# Verify session exists and user owns it
|
||||
result = await db.execute(
|
||||
select(Session).where(Session.id == session_id)
|
||||
)
|
||||
session = result.scalar_one_or_none()
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found"
|
||||
)
|
||||
|
||||
if session.user_id != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the session owner can create share links"
|
||||
)
|
||||
|
||||
# Require account_id for account-scoped shares
|
||||
if share_data.visibility == "account" and not current_user.account_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create account-scoped share without an account"
|
||||
)
|
||||
|
||||
# Check account policy for public shares
|
||||
if share_data.visibility == "public" and current_user.account_id:
|
||||
account_result = await db.execute(
|
||||
select(Account).where(Account.id == current_user.account_id)
|
||||
)
|
||||
account = account_result.scalar_one_or_none()
|
||||
if account and not account.allow_public_shares:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Your organization does not allow public session sharing. Use account-only visibility."
|
||||
)
|
||||
|
||||
# Generate token with collision retry
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
share_token = secrets.token_urlsafe(48)
|
||||
|
||||
share = SessionShare(
|
||||
session_id=session_id,
|
||||
account_id=current_user.account_id,
|
||||
share_token=share_token,
|
||||
share_name=share_data.share_name,
|
||||
visibility=share_data.visibility,
|
||||
created_by=current_user.id,
|
||||
expires_at=share_data.expires_at,
|
||||
)
|
||||
|
||||
db.add(share)
|
||||
await db.flush()
|
||||
|
||||
await log_audit(db, current_user.id, "share.create", "session_share", share.id,
|
||||
{"session_id": str(session_id), "visibility": share_data.visibility})
|
||||
await db.commit()
|
||||
await db.refresh(share)
|
||||
|
||||
return build_share_response(share)
|
||||
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
if "session_shares_share_token_key" in str(e) and attempt < max_retries - 1:
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/shares/my-shares", response_model=list[ShareResponse])
|
||||
async def list_my_shares(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100)
|
||||
):
|
||||
"""List all shares created by the current user."""
|
||||
result = await db.execute(
|
||||
select(SessionShare)
|
||||
.where(
|
||||
SessionShare.created_by == current_user.id,
|
||||
SessionShare.is_active == True
|
||||
)
|
||||
.order_by(SessionShare.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
shares = result.scalars().all()
|
||||
return [build_share_response(s) for s in shares]
|
||||
|
||||
|
||||
@router.delete("/shares/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def revoke_share(
|
||||
share_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Revoke a share link (soft delete - sets is_active=False)."""
|
||||
result = await db.execute(
|
||||
select(SessionShare).where(SessionShare.id == share_id)
|
||||
)
|
||||
share = result.scalar_one_or_none()
|
||||
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share not found"
|
||||
)
|
||||
|
||||
if share.created_by != current_user.id and not current_user.is_super_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only the share creator can revoke it"
|
||||
)
|
||||
|
||||
share.is_active = False
|
||||
|
||||
await log_audit(db, current_user.id, "share.revoke", "session_share", share.id,
|
||||
{"session_id": str(share.session_id)})
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
# --- Public Share Access ---
|
||||
|
||||
|
||||
async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[User]:
|
||||
"""Try to extract authenticated user from request, return None if not authenticated."""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
try:
|
||||
from app.core.security import decode_token
|
||||
payload = decode_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
return None
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
result = await db.execute(select(User).where(User.id == UUID(user_id)))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/share/{share_token}", response_model=SharePublicView)
|
||||
@limiter.limit("30/minute")
|
||||
async def access_share(
|
||||
share_token: str,
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Access a shared session via share token.
|
||||
|
||||
Public shares: No authentication required.
|
||||
Account-only shares: Requires authentication + account membership.
|
||||
"""
|
||||
current_user = await _get_optional_user(request, db)
|
||||
|
||||
# Lookup share
|
||||
result = await db.execute(
|
||||
select(SessionShare)
|
||||
.options(joinedload(SessionShare.session))
|
||||
.where(SessionShare.share_token == share_token)
|
||||
)
|
||||
share = result.scalar_one_or_none()
|
||||
|
||||
# Validate share
|
||||
if not share or not share.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share not found or has been revoked"
|
||||
)
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE,
|
||||
detail="Share link has expired"
|
||||
)
|
||||
|
||||
# Check visibility
|
||||
if share.visibility == "account":
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="This share requires authentication"
|
||||
)
|
||||
if current_user.account_id != share.account_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this session"
|
||||
)
|
||||
|
||||
# Record view
|
||||
session = share.session
|
||||
view = SessionShareView(
|
||||
share_id=share.id,
|
||||
session_id=session.id,
|
||||
viewer_id=current_user.id if current_user else None,
|
||||
viewer_ip=request.client.host if request.client else None,
|
||||
viewer_user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
db.add(view)
|
||||
|
||||
share.view_count += 1
|
||||
share.last_viewed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
# Build read-only response
|
||||
tree_snapshot = session.tree_snapshot or {}
|
||||
return SharePublicView(
|
||||
session_id=session.id,
|
||||
tree_name=tree_snapshot.get("question", "Untitled Tree"),
|
||||
tree_description=tree_snapshot.get("description"),
|
||||
tree_structure=tree_snapshot,
|
||||
path_taken=session.path_taken or [],
|
||||
decisions=session.decisions or [],
|
||||
custom_steps=session.custom_steps or [],
|
||||
started_at=session.started_at,
|
||||
completed_at=session.completed_at,
|
||||
ticket_number=session.ticket_number,
|
||||
client_name=session.client_name,
|
||||
share_name=share.share_name,
|
||||
visibility=share.visibility,
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from app.models.user import User
|
||||
from app.models.category import TreeCategory
|
||||
from app.models.tag import TreeTag, tree_tag_assignments
|
||||
from app.models.folder import UserFolder, user_folder_trees
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo
|
||||
from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo, ForkCreate, ForkInfo
|
||||
from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin
|
||||
from app.core.permissions import can_edit_tree, can_access_tree
|
||||
from app.core.subscriptions import check_tree_limit
|
||||
@@ -73,8 +73,8 @@ def build_tree_response(tree: Tree) -> TreeListResponse:
|
||||
)
|
||||
|
||||
|
||||
def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
"""Build TreeResponse with all details including category_info and tags."""
|
||||
def build_full_tree_response(tree: Tree, parent_tree: Tree = None) -> TreeResponse:
|
||||
"""Build TreeResponse with all details including category_info, tags, and fork_info."""
|
||||
category_info = None
|
||||
if tree.category_rel:
|
||||
category_info = CategoryInfo(
|
||||
@@ -83,6 +83,20 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
slug=tree.category_rel.slug
|
||||
)
|
||||
|
||||
fork_info = None
|
||||
if tree.parent_tree_id or tree.fork_depth > 0:
|
||||
has_updates = False
|
||||
if parent_tree and tree.parent_updated_at:
|
||||
has_updates = parent_tree.updated_at > tree.parent_updated_at
|
||||
fork_info = ForkInfo(
|
||||
parent_tree_id=tree.parent_tree_id,
|
||||
root_tree_id=tree.root_tree_id,
|
||||
fork_reason=tree.fork_reason,
|
||||
fork_depth=tree.fork_depth,
|
||||
parent_updated_at=tree.parent_updated_at,
|
||||
has_parent_updates=has_updates
|
||||
)
|
||||
|
||||
return TreeResponse(
|
||||
id=tree.id,
|
||||
name=tree.name,
|
||||
@@ -91,6 +105,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse:
|
||||
category_id=tree.category_id,
|
||||
category_info=category_info,
|
||||
tags=tree.tag_names,
|
||||
fork_info=fork_info,
|
||||
tree_structure=tree.tree_structure,
|
||||
author_id=tree.author_id,
|
||||
account_id=tree.account_id,
|
||||
@@ -561,3 +576,166 @@ async def delete_tree(
|
||||
{"tree_name": tree.name})
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
# --- Fork Endpoints ---
|
||||
|
||||
|
||||
@router.post("/{tree_id}/fork", response_model=TreeResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def fork_tree(
|
||||
tree_id: UUID,
|
||||
fork_data: ForkCreate,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(require_engineer_or_admin)]
|
||||
):
|
||||
"""Fork a tree to create a personal copy.
|
||||
|
||||
Engineers can fork any tree they can access (public, account, or default).
|
||||
Fork inherits tree_structure but gets new ownership.
|
||||
"""
|
||||
# Load parent tree
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == tree_id)
|
||||
)
|
||||
parent = result.scalar_one_or_none()
|
||||
|
||||
if not parent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
if not can_access_tree(current_user, parent):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
)
|
||||
|
||||
# Check subscription tree limit
|
||||
if current_user.account_id:
|
||||
can_create, limit, count = await check_tree_limit(current_user.account_id, db)
|
||||
if not can_create:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees."
|
||||
)
|
||||
|
||||
# Build fork
|
||||
fork_name = fork_data.name or f"Fork of {parent.name}"
|
||||
fork = Tree(
|
||||
name=fork_name,
|
||||
description=parent.description,
|
||||
category=parent.category,
|
||||
category_id=parent.category_id,
|
||||
tree_structure=parent.tree_structure,
|
||||
author_id=current_user.id,
|
||||
account_id=current_user.account_id,
|
||||
is_public=False,
|
||||
is_default=False,
|
||||
version=1,
|
||||
# Fork tracking
|
||||
parent_tree_id=parent.id,
|
||||
fork_reason=fork_data.fork_reason,
|
||||
parent_updated_at=parent.updated_at,
|
||||
# Lineage tracking
|
||||
root_tree_id=parent.root_tree_id if parent.root_tree_id else parent.id,
|
||||
fork_depth=parent.fork_depth + 1,
|
||||
)
|
||||
|
||||
db.add(fork)
|
||||
await db.flush()
|
||||
|
||||
await log_audit(db, current_user.id, "tree.fork", "tree", fork.id,
|
||||
{"parent_tree_id": str(parent.id), "parent_name": parent.name,
|
||||
"fork_reason": fork_data.fork_reason})
|
||||
await db.commit()
|
||||
|
||||
# Reload with relationships
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == fork.id)
|
||||
)
|
||||
fork = result.scalar_one()
|
||||
|
||||
return build_full_tree_response(fork, parent_tree=parent)
|
||||
|
||||
|
||||
@router.get("/{tree_id}/forks", response_model=list[TreeListResponse])
|
||||
async def list_forks(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100)
|
||||
):
|
||||
"""List all direct forks of a tree."""
|
||||
# Verify parent exists and user can access it
|
||||
parent_result = await db.execute(select(Tree).where(Tree.id == tree_id))
|
||||
parent = parent_result.scalar_one_or_none()
|
||||
|
||||
if not parent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tree not found"
|
||||
)
|
||||
|
||||
if not can_access_tree(current_user, parent):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have access to this tree"
|
||||
)
|
||||
|
||||
# Query direct forks, filtered by access
|
||||
query = select(Tree).options(
|
||||
selectinload(Tree.category_rel),
|
||||
selectinload(Tree.tags)
|
||||
).where(
|
||||
Tree.parent_tree_id == tree_id,
|
||||
Tree.is_active == True,
|
||||
build_tree_access_filter(current_user)
|
||||
).order_by(Tree.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
forks = result.scalars().unique().all()
|
||||
|
||||
return [build_tree_response(tree) for tree in forks]
|
||||
|
||||
|
||||
@router.get("/{tree_id}/lineage", response_model=list[TreeListResponse])
|
||||
async def get_tree_lineage(
|
||||
tree_id: UUID,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
):
|
||||
"""Get the fork lineage chain from current tree back to root.
|
||||
|
||||
Returns ordered list: [current tree, parent, grandparent, ..., root]
|
||||
Limited to 10 levels to prevent infinite loops.
|
||||
"""
|
||||
lineage = []
|
||||
current_id = tree_id
|
||||
visited = set()
|
||||
max_depth = 10
|
||||
|
||||
for _ in range(max_depth):
|
||||
if current_id is None or current_id in visited:
|
||||
break
|
||||
visited.add(current_id)
|
||||
|
||||
result = await db.execute(
|
||||
select(Tree)
|
||||
.options(selectinload(Tree.category_rel), selectinload(Tree.tags))
|
||||
.where(Tree.id == current_id)
|
||||
)
|
||||
tree = result.scalar_one_or_none()
|
||||
|
||||
if not tree:
|
||||
break
|
||||
|
||||
lineage.append(build_tree_response(tree))
|
||||
current_id = tree.parent_tree_id
|
||||
|
||||
return lineage
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks
|
||||
from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -15,3 +15,4 @@ api_router.include_router(steps.router)
|
||||
api_router.include_router(admin.router)
|
||||
api_router.include_router(accounts.router)
|
||||
api_router.include_router(webhooks.router)
|
||||
api_router.include_router(shares.router)
|
||||
|
||||
@@ -15,6 +15,7 @@ from .step_category import StepCategory
|
||||
from .step_library import StepLibrary, StepRating, StepUsageLog
|
||||
from .refresh_token import RefreshToken
|
||||
from .audit_log import AuditLog
|
||||
from .session_share import SessionShare, SessionShareView
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
@@ -38,4 +39,6 @@ __all__ = [
|
||||
"StepUsageLog",
|
||||
"RefreshToken",
|
||||
"AuditLog",
|
||||
"SessionShare",
|
||||
"SessionShareView",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
@@ -26,6 +26,13 @@ class Account(Base):
|
||||
stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
allow_public_shares: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
server_default="true",
|
||||
comment="Policy: engineers can create public shares. Only affects NEW shares (grandfathered)."
|
||||
)
|
||||
|
||||
# Relationships
|
||||
owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Text
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.session_share import SessionShare
|
||||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = "sessions"
|
||||
@@ -53,3 +56,4 @@ class Session(Base):
|
||||
tree: Mapped["Tree"] = relationship("Tree", back_populates="sessions")
|
||||
user: Mapped["User"] = relationship("User", back_populates="sessions")
|
||||
attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session")
|
||||
shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan")
|
||||
|
||||
152
backend/app/models/session_share.py
Normal file
152
backend/app/models/session_share.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer, CheckConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.session import Session
|
||||
from app.models.user import User
|
||||
from app.models.account import Account
|
||||
|
||||
|
||||
class SessionShare(Base):
|
||||
__tablename__ = "session_shares"
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"visibility IN ('public', 'account')",
|
||||
name='ck_session_shares_visibility'
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
account_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("accounts.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Account that owns this share (denormalized from session at creation)"
|
||||
)
|
||||
share_token: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="URL-safe random token (48 bytes -> 64 base64 chars)"
|
||||
)
|
||||
share_name: Mapped[Optional[str]] = mapped_column(
|
||||
String(100),
|
||||
nullable=True,
|
||||
comment="Optional label: 'Training link', 'Customer escalation #1234'"
|
||||
)
|
||||
visibility: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="public",
|
||||
comment="public = anyone with link, account = account members only"
|
||||
)
|
||||
created_by: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="Optional expiration for time-limited shares"
|
||||
)
|
||||
view_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0
|
||||
)
|
||||
last_viewed_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
index=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
session: Mapped["Session"] = relationship("Session", back_populates="shares")
|
||||
account: Mapped["Account"] = relationship("Account")
|
||||
creator: Mapped["User"] = relationship("User", foreign_keys=[created_by])
|
||||
views: Mapped[list["SessionShareView"]] = relationship(
|
||||
"SessionShareView",
|
||||
back_populates="share",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class SessionShareView(Base):
|
||||
__tablename__ = "session_share_views"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4
|
||||
)
|
||||
share_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("session_shares.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
session_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sessions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="Denormalized from share for analytics queries"
|
||||
)
|
||||
viewer_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="NULL for public shares (unauthenticated views)"
|
||||
)
|
||||
viewed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
index=True
|
||||
)
|
||||
viewer_ip: Mapped[Optional[str]] = mapped_column(
|
||||
String(45), # IPv6 max length
|
||||
nullable=True
|
||||
)
|
||||
viewer_user_agent: Mapped[Optional[str]] = mapped_column(
|
||||
String(500),
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
share: Mapped["SessionShare"] = relationship("SessionShare", back_populates="views")
|
||||
viewer: Mapped[Optional["User"]] = relationship("User")
|
||||
@@ -79,10 +79,62 @@ class Tree(Base):
|
||||
)
|
||||
usage_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# Fork tracking
|
||||
parent_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
fork_reason: Mapped[Optional[str]] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="Brief reason: 'Added Cisco Meraki steps for our network'"
|
||||
)
|
||||
parent_updated_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Snapshot of parent's updated_at when fork created. Compare to detect parent updates."
|
||||
)
|
||||
|
||||
# Fork lineage tracking
|
||||
root_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("trees.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="Original tree at root of fork chain (NULL for non-forked trees)"
|
||||
)
|
||||
fork_depth: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default="0",
|
||||
comment="Fork depth: 0 = original, 1 = direct fork, 2 = fork of fork, etc."
|
||||
)
|
||||
|
||||
# Relationships
|
||||
author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees")
|
||||
team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees")
|
||||
account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees")
|
||||
|
||||
# Fork relationships (self-referential)
|
||||
parent: Mapped[Optional["Tree"]] = relationship(
|
||||
"Tree",
|
||||
remote_side="Tree.id",
|
||||
foreign_keys=[parent_tree_id],
|
||||
back_populates="forks"
|
||||
)
|
||||
forks: Mapped[list["Tree"]] = relationship(
|
||||
"Tree",
|
||||
foreign_keys=[parent_tree_id],
|
||||
back_populates="parent"
|
||||
)
|
||||
root: Mapped[Optional["Tree"]] = relationship(
|
||||
"Tree",
|
||||
remote_side="Tree.id",
|
||||
foreign_keys=[root_tree_id]
|
||||
)
|
||||
sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree")
|
||||
|
||||
# New organization relationships
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Literal
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class CustomStepSchema(BaseModel):
|
||||
"""Enhanced custom step with source tracking.
|
||||
|
||||
Backward compatible: old sessions without new fields load with defaults.
|
||||
"""
|
||||
type: str # "decision" | "action" | "solution"
|
||||
content: str
|
||||
notes: Optional[str] = None
|
||||
|
||||
# Source tracking (new fields, optional for backward compatibility)
|
||||
source: Literal["ad-hoc", "step-library", "forked-tree"] = "ad-hoc"
|
||||
source_step_id: Optional[UUID] = None
|
||||
inserted_at: Optional[datetime] = None
|
||||
inserted_after_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class DecisionRecord(BaseModel):
|
||||
node_id: str
|
||||
question: Optional[str] = None
|
||||
@@ -24,7 +40,7 @@ class SessionCreate(BaseModel):
|
||||
class SessionUpdate(BaseModel):
|
||||
path_taken: Optional[list[str]] = None
|
||||
decisions: Optional[list[DecisionRecord]] = None
|
||||
custom_steps: Optional[list[dict[str, Any]]] = None
|
||||
custom_steps: Optional[list[CustomStepSchema]] = None
|
||||
ticket_number: Optional[str] = Field(None, max_length=100)
|
||||
client_name: Optional[str] = Field(None, max_length=255)
|
||||
scratchpad: Optional[str] = None
|
||||
|
||||
47
backend/app/schemas/session_share.py
Normal file
47
backend/app/schemas/session_share.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Literal
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
visibility: Literal["public", "account"] = Field("public", description="Share visibility")
|
||||
share_name: Optional[str] = Field(None, max_length=100, description="Optional label for the share")
|
||||
expires_at: Optional[datetime] = Field(None, description="Optional expiration datetime")
|
||||
|
||||
|
||||
class ShareResponse(BaseModel):
|
||||
id: UUID
|
||||
session_id: UUID
|
||||
account_id: UUID
|
||||
share_token: str
|
||||
share_name: Optional[str] = None
|
||||
visibility: str
|
||||
created_by: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
expires_at: Optional[datetime] = None
|
||||
view_count: int
|
||||
last_viewed_at: Optional[datetime] = None
|
||||
is_active: bool
|
||||
share_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SharePublicView(BaseModel):
|
||||
"""Read-only session data returned when accessing a share link."""
|
||||
session_id: UUID
|
||||
tree_name: str
|
||||
tree_description: Optional[str] = None
|
||||
tree_structure: dict
|
||||
path_taken: list[str]
|
||||
decisions: list[dict]
|
||||
custom_steps: list[dict] = Field(default_factory=list)
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
ticket_number: Optional[str] = None
|
||||
client_name: Optional[str] = None
|
||||
share_name: Optional[str] = None
|
||||
visibility: str
|
||||
@@ -40,6 +40,24 @@ class TreeUpdate(BaseModel):
|
||||
tags: Optional[list[str]] = Field(None, max_length=10, description="List of tag names to assign (replaces existing)")
|
||||
|
||||
|
||||
class ForkCreate(BaseModel):
|
||||
fork_reason: Optional[str] = Field(None, max_length=255, description="Brief reason for forking")
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255, description="Name for the fork (defaults to 'Fork of {original name}')")
|
||||
|
||||
|
||||
class ForkInfo(BaseModel):
|
||||
"""Fork metadata included in tree responses."""
|
||||
parent_tree_id: Optional[UUID] = None
|
||||
root_tree_id: Optional[UUID] = None
|
||||
fork_reason: Optional[str] = None
|
||||
fork_depth: int = 0
|
||||
parent_updated_at: Optional[datetime] = None
|
||||
has_parent_updates: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TreeResponse(TreeBase):
|
||||
id: UUID
|
||||
tree_structure: dict[str, Any]
|
||||
@@ -48,6 +66,7 @@ class TreeResponse(TreeBase):
|
||||
category_id: Optional[UUID] = None
|
||||
category_info: Optional[CategoryInfo] = None
|
||||
tags: list[str] = [] # List of tag names
|
||||
fork_info: Optional[ForkInfo] = None
|
||||
is_active: bool
|
||||
is_public: bool
|
||||
is_default: bool
|
||||
|
||||
@@ -16,6 +16,8 @@ from app.main import app
|
||||
from app.core.database import Base, get_db
|
||||
from app.core.config import settings
|
||||
|
||||
# Disable invite code requirement for tests
|
||||
settings.REQUIRE_INVITE_CODE = False
|
||||
|
||||
# Test database URL (separate from production)
|
||||
TEST_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/patherly_test"
|
||||
|
||||
239
backend/tests/test_session_sharing.py
Normal file
239
backend/tests/test_session_sharing.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Tests for session sharing (create, access, revoke)."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestSessionSharing:
|
||||
"""Test session share creation, access control, and revocation."""
|
||||
|
||||
async def _create_session(self, client, auth_headers, tree_id):
|
||||
"""Helper: start a session for a tree."""
|
||||
response = await client.post(
|
||||
"/api/v1/sessions",
|
||||
json={"tree_id": tree_id},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
async def test_create_public_share(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Create a public share link."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public", "share_name": "Customer link"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
share = response.json()
|
||||
assert share["visibility"] == "public"
|
||||
assert share["share_name"] == "Customer link"
|
||||
assert share["is_active"] is True
|
||||
assert share["view_count"] == 0
|
||||
assert len(share["share_token"]) > 0
|
||||
|
||||
async def test_create_account_share(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Create an account-only share link."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "account"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["visibility"] == "account"
|
||||
|
||||
async def test_access_public_share(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Access a public share without authentication."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
# Create share
|
||||
share_resp = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=auth_headers
|
||||
)
|
||||
share_token = share_resp.json()["share_token"]
|
||||
|
||||
# Access without auth
|
||||
response = await client.get(f"/api/v1/share/{share_token}")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert data["session_id"] == session["id"]
|
||||
assert data["visibility"] == "public"
|
||||
assert data["path_taken"] is not None
|
||||
|
||||
async def test_access_revoked_share_returns_404(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Accessing a revoked share returns 404."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
# Create and revoke share
|
||||
share_resp = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=auth_headers
|
||||
)
|
||||
share = share_resp.json()
|
||||
|
||||
await client.delete(
|
||||
f"/api/v1/shares/{share['id']}",
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
# Try to access revoked share
|
||||
response = await client.get(f"/api/v1/share/{share['share_token']}")
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_access_nonexistent_share_returns_404(self, client: AsyncClient):
|
||||
"""Accessing a nonexistent share token returns 404."""
|
||||
response = await client.get("/api/v1/share/nonexistent-token-12345")
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_list_my_shares(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""List shares created by current user."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
# Create two shares
|
||||
await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public", "share_name": "Link 1"},
|
||||
headers=auth_headers
|
||||
)
|
||||
await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "account", "share_name": "Link 2"},
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/shares/my-shares",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
shares = response.json()
|
||||
assert len(shares) == 2
|
||||
|
||||
async def test_revoke_share(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Revoke a share link (soft delete)."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
share_resp = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=auth_headers
|
||||
)
|
||||
share = share_resp.json()
|
||||
|
||||
# Revoke
|
||||
response = await client.delete(
|
||||
f"/api/v1/shares/{share['id']}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify it's gone from my-shares
|
||||
list_resp = await client.get(
|
||||
"/api/v1/shares/my-shares",
|
||||
headers=auth_headers
|
||||
)
|
||||
shares = list_resp.json()
|
||||
assert len(shares) == 0
|
||||
|
||||
async def test_multiple_shares_per_session(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Multiple shares for the same session work independently."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
# Create public + account shares
|
||||
resp1 = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public", "share_name": "For customer"},
|
||||
headers=auth_headers
|
||||
)
|
||||
resp2 = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "account", "share_name": "For team"},
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
assert resp1.status_code == 201
|
||||
assert resp2.status_code == 201
|
||||
|
||||
# Both tokens are different
|
||||
assert resp1.json()["share_token"] != resp2.json()["share_token"]
|
||||
|
||||
# Both accessible
|
||||
access1 = await client.get(f"/api/v1/share/{resp1.json()['share_token']}")
|
||||
assert access1.status_code == 200
|
||||
|
||||
async def test_share_view_count_increments(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""View count increments on each access."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
share_resp = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=auth_headers
|
||||
)
|
||||
token = share_resp.json()["share_token"]
|
||||
|
||||
# Access three times
|
||||
await client.get(f"/api/v1/share/{token}")
|
||||
await client.get(f"/api/v1/share/{token}")
|
||||
await client.get(f"/api/v1/share/{token}")
|
||||
|
||||
# Check view count via my-shares
|
||||
list_resp = await client.get(
|
||||
"/api/v1/shares/my-shares",
|
||||
headers=auth_headers
|
||||
)
|
||||
shares = list_resp.json()
|
||||
assert shares[0]["view_count"] == 3
|
||||
|
||||
async def test_share_requires_session_ownership(self, client: AsyncClient, auth_headers, test_tree, test_db):
|
||||
"""Non-owner cannot create a share for someone else's session."""
|
||||
session = await self._create_session(client, auth_headers, test_tree["id"])
|
||||
|
||||
# Register a different user
|
||||
await client.post("/api/v1/auth/register", json={
|
||||
"email": "other@example.com",
|
||||
"password": "OtherPassword123!",
|
||||
"name": "Other User"
|
||||
})
|
||||
login_resp = await client.post("/api/v1/auth/login/json", json={
|
||||
"email": "other@example.com",
|
||||
"password": "OtherPassword123!"
|
||||
})
|
||||
other_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"}
|
||||
|
||||
# Try to share other user's session
|
||||
response = await client.post(
|
||||
f"/api/v1/sessions/{session['id']}/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=other_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_share_nonexistent_session(self, client: AsyncClient, auth_headers):
|
||||
"""Creating a share for nonexistent session returns 404."""
|
||||
response = await client.post(
|
||||
"/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares",
|
||||
json={"visibility": "public"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_create_share_requires_auth(self, client: AsyncClient, test_tree):
|
||||
"""Creating a share without auth returns 401."""
|
||||
response = await client.post(
|
||||
"/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares",
|
||||
json={"visibility": "public"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
168
backend/tests/test_tree_forking.py
Normal file
168
backend/tests/test_tree_forking.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for tree forking and lineage tracking."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestTreeForking:
|
||||
"""Test tree fork creation, lineage, and update detection."""
|
||||
|
||||
async def test_fork_tree(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Fork a tree and verify fork metadata."""
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Customizing for our network"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
fork = response.json()
|
||||
assert fork["name"] == f"Fork of {test_tree['name']}"
|
||||
assert fork["tree_structure"] == test_tree["tree_structure"]
|
||||
assert fork["is_public"] is False
|
||||
assert fork["version"] == 1
|
||||
|
||||
# Verify fork_info
|
||||
assert fork["fork_info"] is not None
|
||||
assert fork["fork_info"]["parent_tree_id"] == test_tree["id"]
|
||||
assert fork["fork_info"]["root_tree_id"] == test_tree["id"]
|
||||
assert fork["fork_info"]["fork_reason"] == "Customizing for our network"
|
||||
assert fork["fork_info"]["fork_depth"] == 1
|
||||
assert fork["fork_info"]["has_parent_updates"] is False
|
||||
|
||||
async def test_fork_with_custom_name(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Fork a tree with a custom name."""
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"name": "My Custom Fork", "fork_reason": "Testing"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 201
|
||||
assert response.json()["name"] == "My Custom Fork"
|
||||
|
||||
async def test_fork_of_fork_lineage(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Fork a fork and verify lineage tracking (root_tree_id, fork_depth)."""
|
||||
# Create first fork
|
||||
resp1 = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Junior's customization"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert resp1.status_code == 201
|
||||
fork1 = resp1.json()
|
||||
|
||||
# Fork the fork
|
||||
resp2 = await client.post(
|
||||
f"/api/v1/trees/{fork1['id']}/fork",
|
||||
json={"fork_reason": "Senior's refinement"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert resp2.status_code == 201
|
||||
fork2 = resp2.json()
|
||||
|
||||
# Verify lineage
|
||||
assert fork2["fork_info"]["parent_tree_id"] == fork1["id"]
|
||||
assert fork2["fork_info"]["root_tree_id"] == test_tree["id"] # Points to original
|
||||
assert fork2["fork_info"]["fork_depth"] == 2
|
||||
assert fork2["fork_info"]["fork_reason"] == "Senior's refinement"
|
||||
|
||||
async def test_fork_nonexistent_tree(self, client: AsyncClient, auth_headers):
|
||||
"""Fork a nonexistent tree returns 404."""
|
||||
response = await client.post(
|
||||
"/api/v1/trees/00000000-0000-0000-0000-000000000000/fork",
|
||||
json={"fork_reason": "test"},
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_list_forks(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""List forks of a tree."""
|
||||
# Create two forks
|
||||
await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Fork 1"},
|
||||
headers=auth_headers
|
||||
)
|
||||
await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Fork 2"},
|
||||
headers=auth_headers
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/trees/{test_tree['id']}/forks",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
forks = response.json()
|
||||
assert len(forks) == 2
|
||||
|
||||
async def test_lineage_chain(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Get lineage from fork back to root."""
|
||||
# Create chain: root → fork1 → fork2
|
||||
resp1 = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Level 1"},
|
||||
headers=auth_headers
|
||||
)
|
||||
fork1 = resp1.json()
|
||||
|
||||
resp2 = await client.post(
|
||||
f"/api/v1/trees/{fork1['id']}/fork",
|
||||
json={"fork_reason": "Level 2"},
|
||||
headers=auth_headers
|
||||
)
|
||||
fork2 = resp2.json()
|
||||
|
||||
# Get lineage from fork2
|
||||
response = await client.get(
|
||||
f"/api/v1/trees/{fork2['id']}/lineage",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
lineage = response.json()
|
||||
|
||||
# Should be [fork2, fork1, root]
|
||||
assert len(lineage) == 3
|
||||
assert lineage[0]["id"] == fork2["id"]
|
||||
assert lineage[1]["id"] == fork1["id"]
|
||||
assert lineage[2]["id"] == test_tree["id"]
|
||||
|
||||
async def test_lineage_of_root_tree(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Root tree lineage is just itself."""
|
||||
response = await client.get(
|
||||
f"/api/v1/trees/{test_tree['id']}/lineage",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
lineage = response.json()
|
||||
assert len(lineage) == 1
|
||||
assert lineage[0]["id"] == test_tree["id"]
|
||||
|
||||
async def test_fork_preserves_tree_structure(self, client: AsyncClient, auth_headers, test_tree):
|
||||
"""Fork copies the complete tree_structure JSONB."""
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "Copy check"},
|
||||
headers=auth_headers
|
||||
)
|
||||
fork = response.json()
|
||||
|
||||
# Get full fork detail
|
||||
detail = await client.get(
|
||||
f"/api/v1/trees/{fork['id']}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["tree_structure"] == test_tree["tree_structure"]
|
||||
|
||||
async def test_fork_requires_auth(self, client: AsyncClient, test_tree):
|
||||
"""Fork without auth returns 401."""
|
||||
response = await client.post(
|
||||
f"/api/v1/trees/{test_tree['id']}/fork",
|
||||
json={"fork_reason": "No auth"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
Reference in New Issue
Block a user