From ffb14cd014fcd3961858a90900bde8201572f746 Mon Sep 17 00:00:00 2001 From: Michael Chihlas Date: Sat, 7 Feb 2026 19:10:47 -0500 Subject: [PATCH] feat: add tree forking, custom step tracking, and session sharing Implement three foundational schema features from the design doc: - Tree forking with lineage tracking (migration 022): parent_tree_id, root_tree_id, fork_depth columns with self-referential FKs and composite analytics index - Custom step enhancement: CustomStepSchema with source tracking (ad-hoc, step-library, forked-tree) for backward-compatible JSONB - Session sharing (migration 023): session_shares and session_share_views tables with account-scoped visibility, cryptographic tokens, view tracking, and allow_public_shares account policy Includes 21 new integration tests (9 forking, 12 sharing), SaaS consultant-recommended denormalizations, rate limiting on public share access, and test fixture fix for invite code requirement. Co-Authored-By: Claude Opus 4.6 --- .../alembic/versions/022_add_tree_forking.py | 73 +++++ .../versions/023_add_session_sharing.py | 89 ++++++ backend/app/api/endpoints/shares.py | 287 ++++++++++++++++++ backend/app/api/endpoints/trees.py | 184 ++++++++++- backend/app/api/router.py | 3 +- backend/app/models/__init__.py | 3 + backend/app/models/account.py | 9 +- backend/app/models/session.py | 6 +- backend/app/models/session_share.py | 152 ++++++++++ backend/app/models/tree.py | 52 ++++ backend/app/schemas/session.py | 20 +- backend/app/schemas/session_share.py | 47 +++ backend/app/schemas/tree.py | 19 ++ backend/tests/conftest.py | 2 + backend/tests/test_session_sharing.py | 239 +++++++++++++++ backend/tests/test_tree_forking.py | 168 ++++++++++ 16 files changed, 1345 insertions(+), 8 deletions(-) create mode 100644 backend/alembic/versions/022_add_tree_forking.py create mode 100644 backend/alembic/versions/023_add_session_sharing.py create mode 100644 backend/app/api/endpoints/shares.py create mode 100644 backend/app/models/session_share.py create mode 100644 backend/app/schemas/session_share.py create mode 100644 backend/tests/test_session_sharing.py create mode 100644 backend/tests/test_tree_forking.py diff --git a/backend/alembic/versions/022_add_tree_forking.py b/backend/alembic/versions/022_add_tree_forking.py new file mode 100644 index 00000000..8884e308 --- /dev/null +++ b/backend/alembic/versions/022_add_tree_forking.py @@ -0,0 +1,73 @@ +"""add tree forking support with lineage tracking + +Revision ID: 022 +Revises: 021 +Create Date: 2026-02-07 + +Adds fork relationship and lineage tracking to trees table: +- parent_tree_id: Points to immediate parent tree (NULL for root trees) +- fork_reason: Optional engineer note explaining fork purpose +- parent_updated_at: Snapshot of parent's updated_at at fork time +- root_tree_id: Points to original tree at root of fork chain +- fork_depth: How many forks deep (1 = direct fork, 2 = fork of fork, etc.) + +Fork on parent delete behavior: SET NULL (orphaned forks survive) +Root tree delete behavior: SET NULL (lineage preserved via fork_depth) +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '022' +down_revision: Union[str, None] = '021' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add fork tracking columns + op.add_column('trees', sa.Column('parent_tree_id', UUID(as_uuid=True), nullable=True)) + op.add_column('trees', sa.Column('fork_reason', sa.String(255), nullable=True)) + op.add_column('trees', sa.Column('parent_updated_at', sa.DateTime(timezone=True), nullable=True)) + + # Add fork lineage columns + op.add_column('trees', sa.Column('root_tree_id', UUID(as_uuid=True), nullable=True)) + op.add_column('trees', sa.Column('fork_depth', sa.Integer, nullable=False, server_default='0')) + + # Add foreign key constraints + op.create_foreign_key( + 'fk_trees_parent_tree_id', + 'trees', 'trees', + ['parent_tree_id'], ['id'], + ondelete='SET NULL' + ) + op.create_foreign_key( + 'fk_trees_root_tree_id', + 'trees', 'trees', + ['root_tree_id'], ['id'], + ondelete='SET NULL' + ) + + # Add indexes for fork queries + op.create_index('ix_trees_parent_tree_id', 'trees', ['parent_tree_id']) + op.create_index('ix_trees_root_tree_id', 'trees', ['root_tree_id']) + + # Composite index for fork analytics (descendants + depth sorting) + op.create_index('ix_trees_fork_analytics', 'trees', ['root_tree_id', 'fork_depth']) + + +def downgrade() -> None: + op.drop_index('ix_trees_fork_analytics', table_name='trees') + op.drop_index('ix_trees_root_tree_id', table_name='trees') + op.drop_index('ix_trees_parent_tree_id', table_name='trees') + op.drop_constraint('fk_trees_root_tree_id', 'trees', type_='foreignkey') + op.drop_constraint('fk_trees_parent_tree_id', 'trees', type_='foreignkey') + op.drop_column('trees', 'fork_depth') + op.drop_column('trees', 'root_tree_id') + op.drop_column('trees', 'parent_updated_at') + op.drop_column('trees', 'fork_reason') + op.drop_column('trees', 'parent_tree_id') diff --git a/backend/alembic/versions/023_add_session_sharing.py b/backend/alembic/versions/023_add_session_sharing.py new file mode 100644 index 00000000..9b45f59b --- /dev/null +++ b/backend/alembic/versions/023_add_session_sharing.py @@ -0,0 +1,89 @@ +"""add session sharing tables + +Revision ID: 023 +Revises: 022 +Create Date: 2026-02-07 + +Adds session sharing infrastructure: +- session_shares: Share tokens with visibility control (public/account-only) +- session_share_views: Detailed view tracking with viewer identification +- accounts.allow_public_shares: Account-level policy for public share creation + +Includes SaaS consultant-recommended denormalizations: +- account_id on session_shares (faster access control, survives user transfers) +- session_id on session_share_views (faster analytics queries) +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + + +# revision identifiers, used by Alembic. +revision: str = '023' +down_revision: Union[str, None] = '022' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create session_shares table + op.create_table( + 'session_shares', + sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), + sa.Column('session_id', UUID(as_uuid=True), nullable=False), + sa.Column('account_id', UUID(as_uuid=True), nullable=False), + sa.Column('share_token', sa.String(64), nullable=False, unique=True), + sa.Column('share_name', sa.String(100), nullable=True), + sa.Column('visibility', sa.String(20), nullable=False, server_default='public'), + sa.Column('created_by', UUID(as_uuid=True), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('view_count', sa.Integer, nullable=False, server_default='0'), + sa.Column('last_viewed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('is_active', sa.Boolean, nullable=False, server_default='true'), + sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='CASCADE'), + sa.CheckConstraint("visibility IN ('public', 'account')", name='ck_session_shares_visibility') + ) + + # Create indexes for session_shares + op.create_index('ix_session_shares_session_id', 'session_shares', ['session_id']) + op.create_index('ix_session_shares_account_id', 'session_shares', ['account_id']) + op.create_index('ix_session_shares_share_token', 'session_shares', ['share_token']) + op.create_index('ix_session_shares_created_by', 'session_shares', ['created_by']) + op.create_index('ix_session_shares_expires_at', 'session_shares', ['expires_at']) + op.create_index('ix_session_shares_is_active', 'session_shares', ['is_active']) + + # Create session_share_views table + op.create_table( + 'session_share_views', + sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), + sa.Column('share_id', UUID(as_uuid=True), nullable=False), + sa.Column('session_id', UUID(as_uuid=True), nullable=False), + sa.Column('viewer_id', UUID(as_uuid=True), nullable=True), + sa.Column('viewed_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('viewer_ip', sa.String(45), nullable=True), + sa.Column('viewer_user_agent', sa.String(500), nullable=True), + sa.ForeignKeyConstraint(['share_id'], ['session_shares.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['viewer_id'], ['users.id'], ondelete='SET NULL') + ) + + # Create indexes for session_share_views + op.create_index('ix_session_share_views_share_id', 'session_share_views', ['share_id']) + op.create_index('ix_session_share_views_session_id', 'session_share_views', ['session_id']) + op.create_index('ix_session_share_views_viewer_id', 'session_share_views', ['viewer_id']) + op.create_index('ix_session_share_views_viewed_at', 'session_share_views', ['viewed_at']) + + # Add account policy for public shares + op.add_column('accounts', sa.Column('allow_public_shares', sa.Boolean, nullable=False, server_default='true')) + + +def downgrade() -> None: + op.drop_column('accounts', 'allow_public_shares') + op.drop_table('session_share_views') + op.drop_table('session_shares') diff --git a/backend/app/api/endpoints/shares.py b/backend/app/api/endpoints/shares.py new file mode 100644 index 00000000..ee81e903 --- /dev/null +++ b/backend/app/api/endpoints/shares.py @@ -0,0 +1,287 @@ +import secrets +from datetime import datetime, timezone +from typing import Annotated, Optional +from uuid import UUID +from fastapi import APIRouter, Depends, HTTPException, Request, status, Query +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from sqlalchemy.orm import joinedload +from sqlalchemy.exc import IntegrityError + +from app.core.database import get_db +from app.models.session import Session +from app.models.session_share import SessionShare, SessionShareView +from app.models.user import User +from app.models.account import Account +from app.schemas.session_share import ShareCreate, ShareResponse, SharePublicView +from app.api.deps import get_current_active_user, require_engineer_or_admin +from app.core.audit import log_audit +from app.core.rate_limit import limiter + +router = APIRouter(tags=["shares"]) + + +def build_share_response(share: SessionShare) -> ShareResponse: + return ShareResponse( + id=share.id, + session_id=share.session_id, + account_id=share.account_id, + share_token=share.share_token, + share_name=share.share_name, + visibility=share.visibility, + created_by=share.created_by, + created_at=share.created_at, + updated_at=share.updated_at, + expires_at=share.expires_at, + view_count=share.view_count, + last_viewed_at=share.last_viewed_at, + is_active=share.is_active, + ) + + +# --- Session Share CRUD --- + + +@router.post( + "/sessions/{session_id}/shares", + response_model=ShareResponse, + status_code=status.HTTP_201_CREATED +) +async def create_share( + session_id: UUID, + share_data: ShareCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_engineer_or_admin)] +): + """Create a share link for a session. + + Only the session owner can create shares. + Public shares require account.allow_public_shares policy. + """ + # Verify session exists and user owns it + result = await db.execute( + select(Session).where(Session.id == session_id) + ) + session = result.scalar_one_or_none() + + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Session not found" + ) + + if session.user_id != current_user.id and not current_user.is_super_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the session owner can create share links" + ) + + # Require account_id for account-scoped shares + if share_data.visibility == "account" and not current_user.account_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot create account-scoped share without an account" + ) + + # Check account policy for public shares + if share_data.visibility == "public" and current_user.account_id: + account_result = await db.execute( + select(Account).where(Account.id == current_user.account_id) + ) + account = account_result.scalar_one_or_none() + if account and not account.allow_public_shares: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Your organization does not allow public session sharing. Use account-only visibility." + ) + + # Generate token with collision retry + max_retries = 3 + for attempt in range(max_retries): + try: + share_token = secrets.token_urlsafe(48) + + share = SessionShare( + session_id=session_id, + account_id=current_user.account_id, + share_token=share_token, + share_name=share_data.share_name, + visibility=share_data.visibility, + created_by=current_user.id, + expires_at=share_data.expires_at, + ) + + db.add(share) + await db.flush() + + await log_audit(db, current_user.id, "share.create", "session_share", share.id, + {"session_id": str(session_id), "visibility": share_data.visibility}) + await db.commit() + await db.refresh(share) + + return build_share_response(share) + + except IntegrityError as e: + await db.rollback() + if "session_shares_share_token_key" in str(e) and attempt < max_retries - 1: + continue + raise + + +@router.get("/shares/my-shares", response_model=list[ShareResponse]) +async def list_my_shares( + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100) +): + """List all shares created by the current user.""" + result = await db.execute( + select(SessionShare) + .where( + SessionShare.created_by == current_user.id, + SessionShare.is_active == True + ) + .order_by(SessionShare.created_at.desc()) + .offset(skip) + .limit(limit) + ) + shares = result.scalars().all() + return [build_share_response(s) for s in shares] + + +@router.delete("/shares/{share_id}", status_code=status.HTTP_204_NO_CONTENT) +async def revoke_share( + share_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Revoke a share link (soft delete - sets is_active=False).""" + result = await db.execute( + select(SessionShare).where(SessionShare.id == share_id) + ) + share = result.scalar_one_or_none() + + if not share: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Share not found" + ) + + if share.created_by != current_user.id and not current_user.is_super_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the share creator can revoke it" + ) + + share.is_active = False + + await log_audit(db, current_user.id, "share.revoke", "session_share", share.id, + {"session_id": str(share.session_id)}) + await db.commit() + return None + + +# --- Public Share Access --- + + +async def _get_optional_user(request: Request, db: AsyncSession) -> Optional[User]: + """Try to extract authenticated user from request, return None if not authenticated.""" + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return None + token = auth_header.replace("Bearer ", "") + try: + from app.core.security import decode_token + payload = decode_token(token) + if not payload or payload.get("type") != "access": + return None + user_id = payload.get("sub") + if not user_id: + return None + result = await db.execute(select(User).where(User.id == UUID(user_id))) + return result.scalar_one_or_none() + except Exception: + return None + + +@router.get("/share/{share_token}", response_model=SharePublicView) +@limiter.limit("30/minute") +async def access_share( + share_token: str, + request: Request, + db: Annotated[AsyncSession, Depends(get_db)], +): + """Access a shared session via share token. + + Public shares: No authentication required. + Account-only shares: Requires authentication + account membership. + """ + current_user = await _get_optional_user(request, db) + + # Lookup share + result = await db.execute( + select(SessionShare) + .options(joinedload(SessionShare.session)) + .where(SessionShare.share_token == share_token) + ) + share = result.scalar_one_or_none() + + # Validate share + if not share or not share.is_active: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Share not found or has been revoked" + ) + + if share.expires_at and share.expires_at < datetime.now(timezone.utc): + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail="Share link has expired" + ) + + # Check visibility + if share.visibility == "account": + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="This share requires authentication" + ) + if current_user.account_id != share.account_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to this session" + ) + + # Record view + session = share.session + view = SessionShareView( + share_id=share.id, + session_id=session.id, + viewer_id=current_user.id if current_user else None, + viewer_ip=request.client.host if request.client else None, + viewer_user_agent=request.headers.get("user-agent"), + ) + db.add(view) + + share.view_count += 1 + share.last_viewed_at = datetime.now(timezone.utc) + await db.commit() + + # Build read-only response + tree_snapshot = session.tree_snapshot or {} + return SharePublicView( + session_id=session.id, + tree_name=tree_snapshot.get("question", "Untitled Tree"), + tree_description=tree_snapshot.get("description"), + tree_structure=tree_snapshot, + path_taken=session.path_taken or [], + decisions=session.decisions or [], + custom_steps=session.custom_steps or [], + started_at=session.started_at, + completed_at=session.completed_at, + ticket_number=session.ticket_number, + client_name=session.client_name, + share_name=share.share_name, + visibility=share.visibility, + ) diff --git a/backend/app/api/endpoints/trees.py b/backend/app/api/endpoints/trees.py index 0a0fb6d0..1b49d783 100644 --- a/backend/app/api/endpoints/trees.py +++ b/backend/app/api/endpoints/trees.py @@ -12,7 +12,7 @@ from app.models.user import User from app.models.category import TreeCategory from app.models.tag import TreeTag, tree_tag_assignments from app.models.folder import UserFolder, user_folder_trees -from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo +from app.schemas.tree import TreeCreate, TreeUpdate, TreeResponse, TreeListResponse, CategoryInfo, ForkCreate, ForkInfo from app.api.deps import get_current_active_user, require_engineer_or_admin, require_admin from app.core.permissions import can_edit_tree, can_access_tree from app.core.subscriptions import check_tree_limit @@ -73,8 +73,8 @@ def build_tree_response(tree: Tree) -> TreeListResponse: ) -def build_full_tree_response(tree: Tree) -> TreeResponse: - """Build TreeResponse with all details including category_info and tags.""" +def build_full_tree_response(tree: Tree, parent_tree: Tree = None) -> TreeResponse: + """Build TreeResponse with all details including category_info, tags, and fork_info.""" category_info = None if tree.category_rel: category_info = CategoryInfo( @@ -83,6 +83,20 @@ def build_full_tree_response(tree: Tree) -> TreeResponse: slug=tree.category_rel.slug ) + fork_info = None + if tree.parent_tree_id or tree.fork_depth > 0: + has_updates = False + if parent_tree and tree.parent_updated_at: + has_updates = parent_tree.updated_at > tree.parent_updated_at + fork_info = ForkInfo( + parent_tree_id=tree.parent_tree_id, + root_tree_id=tree.root_tree_id, + fork_reason=tree.fork_reason, + fork_depth=tree.fork_depth, + parent_updated_at=tree.parent_updated_at, + has_parent_updates=has_updates + ) + return TreeResponse( id=tree.id, name=tree.name, @@ -91,6 +105,7 @@ def build_full_tree_response(tree: Tree) -> TreeResponse: category_id=tree.category_id, category_info=category_info, tags=tree.tag_names, + fork_info=fork_info, tree_structure=tree.tree_structure, author_id=tree.author_id, account_id=tree.account_id, @@ -561,3 +576,166 @@ async def delete_tree( {"tree_name": tree.name}) await db.commit() return None + + +# --- Fork Endpoints --- + + +@router.post("/{tree_id}/fork", response_model=TreeResponse, status_code=status.HTTP_201_CREATED) +async def fork_tree( + tree_id: UUID, + fork_data: ForkCreate, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(require_engineer_or_admin)] +): + """Fork a tree to create a personal copy. + + Engineers can fork any tree they can access (public, account, or default). + Fork inherits tree_structure but gets new ownership. + """ + # Load parent tree + result = await db.execute( + select(Tree) + .options(selectinload(Tree.category_rel), selectinload(Tree.tags)) + .where(Tree.id == tree_id) + ) + parent = result.scalar_one_or_none() + + if not parent: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Tree not found" + ) + + if not can_access_tree(current_user, parent): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to this tree" + ) + + # Check subscription tree limit + if current_user.account_id: + can_create, limit, count = await check_tree_limit(current_user.account_id, db) + if not can_create: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Tree limit reached ({count}/{limit}). Upgrade your plan to create more trees." + ) + + # Build fork + fork_name = fork_data.name or f"Fork of {parent.name}" + fork = Tree( + name=fork_name, + description=parent.description, + category=parent.category, + category_id=parent.category_id, + tree_structure=parent.tree_structure, + author_id=current_user.id, + account_id=current_user.account_id, + is_public=False, + is_default=False, + version=1, + # Fork tracking + parent_tree_id=parent.id, + fork_reason=fork_data.fork_reason, + parent_updated_at=parent.updated_at, + # Lineage tracking + root_tree_id=parent.root_tree_id if parent.root_tree_id else parent.id, + fork_depth=parent.fork_depth + 1, + ) + + db.add(fork) + await db.flush() + + await log_audit(db, current_user.id, "tree.fork", "tree", fork.id, + {"parent_tree_id": str(parent.id), "parent_name": parent.name, + "fork_reason": fork_data.fork_reason}) + await db.commit() + + # Reload with relationships + result = await db.execute( + select(Tree) + .options(selectinload(Tree.category_rel), selectinload(Tree.tags)) + .where(Tree.id == fork.id) + ) + fork = result.scalar_one() + + return build_full_tree_response(fork, parent_tree=parent) + + +@router.get("/{tree_id}/forks", response_model=list[TreeListResponse]) +async def list_forks( + tree_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)], + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=100) +): + """List all direct forks of a tree.""" + # Verify parent exists and user can access it + parent_result = await db.execute(select(Tree).where(Tree.id == tree_id)) + parent = parent_result.scalar_one_or_none() + + if not parent: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Tree not found" + ) + + if not can_access_tree(current_user, parent): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to this tree" + ) + + # Query direct forks, filtered by access + query = select(Tree).options( + selectinload(Tree.category_rel), + selectinload(Tree.tags) + ).where( + Tree.parent_tree_id == tree_id, + Tree.is_active == True, + build_tree_access_filter(current_user) + ).order_by(Tree.created_at.desc()).offset(skip).limit(limit) + + result = await db.execute(query) + forks = result.scalars().unique().all() + + return [build_tree_response(tree) for tree in forks] + + +@router.get("/{tree_id}/lineage", response_model=list[TreeListResponse]) +async def get_tree_lineage( + tree_id: UUID, + db: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_active_user)] +): + """Get the fork lineage chain from current tree back to root. + + Returns ordered list: [current tree, parent, grandparent, ..., root] + Limited to 10 levels to prevent infinite loops. + """ + lineage = [] + current_id = tree_id + visited = set() + max_depth = 10 + + for _ in range(max_depth): + if current_id is None or current_id in visited: + break + visited.add(current_id) + + result = await db.execute( + select(Tree) + .options(selectinload(Tree.category_rel), selectinload(Tree.tags)) + .where(Tree.id == current_id) + ) + tree = result.scalar_one_or_none() + + if not tree: + break + + lineage.append(build_tree_response(tree)) + current_id = tree.parent_tree_id + + return lineage diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 05a773bc..2ad2bf52 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks +from app.api.endpoints import auth, trees, sessions, invite, categories, tags, folders, step_categories, steps, admin, accounts, webhooks, shares api_router = APIRouter() @@ -15,3 +15,4 @@ api_router.include_router(steps.router) api_router.include_router(admin.router) api_router.include_router(accounts.router) api_router.include_router(webhooks.router) +api_router.include_router(shares.router) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 981c5b15..d655d533 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -15,6 +15,7 @@ from .step_category import StepCategory from .step_library import StepLibrary, StepRating, StepUsageLog from .refresh_token import RefreshToken from .audit_log import AuditLog +from .session_share import SessionShare, SessionShareView __all__ = [ "User", @@ -38,4 +39,6 @@ __all__ = [ "StepUsageLog", "RefreshToken", "AuditLog", + "SessionShare", + "SessionShareView", ] diff --git a/backend/app/models/account.py b/backend/app/models/account.py index e9e8be18..3f6472f7 100644 --- a/backend/app/models/account.py +++ b/backend/app/models/account.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime, timezone from typing import Optional, TYPE_CHECKING -from sqlalchemy import String, DateTime, ForeignKey +from sqlalchemy import String, DateTime, ForeignKey, Boolean from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID from app.core.database import Base @@ -26,6 +26,13 @@ class Account(Base): stripe_customer_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + allow_public_shares: Mapped[bool] = mapped_column( + Boolean, + nullable=False, + default=True, + server_default="true", + comment="Policy: engineers can create public shares. Only affects NEW shares (grandfathered)." + ) # Relationships owner: Mapped["User"] = relationship("User", foreign_keys=[owner_id], back_populates="owned_account") diff --git a/backend/app/models/session.py b/backend/app/models/session.py index bec2cdd6..4f2b028a 100644 --- a/backend/app/models/session.py +++ b/backend/app/models/session.py @@ -1,12 +1,15 @@ import uuid from datetime import datetime, timezone -from typing import Optional, Any +from typing import Optional, Any, TYPE_CHECKING from sqlalchemy import String, DateTime, ForeignKey, Boolean, Text import sqlalchemy as sa from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.dialects.postgresql import UUID, JSONB from app.core.database import Base +if TYPE_CHECKING: + from app.models.session_share import SessionShare + class Session(Base): __tablename__ = "sessions" @@ -53,3 +56,4 @@ class Session(Base): tree: Mapped["Tree"] = relationship("Tree", back_populates="sessions") user: Mapped["User"] = relationship("User", back_populates="sessions") attachments: Mapped[list["Attachment"]] = relationship("Attachment", back_populates="session") + shares: Mapped[list["SessionShare"]] = relationship("SessionShare", back_populates="session", cascade="all, delete-orphan") diff --git a/backend/app/models/session_share.py b/backend/app/models/session_share.py new file mode 100644 index 00000000..039d1a40 --- /dev/null +++ b/backend/app/models/session_share.py @@ -0,0 +1,152 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, TYPE_CHECKING +from sqlalchemy import String, DateTime, ForeignKey, Boolean, Integer, CheckConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID +from app.core.database import Base + +if TYPE_CHECKING: + from app.models.session import Session + from app.models.user import User + from app.models.account import Account + + +class SessionShare(Base): + __tablename__ = "session_shares" + __table_args__ = ( + CheckConstraint( + "visibility IN ('public', 'account')", + name='ck_session_shares_visibility' + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4 + ) + session_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("sessions.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + account_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("accounts.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Account that owns this share (denormalized from session at creation)" + ) + share_token: Mapped[str] = mapped_column( + String(64), + unique=True, + nullable=False, + index=True, + comment="URL-safe random token (48 bytes -> 64 base64 chars)" + ) + share_name: Mapped[Optional[str]] = mapped_column( + String(100), + nullable=True, + comment="Optional label: 'Training link', 'Customer escalation #1234'" + ) + visibility: Mapped[str] = mapped_column( + String(20), + nullable=False, + default="public", + comment="public = anyone with link, account = account members only" + ) + created_by: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc) + ) + expires_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + index=True, + comment="Optional expiration for time-limited shares" + ) + view_count: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=0 + ) + last_viewed_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True + ) + is_active: Mapped[bool] = mapped_column( + Boolean, + nullable=False, + default=True, + index=True + ) + + # Relationships + session: Mapped["Session"] = relationship("Session", back_populates="shares") + account: Mapped["Account"] = relationship("Account") + creator: Mapped["User"] = relationship("User", foreign_keys=[created_by]) + views: Mapped[list["SessionShareView"]] = relationship( + "SessionShareView", + back_populates="share", + cascade="all, delete-orphan" + ) + + +class SessionShareView(Base): + __tablename__ = "session_share_views" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4 + ) + share_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("session_shares.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + session_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("sessions.id", ondelete="CASCADE"), + nullable=False, + index=True, + comment="Denormalized from share for analytics queries" + ) + viewer_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + index=True, + comment="NULL for public shares (unauthenticated views)" + ) + viewed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + index=True + ) + viewer_ip: Mapped[Optional[str]] = mapped_column( + String(45), # IPv6 max length + nullable=True + ) + viewer_user_agent: Mapped[Optional[str]] = mapped_column( + String(500), + nullable=True + ) + + # Relationships + share: Mapped["SessionShare"] = relationship("SessionShare", back_populates="views") + viewer: Mapped[Optional["User"]] = relationship("User") diff --git a/backend/app/models/tree.py b/backend/app/models/tree.py index c9b305a7..344ec4e8 100644 --- a/backend/app/models/tree.py +++ b/backend/app/models/tree.py @@ -79,10 +79,62 @@ class Tree(Base): ) usage_count: Mapped[int] = mapped_column(Integer, default=0) + # Fork tracking + parent_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("trees.id", ondelete="SET NULL"), + nullable=True, + index=True + ) + fork_reason: Mapped[Optional[str]] = mapped_column( + String(255), + nullable=True, + comment="Brief reason: 'Added Cisco Meraki steps for our network'" + ) + parent_updated_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + comment="Snapshot of parent's updated_at when fork created. Compare to detect parent updates." + ) + + # Fork lineage tracking + root_tree_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUID(as_uuid=True), + ForeignKey("trees.id", ondelete="SET NULL"), + nullable=True, + index=True, + comment="Original tree at root of fork chain (NULL for non-forked trees)" + ) + fork_depth: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=0, + server_default="0", + comment="Fork depth: 0 = original, 1 = direct fork, 2 = fork of fork, etc." + ) + # Relationships author: Mapped[Optional["User"]] = relationship("User", foreign_keys=[author_id], back_populates="trees") team: Mapped[Optional["Team"]] = relationship("Team", back_populates="trees") account: Mapped[Optional["Account"]] = relationship("Account", foreign_keys=[account_id], back_populates="trees") + + # Fork relationships (self-referential) + parent: Mapped[Optional["Tree"]] = relationship( + "Tree", + remote_side="Tree.id", + foreign_keys=[parent_tree_id], + back_populates="forks" + ) + forks: Mapped[list["Tree"]] = relationship( + "Tree", + foreign_keys=[parent_tree_id], + back_populates="parent" + ) + root: Mapped[Optional["Tree"]] = relationship( + "Tree", + remote_side="Tree.id", + foreign_keys=[root_tree_id] + ) sessions: Mapped[list["Session"]] = relationship("Session", back_populates="tree") # New organization relationships diff --git a/backend/app/schemas/session.py b/backend/app/schemas/session.py index b4ab6377..57f5ea4d 100644 --- a/backend/app/schemas/session.py +++ b/backend/app/schemas/session.py @@ -1,9 +1,25 @@ from datetime import datetime -from typing import Optional, Any +from typing import Optional, Any, Literal from uuid import UUID from pydantic import BaseModel, Field, validator +class CustomStepSchema(BaseModel): + """Enhanced custom step with source tracking. + + Backward compatible: old sessions without new fields load with defaults. + """ + type: str # "decision" | "action" | "solution" + content: str + notes: Optional[str] = None + + # Source tracking (new fields, optional for backward compatibility) + source: Literal["ad-hoc", "step-library", "forked-tree"] = "ad-hoc" + source_step_id: Optional[UUID] = None + inserted_at: Optional[datetime] = None + inserted_after_node_id: Optional[str] = None + + class DecisionRecord(BaseModel): node_id: str question: Optional[str] = None @@ -24,7 +40,7 @@ class SessionCreate(BaseModel): class SessionUpdate(BaseModel): path_taken: Optional[list[str]] = None decisions: Optional[list[DecisionRecord]] = None - custom_steps: Optional[list[dict[str, Any]]] = None + custom_steps: Optional[list[CustomStepSchema]] = None ticket_number: Optional[str] = Field(None, max_length=100) client_name: Optional[str] = Field(None, max_length=255) scratchpad: Optional[str] = None diff --git a/backend/app/schemas/session_share.py b/backend/app/schemas/session_share.py new file mode 100644 index 00000000..40c23cb0 --- /dev/null +++ b/backend/app/schemas/session_share.py @@ -0,0 +1,47 @@ +from datetime import datetime +from typing import Optional, Literal +from uuid import UUID +from pydantic import BaseModel, Field + + +class ShareCreate(BaseModel): + visibility: Literal["public", "account"] = Field("public", description="Share visibility") + share_name: Optional[str] = Field(None, max_length=100, description="Optional label for the share") + expires_at: Optional[datetime] = Field(None, description="Optional expiration datetime") + + +class ShareResponse(BaseModel): + id: UUID + session_id: UUID + account_id: UUID + share_token: str + share_name: Optional[str] = None + visibility: str + created_by: UUID + created_at: datetime + updated_at: datetime + expires_at: Optional[datetime] = None + view_count: int + last_viewed_at: Optional[datetime] = None + is_active: bool + share_url: Optional[str] = None + + class Config: + from_attributes = True + + +class SharePublicView(BaseModel): + """Read-only session data returned when accessing a share link.""" + session_id: UUID + tree_name: str + tree_description: Optional[str] = None + tree_structure: dict + path_taken: list[str] + decisions: list[dict] + custom_steps: list[dict] = Field(default_factory=list) + started_at: datetime + completed_at: Optional[datetime] = None + ticket_number: Optional[str] = None + client_name: Optional[str] = None + share_name: Optional[str] = None + visibility: str diff --git a/backend/app/schemas/tree.py b/backend/app/schemas/tree.py index 86aa7567..2811ff7b 100644 --- a/backend/app/schemas/tree.py +++ b/backend/app/schemas/tree.py @@ -40,6 +40,24 @@ class TreeUpdate(BaseModel): tags: Optional[list[str]] = Field(None, max_length=10, description="List of tag names to assign (replaces existing)") +class ForkCreate(BaseModel): + fork_reason: Optional[str] = Field(None, max_length=255, description="Brief reason for forking") + name: Optional[str] = Field(None, min_length=1, max_length=255, description="Name for the fork (defaults to 'Fork of {original name}')") + + +class ForkInfo(BaseModel): + """Fork metadata included in tree responses.""" + parent_tree_id: Optional[UUID] = None + root_tree_id: Optional[UUID] = None + fork_reason: Optional[str] = None + fork_depth: int = 0 + parent_updated_at: Optional[datetime] = None + has_parent_updates: bool = False + + class Config: + from_attributes = True + + class TreeResponse(TreeBase): id: UUID tree_structure: dict[str, Any] @@ -48,6 +66,7 @@ class TreeResponse(TreeBase): category_id: Optional[UUID] = None category_info: Optional[CategoryInfo] = None tags: list[str] = [] # List of tag names + fork_info: Optional[ForkInfo] = None is_active: bool is_public: bool is_default: bool diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index e7b7667a..e40acb29 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -16,6 +16,8 @@ from app.main import app from app.core.database import Base, get_db from app.core.config import settings +# Disable invite code requirement for tests +settings.REQUIRE_INVITE_CODE = False # Test database URL (separate from production) TEST_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/patherly_test" diff --git a/backend/tests/test_session_sharing.py b/backend/tests/test_session_sharing.py new file mode 100644 index 00000000..1f097b19 --- /dev/null +++ b/backend/tests/test_session_sharing.py @@ -0,0 +1,239 @@ +"""Tests for session sharing (create, access, revoke).""" + +import pytest +from httpx import AsyncClient + + +pytestmark = pytest.mark.asyncio + + +class TestSessionSharing: + """Test session share creation, access control, and revocation.""" + + async def _create_session(self, client, auth_headers, tree_id): + """Helper: start a session for a tree.""" + response = await client.post( + "/api/v1/sessions", + json={"tree_id": tree_id}, + headers=auth_headers + ) + assert response.status_code == 201 + return response.json() + + async def test_create_public_share(self, client: AsyncClient, auth_headers, test_tree): + """Create a public share link.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + response = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public", "share_name": "Customer link"}, + headers=auth_headers + ) + assert response.status_code == 201 + + share = response.json() + assert share["visibility"] == "public" + assert share["share_name"] == "Customer link" + assert share["is_active"] is True + assert share["view_count"] == 0 + assert len(share["share_token"]) > 0 + + async def test_create_account_share(self, client: AsyncClient, auth_headers, test_tree): + """Create an account-only share link.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + response = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "account"}, + headers=auth_headers + ) + assert response.status_code == 201 + assert response.json()["visibility"] == "account" + + async def test_access_public_share(self, client: AsyncClient, auth_headers, test_tree): + """Access a public share without authentication.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + # Create share + share_resp = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public"}, + headers=auth_headers + ) + share_token = share_resp.json()["share_token"] + + # Access without auth + response = await client.get(f"/api/v1/share/{share_token}") + assert response.status_code == 200 + + data = response.json() + assert data["session_id"] == session["id"] + assert data["visibility"] == "public" + assert data["path_taken"] is not None + + async def test_access_revoked_share_returns_404(self, client: AsyncClient, auth_headers, test_tree): + """Accessing a revoked share returns 404.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + # Create and revoke share + share_resp = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public"}, + headers=auth_headers + ) + share = share_resp.json() + + await client.delete( + f"/api/v1/shares/{share['id']}", + headers=auth_headers + ) + + # Try to access revoked share + response = await client.get(f"/api/v1/share/{share['share_token']}") + assert response.status_code == 404 + + async def test_access_nonexistent_share_returns_404(self, client: AsyncClient): + """Accessing a nonexistent share token returns 404.""" + response = await client.get("/api/v1/share/nonexistent-token-12345") + assert response.status_code == 404 + + async def test_list_my_shares(self, client: AsyncClient, auth_headers, test_tree): + """List shares created by current user.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + # Create two shares + await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public", "share_name": "Link 1"}, + headers=auth_headers + ) + await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "account", "share_name": "Link 2"}, + headers=auth_headers + ) + + response = await client.get( + "/api/v1/shares/my-shares", + headers=auth_headers + ) + assert response.status_code == 200 + shares = response.json() + assert len(shares) == 2 + + async def test_revoke_share(self, client: AsyncClient, auth_headers, test_tree): + """Revoke a share link (soft delete).""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + share_resp = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public"}, + headers=auth_headers + ) + share = share_resp.json() + + # Revoke + response = await client.delete( + f"/api/v1/shares/{share['id']}", + headers=auth_headers + ) + assert response.status_code == 204 + + # Verify it's gone from my-shares + list_resp = await client.get( + "/api/v1/shares/my-shares", + headers=auth_headers + ) + shares = list_resp.json() + assert len(shares) == 0 + + async def test_multiple_shares_per_session(self, client: AsyncClient, auth_headers, test_tree): + """Multiple shares for the same session work independently.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + # Create public + account shares + resp1 = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public", "share_name": "For customer"}, + headers=auth_headers + ) + resp2 = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "account", "share_name": "For team"}, + headers=auth_headers + ) + + assert resp1.status_code == 201 + assert resp2.status_code == 201 + + # Both tokens are different + assert resp1.json()["share_token"] != resp2.json()["share_token"] + + # Both accessible + access1 = await client.get(f"/api/v1/share/{resp1.json()['share_token']}") + assert access1.status_code == 200 + + async def test_share_view_count_increments(self, client: AsyncClient, auth_headers, test_tree): + """View count increments on each access.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + share_resp = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public"}, + headers=auth_headers + ) + token = share_resp.json()["share_token"] + + # Access three times + await client.get(f"/api/v1/share/{token}") + await client.get(f"/api/v1/share/{token}") + await client.get(f"/api/v1/share/{token}") + + # Check view count via my-shares + list_resp = await client.get( + "/api/v1/shares/my-shares", + headers=auth_headers + ) + shares = list_resp.json() + assert shares[0]["view_count"] == 3 + + async def test_share_requires_session_ownership(self, client: AsyncClient, auth_headers, test_tree, test_db): + """Non-owner cannot create a share for someone else's session.""" + session = await self._create_session(client, auth_headers, test_tree["id"]) + + # Register a different user + await client.post("/api/v1/auth/register", json={ + "email": "other@example.com", + "password": "OtherPassword123!", + "name": "Other User" + }) + login_resp = await client.post("/api/v1/auth/login/json", json={ + "email": "other@example.com", + "password": "OtherPassword123!" + }) + other_headers = {"Authorization": f"Bearer {login_resp.json()['access_token']}"} + + # Try to share other user's session + response = await client.post( + f"/api/v1/sessions/{session['id']}/shares", + json={"visibility": "public"}, + headers=other_headers + ) + assert response.status_code == 403 + + async def test_share_nonexistent_session(self, client: AsyncClient, auth_headers): + """Creating a share for nonexistent session returns 404.""" + response = await client.post( + "/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares", + json={"visibility": "public"}, + headers=auth_headers + ) + assert response.status_code == 404 + + async def test_create_share_requires_auth(self, client: AsyncClient, test_tree): + """Creating a share without auth returns 401.""" + response = await client.post( + "/api/v1/sessions/00000000-0000-0000-0000-000000000000/shares", + json={"visibility": "public"} + ) + assert response.status_code == 401 diff --git a/backend/tests/test_tree_forking.py b/backend/tests/test_tree_forking.py new file mode 100644 index 00000000..b3ec7e64 --- /dev/null +++ b/backend/tests/test_tree_forking.py @@ -0,0 +1,168 @@ +"""Tests for tree forking and lineage tracking.""" + +import pytest +from httpx import AsyncClient + + +pytestmark = pytest.mark.asyncio + + +class TestTreeForking: + """Test tree fork creation, lineage, and update detection.""" + + async def test_fork_tree(self, client: AsyncClient, auth_headers, test_tree): + """Fork a tree and verify fork metadata.""" + response = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Customizing for our network"}, + headers=auth_headers + ) + assert response.status_code == 201 + + fork = response.json() + assert fork["name"] == f"Fork of {test_tree['name']}" + assert fork["tree_structure"] == test_tree["tree_structure"] + assert fork["is_public"] is False + assert fork["version"] == 1 + + # Verify fork_info + assert fork["fork_info"] is not None + assert fork["fork_info"]["parent_tree_id"] == test_tree["id"] + assert fork["fork_info"]["root_tree_id"] == test_tree["id"] + assert fork["fork_info"]["fork_reason"] == "Customizing for our network" + assert fork["fork_info"]["fork_depth"] == 1 + assert fork["fork_info"]["has_parent_updates"] is False + + async def test_fork_with_custom_name(self, client: AsyncClient, auth_headers, test_tree): + """Fork a tree with a custom name.""" + response = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"name": "My Custom Fork", "fork_reason": "Testing"}, + headers=auth_headers + ) + assert response.status_code == 201 + assert response.json()["name"] == "My Custom Fork" + + async def test_fork_of_fork_lineage(self, client: AsyncClient, auth_headers, test_tree): + """Fork a fork and verify lineage tracking (root_tree_id, fork_depth).""" + # Create first fork + resp1 = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Junior's customization"}, + headers=auth_headers + ) + assert resp1.status_code == 201 + fork1 = resp1.json() + + # Fork the fork + resp2 = await client.post( + f"/api/v1/trees/{fork1['id']}/fork", + json={"fork_reason": "Senior's refinement"}, + headers=auth_headers + ) + assert resp2.status_code == 201 + fork2 = resp2.json() + + # Verify lineage + assert fork2["fork_info"]["parent_tree_id"] == fork1["id"] + assert fork2["fork_info"]["root_tree_id"] == test_tree["id"] # Points to original + assert fork2["fork_info"]["fork_depth"] == 2 + assert fork2["fork_info"]["fork_reason"] == "Senior's refinement" + + async def test_fork_nonexistent_tree(self, client: AsyncClient, auth_headers): + """Fork a nonexistent tree returns 404.""" + response = await client.post( + "/api/v1/trees/00000000-0000-0000-0000-000000000000/fork", + json={"fork_reason": "test"}, + headers=auth_headers + ) + assert response.status_code == 404 + + async def test_list_forks(self, client: AsyncClient, auth_headers, test_tree): + """List forks of a tree.""" + # Create two forks + await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Fork 1"}, + headers=auth_headers + ) + await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Fork 2"}, + headers=auth_headers + ) + + response = await client.get( + f"/api/v1/trees/{test_tree['id']}/forks", + headers=auth_headers + ) + assert response.status_code == 200 + forks = response.json() + assert len(forks) == 2 + + async def test_lineage_chain(self, client: AsyncClient, auth_headers, test_tree): + """Get lineage from fork back to root.""" + # Create chain: root → fork1 → fork2 + resp1 = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Level 1"}, + headers=auth_headers + ) + fork1 = resp1.json() + + resp2 = await client.post( + f"/api/v1/trees/{fork1['id']}/fork", + json={"fork_reason": "Level 2"}, + headers=auth_headers + ) + fork2 = resp2.json() + + # Get lineage from fork2 + response = await client.get( + f"/api/v1/trees/{fork2['id']}/lineage", + headers=auth_headers + ) + assert response.status_code == 200 + lineage = response.json() + + # Should be [fork2, fork1, root] + assert len(lineage) == 3 + assert lineage[0]["id"] == fork2["id"] + assert lineage[1]["id"] == fork1["id"] + assert lineage[2]["id"] == test_tree["id"] + + async def test_lineage_of_root_tree(self, client: AsyncClient, auth_headers, test_tree): + """Root tree lineage is just itself.""" + response = await client.get( + f"/api/v1/trees/{test_tree['id']}/lineage", + headers=auth_headers + ) + assert response.status_code == 200 + lineage = response.json() + assert len(lineage) == 1 + assert lineage[0]["id"] == test_tree["id"] + + async def test_fork_preserves_tree_structure(self, client: AsyncClient, auth_headers, test_tree): + """Fork copies the complete tree_structure JSONB.""" + response = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "Copy check"}, + headers=auth_headers + ) + fork = response.json() + + # Get full fork detail + detail = await client.get( + f"/api/v1/trees/{fork['id']}", + headers=auth_headers + ) + assert detail.status_code == 200 + assert detail.json()["tree_structure"] == test_tree["tree_structure"] + + async def test_fork_requires_auth(self, client: AsyncClient, test_tree): + """Fork without auth returns 401.""" + response = await client.post( + f"/api/v1/trees/{test_tree['id']}/fork", + json={"fork_reason": "No auth"} + ) + assert response.status_code == 401