"content": "\"use client\";\n\nimport { useComposedRefs } from \"@/lib/composition\";\nimport { cn } from \"@/lib/utils\";\nimport { Slot } from \"@radix-ui/react-slot\";\nimport * as React from \"react\";\nimport * as ReactDOM from \"react-dom\";\n\nconst ROOT_NAME = \"MasonryRoot\";\nconst ITEM_NAME = \"MasonryItem\";\n\nconst DATA_LINE_BREAK_ATTR = \"data-masonry-line-break\";\nconst DATA_ITEM_ATTR = \"data-masonry-item\";\n\nconst COLUMN_COUNT = 4;\nconst GAP = 12;\nconst CACHE_MAX_AGE = 5000;\n\nconst MASONRY_ERROR = {\n [ROOT_NAME]: `\\`${ROOT_NAME}\\` components must be within \\`${ROOT_NAME}\\``,\n [ITEM_NAME]: `\\`${ITEM_NAME}\\` must be within \\`${ROOT_NAME}\\``,\n} as const;\n\nconst TAILWIND_BREAKPOINTS = {\n initial: 0,\n sm: 640,\n md: 768,\n lg: 1024,\n xl: 1280,\n \"2xl\": 1536,\n} as const;\n\ntype TailwindBreakpoint = keyof typeof TAILWIND_BREAKPOINTS;\ntype BreakpointValue = TailwindBreakpoint | number;\ntype ResponsiveObject = Partial<Record<BreakpointValue, number>>;\ntype ResponsiveValue = number | ResponsiveObject;\n\nfunction parseBreakpoint(breakpoint: BreakpointValue): number {\n if (typeof breakpoint === \"number\") return breakpoint;\n return breakpoint in TAILWIND_BREAKPOINTS\n ? TAILWIND_BREAKPOINTS[breakpoint]\n : Number(breakpoint);\n}\n\nfunction getInitialValue(value: ResponsiveValue, defaultValue: number): number {\n if (typeof value === \"number\") return value;\n if (\"initial\" in value) return value.initial ?? defaultValue;\n\n const breakpoints = Object.entries(value)\n .map(([key, val]) => ({\n breakpoint: parseBreakpoint(key as BreakpointValue),\n value: val ?? defaultValue,\n }))\n .sort((a, b) => a.breakpoint - b.breakpoint);\n\n return breakpoints[0]?.value ?? defaultValue;\n}\n\nfunction useResponsiveValue({\n value,\n defaultValue,\n mounted,\n}: {\n value: ResponsiveValue;\n defaultValue: number;\n mounted: boolean;\n}): number {\n const initialValue = React.useMemo(\n () => getInitialValue(value, defaultValue),\n [value, defaultValue],\n );\n const [currentValue, setCurrentValue] = React.useState(initialValue);\n\n const onResize = React.useCallback(() => {\n if (!mounted) return;\n if (typeof value === \"number\") {\n setCurrentValue(value);\n return;\n }\n\n const width = window.innerWidth;\n const breakpoints = Object.entries(value)\n .map(([key, val]) => ({\n breakpoint:\n key === \"initial\" ? 0 : parseBreakpoint(key as BreakpointValue),\n value: val ?? defaultValue,\n }))\n .sort((a, b) => b.breakpoint - a.breakpoint);\n\n const newValue =\n breakpoints.find(({ breakpoint }) => width >= breakpoint)?.value ??\n defaultValue;\n setCurrentValue(newValue);\n }, [value, defaultValue, mounted]);\n\n React.useEffect(() => {\n if (!mounted) return;\n\n onResize();\n window.addEventListener(\"resize\", onResize);\n return () => window.removeEventListener(\"resize\", onResize);\n }, [onResize, mounted]);\n\n return currentValue;\n}\n\nconst useIsomorphicLayoutEffect =\n typeof window !== \"undefined\" ? React.useLayoutEffect : React.useEffect;\n\ninterface MasonryContextValue {\n mounted: boolean;\n}\n\nconst MasonryContext = React.createContext<MasonryContextValue | null>(null);\nMasonryContext.displayName = ROOT_NAME;\n\nfunction useMasonryContext(name: keyof typeof MASONRY_ERROR) {\n const context = React.useContext(MasonryContext);\n if (!context) {\n throw new Error(MASONRY_ERROR[name]);\n }\n return context;\n}\n\ntype ItemElement = React.ComponentRef<typeof MasonryItem>;\n\ninterface ItemMeasurement {\n height: number;\n width: number;\n marginTop: number;\n marginBottom: number;\n}\n\ninterface ItemCache {\n measurements: Map<ItemElement, ItemMeasurement>;\n lastUpdate: number;\n timestamps: Map<ItemElement, number>;\n}\n\ninterface MasonryProps extends React.ComponentPropsWithoutRef<\"div\"> {\n columnCount?: ResponsiveValue;\n defaultColumnCount?: number;\n gap?: ResponsiveValue;\n defaultGap?: number;\n linear?: boolean;\n asChild?: boolean;\n}\n\nconst Masonry = React.forwardRef<HTMLDivElement, MasonryProps>(\n (props, forwardedRef) => {\n const {\n children,\n columnCount = COLUMN_COUNT,\n defaultColumnCount = columnCount,\n gap = GAP,\n defaultGap = gap,\n linear = false,\n asChild,\n className,\n style,\n ...rootProps\n } = props;\n const [maxColumnHeight, setMaxColumnHeight] = React.useState<number | null>(\n null,\n );\n const resizeObserverRef = React.useRef<ResizeObserver | null>(null);\n const rafIdRef = React.useRef<number | null>(null);\n const itemCacheRef = React.useRef<ItemCache>({\n measurements: new Map(),\n lastUpdate: 0,\n timestamps: new Map(),\n });\n const collectionRef = React.useRef<HTMLDivElement>(null);\n const composedRef = useComposedRefs(forwardedRef, collectionRef);\n\n const [mounted, setMounted] = React.useState(false);\n React.useLayoutEffect(() => {\n setMounted(true);\n }, []);\n\n const currentColumnCount = useResponsiveValue({\n value: columnCount,\n defaultValue: COLUMN_COUNT,\n mounted,\n });\n const currentGap = useResponsiveValue({\n value: gap,\n defaultValue: GAP,\n mounted,\n });\n const lineBreakCount = currentColumnCount > 0 ? currentColumnCount - 1 : 0;\n\n const getMeasurements = React.useCallback(\n (item: ItemElement): ItemMeasurement | null => {\n const cached = itemCacheRef.current.measurements.get(item);\n const timestamp = itemCacheRef.current.timestamps.get(item);\n const now = Date.now();\n\n if (cached && timestamp && now - timestamp < CACHE_MAX_AGE) {\n return cached;\n }\n\n const itemStyle = window.getComputedStyle(item);\n const marginTop =\n Number.parseFloat(itemStyle.marginTop) || currentGap / 2;\n const marginBottom =\n Number.parseFloat(itemStyle.marginBottom) || currentGap / 2;\n const height = item.offsetHeight + marginTop + marginBottom;\n const width = item.offsetWidth;\n\n if (\n height === 0 ||\n Array.from(item.getElementsByTagName(\"img\")).some(\n (img) => img.clientHeight === 0,\n )\n ) {\n return null;\n }\n\n const measurements = { height, width, marginTop, marginBottom };\n itemCacheRef.current.measurements.set(item, measurements);\n itemCacheRef.current.timestamps.set(item, now);\n itemCacheRef.current.lastUpdate = now;\n\n return measurements;\n },\n [currentGap],\n );\n\n const invalidateCache = React.useCallback(() => {\n itemCacheRef.current.measurements.clear();\n itemCacheRef.current.timestamps.clear();\n itemCacheRef.current.lastUpdate = Date.now();\n }, []);\n\n const calculateLayout = React.useCallback(() => {\n if (!collectionRef.current || !mounted) return;\n\n const items = Array.from(\n collectionRef.current.querySelectorAll<ItemElement>(\n `[${DATA_ITEM_ATTR}]`,\n ),\n );\n\n const columnHeights = new Array(currentColumnCount).fill(0);\n let skip = false;\n let nextOrder = 1;\n\n for (const item of items) {\n if (item.dataset[DATA_LINE_BREAK_ATTR] === \"\") continue;\n const styles: Partial<CSSStyleDeclaration> = {\n position: \"\",\n top: \"\",\n left: \"\",\n width: `calc(${100 / currentColumnCount}% - ${(currentGap * (currentColumnCount - 1)) / currentColumnCount}px)`,\n margin: `${currentGap / 2}px`,\n };\n Object.assign(item.style, styles);\n }\n\n for (const item of items) {\n if (item.dataset[DATA_LINE_BREAK_ATTR] === \"\" || skip) continue;\n\n const itemMeasurement = getMeasurements(item);\n if (!itemMeasurement) {\n skip = true;\n continue;\n }\n\n if (linear) {\n const yPos = columnHeights[nextOrder - 1];\n Object.assign(item.style, {\n position: \"absolute\",\n top: `${yPos}px`,\n left: `${(nextOrder - 1) * (itemMeasurement.width + currentGap)}px`,\n });\n\n columnHeights[nextOrder - 1] = yPos + itemMeasurement.height;\n nextOrder = (nextOrder % currentColumnCount) + 1;\n } else {\n const minColumnIndex = columnHeights.indexOf(\n Math.min(...columnHeights),\n );\n const xPos = minColumnIndex * (itemMeasurement.width + currentGap);\n const yPos = columnHeights[minColumnIndex];\n\n Object.assign(item.style, {\n position: \"absolute\",\n top: `${yPos}px`,\n left: `${xPos}px`,\n });\n\n columnHeights[minColumnIndex] = yPos + itemMeasurement.height;\n }\n }\n\n if (!skip) {\n /**\n * Use flushSync to prevent layout thrashing during React 18 batching\n * @see https://github.com/facebook/react/blob/a8a4742f1c54493df00da648a3f9d26e3db9c8b5/packages/react-dom/src/events/ReactDOMEventListener.js#L294-L350\n */\n ReactDOM.flushSync(() => {\n const maxHeight = Math.max(...columnHeights);\n setMaxColumnHeight(maxHeight > 0 ? maxHeight : null);\n });\n }\n }, [currentColumnCount, currentGap, linear, mounted, getMeasurements]);\n\n useIsomorphicLayoutEffect(() => {\n if (typeof ResizeObserver === \"undefined\") return;\n\n const cleanupResizeObserver = () => {\n if (rafIdRef.current) {\n cancelAnimationFrame(rafIdRef.current);\n }\n if (resizeObserverRef.current) {\n resizeObserverRef.current.disconnect();\n }\n };\n\n resizeObserverRef.current = new ResizeObserver(() => {\n invalidateCache();\n rafIdRef.current = requestAnimationFrame(calculateLayout);\n });\n\n const content = collectionRef.current;\n if (content) {\n resizeObserverRef.current.observe(content);\n for (const child of Array.from(content.children)) {\n resizeObserverRef.current.observe(child);\n }\n }\n\n return cleanupResizeObserver;\n }, [calculateLayout, invalidateCache]);\n\n const initialGridStyle = React.useMemo(\n () => ({\n display: mounted ? \"block\" : \"grid\",\n gridTemplateColumns: !mounted\n ? `repeat(${getInitialValue(defaultColumnCount, 4)}, 1fr)`\n : undefined,\n gap: !mounted ? `${getInitialValue(defaultGap, 16)}px` : undefined,\n }),\n [mounted, defaultColumnCount, defaultGap],\n );\n\n const containerStyle = React.useMemo(\n () => ({\n ...style,\n ...initialGridStyle,\n height: mounted && maxColumnHeight ? `${maxColumnHeight}px` : \"auto\",\n minHeight: \"0px\",\n width: mounted ? `calc(100% - ${currentGap}px)` : \"100%\",\n marginLeft: mounted ? `${currentGap / 2}px` : undefined,\n marginRight: mounted ? `${currentGap / 2}px` : undefined,\n }),\n [style, initialGridStyle, mounted, maxColumnHeight, currentGap],\n );\n\n const contextValue = React.useMemo(() => ({ mounted }), [mounted]);\n\n const RootSlot = asChild ? Slot : \"div\";\n\n return (\n <MasonryContext.Provider value={contextValue}>\n <RootSlot\n {...rootProps}\n ref={composedRef}\n className={cn(\"relative mx-auto w-full\", className)}\n style={containerStyle}\n >\n {children}\n <LineBreaks\n lineBreakCount={lineBreakCount}\n currentColumnCount={currentColumnCount}\n />\n </RootSlot>\n </MasonryContext.Provider>\n );\n },\n);\n\nMasonry.displayName = ROOT_NAME;\n\ninterface LineBreaksProps {\n lineBreakCount: number;\n currentColumnCount: number;\n}\n\nconst LineBreaks = React.memo(\n function LineBreaks({ lineBreakCount, currentColumnCount }: LineBreaksProps) {\n return (\n <>\n {Array.from({ length: lineBreakCount }, (_, i) => {\n const key = `line-break-${currentColumnCount}-${i}`;\n return (\n <span\n key={key}\n {...{ [DATA_LINE_BREAK_ATTR]: \"\" }}\n style={{\n flexBasis: \"100%\",\n width: 0,\n margin: 0,\n padding: 0,\n order: i + 1,\n }}\n />\n );\n })}\n </>\n );\n },\n (prevProps, nextProps) => {\n return prevProps.lineBreakCount === nextProps.lineBreakCount;\n },\n);\n\ninterface MasonryItemProps extends React.ComponentPropsWithoutRef<\"div\"> {\n asChild?: boolean;\n fallback?: React.ReactNode;\n}\n\nconst MasonryItem = React.forwardRef<HTMLDivElement, MasonryItemProps>(\n (props, forwardedRef) => {\n const { asChild, fallback, ...itemProps } = props;\n const context = useMasonryContext(ITEM_NAME);\n\n if (!context.mounted && fallback) {\n return fallback;\n }\n\n const ItemSlot = asChild ? Slot : \"div\";\n\n return (\n <ItemSlot\n {...{ [DATA_ITEM_ATTR]: \"\" }}\n {...itemProps}\n ref={forwardedRef}\n />\n );\n },\n);\n\nMasonryItem.displayName = ITEM_NAME;\n\nconst Root = Masonry;\nconst Item = MasonryItem;\n\nexport {\n Masonry,\n MasonryItem,\n //\n Root,\n Item,\n};\n",
0 commit comments