import { useCallback } from 'react' import type { Node, Edge } from '@xyflow/react' interface UseDiagramCommandsParams { nodes: Node[] edges: Edge[] pushHistory: (nodes: Node[], edges: Edge[]) => void setNodes: React.Dispatch> } export function useDiagramCommands({ nodes, edges, pushHistory, setNodes, }: UseDiagramCommandsParams) { const selectedNodes = nodes.filter(n => n.selected) // ── Alignment ────────────────────────────────────────────────────────── const alignLeft = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const minX = Math.min(...selectedNodes.map(n => n.position.x)) setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, x: minX } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const alignRight = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const maxX = Math.max(...selectedNodes.map(n => n.position.x + (n.measured?.width ?? 100))) setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, x: maxX - (n.measured?.width ?? 100) } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const alignCenterH = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const minX = Math.min(...selectedNodes.map(n => n.position.x)) const maxX = Math.max(...selectedNodes.map(n => n.position.x + (n.measured?.width ?? 100))) const centerX = (minX + maxX) / 2 setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, x: centerX - (n.measured?.width ?? 100) / 2 } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const alignTop = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const minY = Math.min(...selectedNodes.map(n => n.position.y)) setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, y: minY } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const alignBottom = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const maxY = Math.max(...selectedNodes.map(n => n.position.y + (n.measured?.height ?? 100))) setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, y: maxY - (n.measured?.height ?? 100) } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const alignCenterV = useCallback(() => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const minY = Math.min(...selectedNodes.map(n => n.position.y)) const maxY = Math.max(...selectedNodes.map(n => n.position.y + (n.measured?.height ?? 100))) const centerY = (minY + maxY) / 2 setNodes(prev => prev.map(n => n.selected ? { ...n, position: { ...n.position, y: centerY - (n.measured?.height ?? 100) / 2 } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) // ── Distribution ─────────────────────────────────────────────────────── const distributeHorizontally = useCallback(() => { if (selectedNodes.length < 3) return pushHistory(nodes, edges) const sorted = [...selectedNodes].sort((a, b) => a.position.x - b.position.x) const minX = sorted[0].position.x const maxX = sorted[sorted.length - 1].position.x + (sorted[sorted.length - 1].measured?.width ?? 100) const totalWidth = sorted.reduce((sum, n) => sum + (n.measured?.width ?? 100), 0) const gap = (maxX - minX - totalWidth) / (sorted.length - 1) let cursor = minX const positions: Record = {} for (const n of sorted) { positions[n.id] = cursor cursor += (n.measured?.width ?? 100) + gap } setNodes(prev => prev.map(n => n.selected && positions[n.id] !== undefined ? { ...n, position: { ...n.position, x: positions[n.id] } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const distributeVertically = useCallback(() => { if (selectedNodes.length < 3) return pushHistory(nodes, edges) const sorted = [...selectedNodes].sort((a, b) => a.position.y - b.position.y) const minY = sorted[0].position.y const maxY = sorted[sorted.length - 1].position.y + (sorted[sorted.length - 1].measured?.height ?? 100) const totalHeight = sorted.reduce((sum, n) => sum + (n.measured?.height ?? 100), 0) const gap = (maxY - minY - totalHeight) / (sorted.length - 1) let cursor = minY const positions: Record = {} for (const n of sorted) { positions[n.id] = cursor cursor += (n.measured?.height ?? 100) + gap } setNodes(prev => prev.map(n => n.selected && positions[n.id] !== undefined ? { ...n, position: { ...n.position, y: positions[n.id] } } : n )) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) // ── Helpers ──────────────────────────────────────────────────────────── const canAlign = selectedNodes.length >= 2 const canDistribute = selectedNodes.length >= 3 // ── Grouping ─────────────────────────────────────────────────────────── const groupSelection = useCallback((groupType: string = 'custom') => { if (selectedNodes.length < 2) return pushHistory(nodes, edges) const PADDING = 24 const minX = Math.min(...selectedNodes.map(n => n.position.x)) - PADDING const minY = Math.min(...selectedNodes.map(n => n.position.y)) - PADDING const maxX = Math.max(...selectedNodes.map(n => n.position.x + (n.measured?.width ?? 100))) + PADDING const maxY = Math.max(...selectedNodes.map(n => n.position.y + (n.measured?.height ?? 100))) + PADDING const groupId = `group-${Date.now()}` const groupNode: Node = { id: groupId, type: 'group', position: { x: minX, y: minY }, style: { width: maxX - minX, height: maxY - minY }, data: { label: groupType.charAt(0).toUpperCase() + groupType.slice(1), groupType }, selected: false, } setNodes(prev => [ groupNode, ...prev.map(n => n.selected ? { ...n, parentId: groupId, extent: 'parent' as const, position: { x: n.position.x - minX, y: n.position.y - minY }, selected: false, } : n ), ]) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const ungroupSelection = useCallback(() => { const selectedGroups = selectedNodes.filter(n => n.type === 'group') if (selectedGroups.length === 0) return pushHistory(nodes, edges) const groupIds = new Set(selectedGroups.map(g => g.id)) setNodes(prev => { const groupPositions: Record = {} for (const n of prev) { if (groupIds.has(n.id)) groupPositions[n.id] = n.position } return prev .filter(n => !groupIds.has(n.id)) .map(n => { if (n.parentId && groupIds.has(n.parentId)) { const gPos = groupPositions[n.parentId] ?? { x: 0, y: 0 } return { ...n, parentId: undefined, extent: undefined, position: { x: gPos.x + n.position.x, y: gPos.y + n.position.y }, } } return n }) }) }, [nodes, edges, selectedNodes, pushHistory, setNodes]) const canGroup = selectedNodes.length >= 2 && !selectedNodes.some(n => n.type === 'group') const canUngroup = selectedNodes.some(n => n.type === 'group') return { alignLeft, alignRight, alignCenterH, alignTop, alignBottom, alignCenterV, distributeHorizontally, distributeVertically, canAlign, canDistribute, selectedNodes, groupSelection, ungroupSelection, canGroup, canUngroup, } }