diff --git a/examples/distance-rescale/index.html b/examples/distance-rescale/index.html new file mode 100644 index 00000000..ecaad5f8 --- /dev/null +++ b/examples/distance-rescale/index.html @@ -0,0 +1,109 @@ + + + + + + + Spark - Distance Measurement & Rescale + + + + +
Click on the model to select first measurement point
+ +
+ + +
+ +
+
Distance
+
0.000
+
+ + + + + + diff --git a/examples/distance-rescale/main.js b/examples/distance-rescale/main.js new file mode 100644 index 00000000..e15aea5a --- /dev/null +++ b/examples/distance-rescale/main.js @@ -0,0 +1,717 @@ +import { PlyWriter, SparkRenderer, SplatMesh } from "@sparkjsdev/spark"; +import { GUI } from "lil-gui"; +import * as THREE from "three"; +import { OrbitControls } from "three/addons/controls/OrbitControls.js"; +import { getAssetFileURL } from "/examples/js/get-asset-url.js"; + +// ============================================================================ +// Scene Setup +// ============================================================================ + +const scene = new THREE.Scene(); +const camera = new THREE.PerspectiveCamera( + 60, + window.innerWidth / window.innerHeight, + 0.1, + 100000, +); +const renderer = new THREE.WebGLRenderer({ antialias: false }); +renderer.setSize(window.innerWidth, window.innerHeight); +document.body.appendChild(renderer.domElement); + +const spark = new SparkRenderer({ renderer }); +scene.add(spark); + +// Camera controls - using OrbitControls for reliability +const controls = new OrbitControls(camera, renderer.domElement); +controls.enableDamping = true; +controls.dampingFactor = 0.05; +camera.position.set(0, 2, 5); +camera.lookAt(0, 0, 0); + +window.addEventListener("resize", onWindowResize, false); +function onWindowResize() { + camera.aspect = window.innerWidth / window.innerHeight; + camera.updateProjectionMatrix(); + renderer.setSize(window.innerWidth, window.innerHeight); +} + +// ============================================================================ +// State Management +// ============================================================================ + +const state = { + // Point 1 + point1: null, + ray1Origin: null, + ray1Direction: null, + marker1: null, + rayLine1: null, + + // Point 2 + point2: null, + ray2Origin: null, + ray2Direction: null, + marker2: null, + rayLine2: null, + + // Measurement + distanceLine: null, + currentDistance: 0, + + // Interaction + mode: "select1", // 'select1' | 'select2' | 'complete' + dragging: null, // 'point1' | 'point2' | null +}; + +let splatMesh = null; +const raycaster = new THREE.Raycaster(); + +// ============================================================================ +// Visual Elements +// ============================================================================ + +let rayLineLength = 100; // Will be updated based on model size +const MARKER_SCREEN_SIZE = 0.03; // Constant screen-space size (percentage of screen height) +const POINT1_COLOR = 0x00ff00; // Green +const POINT2_COLOR = 0x0088ff; // Blue +const DISTANCE_LINE_COLOR = 0xffff00; // Yellow + +function createMarker(color) { + // Create a group to hold both the sphere and its outline + // Use unit size - will be scaled dynamically based on camera distance + const group = new THREE.Group(); + + // Inner sphere (unit radius = 1) + const geometry = new THREE.SphereGeometry(1, 16, 16); + const material = new THREE.MeshBasicMaterial({ + color, + depthTest: false, + transparent: true, + opacity: 0.9, + }); + const mesh = new THREE.Mesh(geometry, material); + mesh.renderOrder = 1000; + group.add(mesh); + + // Outer ring/outline for better visibility + const ringGeometry = new THREE.RingGeometry(1.2, 1.8, 32); + const ringMaterial = new THREE.MeshBasicMaterial({ + color: 0xffffff, + depthTest: false, + transparent: true, + opacity: 0.8, + side: THREE.DoubleSide, + }); + const ring = new THREE.Mesh(ringGeometry, ringMaterial); + ring.renderOrder = 999; + group.add(ring); + + // Make ring always face camera (billboard) + group.userData.ring = ring; + + return group; +} + +function createRayLine(origin, direction, color) { + const farPoint = origin + .clone() + .add(direction.clone().multiplyScalar(rayLineLength)); + const geometry = new THREE.BufferGeometry().setFromPoints([origin, farPoint]); + const material = new THREE.LineBasicMaterial({ + color, + depthTest: false, + transparent: true, + opacity: 0.6, + }); + const line = new THREE.Line(geometry, material); + line.renderOrder = 998; + return line; +} + +function updateRayLine(line, origin, direction) { + const positions = line.geometry.attributes.position.array; + const farPoint = origin + .clone() + .add(direction.clone().multiplyScalar(rayLineLength)); + positions[0] = origin.x; + positions[1] = origin.y; + positions[2] = origin.z; + positions[3] = farPoint.x; + positions[4] = farPoint.y; + positions[5] = farPoint.z; + line.geometry.attributes.position.needsUpdate = true; +} + +function createDistanceLine() { + const geometry = new THREE.BufferGeometry().setFromPoints([ + new THREE.Vector3(), + new THREE.Vector3(), + ]); + const material = new THREE.LineBasicMaterial({ + color: DISTANCE_LINE_COLOR, + depthTest: false, + linewidth: 2, + }); + const line = new THREE.Line(geometry, material); + line.renderOrder = 997; + return line; +} + +function updateDistanceLine() { + if (!state.distanceLine || !state.point1 || !state.point2) return; + + const positions = state.distanceLine.geometry.attributes.position.array; + positions[0] = state.point1.x; + positions[1] = state.point1.y; + positions[2] = state.point1.z; + positions[3] = state.point2.x; + positions[4] = state.point2.y; + positions[5] = state.point2.z; + state.distanceLine.geometry.attributes.position.needsUpdate = true; +} + +// ============================================================================ +// Mouse / Touch Utilities +// ============================================================================ + +function getMouseNDC(event) { + const rect = renderer.domElement.getBoundingClientRect(); + return new THREE.Vector2( + ((event.clientX - rect.left) / rect.width) * 2 - 1, + -((event.clientY - rect.top) / rect.height) * 2 + 1, + ); +} + +function getHitPoint(ndc) { + if (!splatMesh) return null; + raycaster.setFromCamera(ndc, camera); + const hits = raycaster.intersectObject(splatMesh, false); + if (hits && hits.length > 0) { + return hits[0].point.clone(); + } + return null; +} + +// ============================================================================ +// Point Selection +// ============================================================================ + +function selectPoint1(hitPoint) { + state.point1 = hitPoint.clone(); + state.ray1Origin = camera.position.clone(); + state.ray1Direction = raycaster.ray.direction.clone(); + + // Create marker + if (state.marker1) scene.remove(state.marker1); + state.marker1 = createMarker(POINT1_COLOR); + state.marker1.position.copy(hitPoint); + scene.add(state.marker1); + + // Create ray line + if (state.rayLine1) scene.remove(state.rayLine1); + state.rayLine1 = createRayLine( + state.ray1Origin, + state.ray1Direction, + POINT1_COLOR, + ); + scene.add(state.rayLine1); + + state.mode = "select2"; + updateInstructions("Click on the model to select second measurement point"); +} + +function selectPoint2(hitPoint) { + state.point2 = hitPoint.clone(); + state.ray2Origin = camera.position.clone(); + state.ray2Direction = raycaster.ray.direction.clone(); + + // Create marker + if (state.marker2) scene.remove(state.marker2); + state.marker2 = createMarker(POINT2_COLOR); + state.marker2.position.copy(hitPoint); + scene.add(state.marker2); + + // Create ray line + if (state.rayLine2) scene.remove(state.rayLine2); + state.rayLine2 = createRayLine( + state.ray2Origin, + state.ray2Direction, + POINT2_COLOR, + ); + scene.add(state.rayLine2); + + // Create distance line + if (!state.distanceLine) { + state.distanceLine = createDistanceLine(); + scene.add(state.distanceLine); + } + updateDistanceLine(); + + state.mode = "complete"; + calculateDistance(); + updateInstructions("Drag markers to adjust position along ray lines"); +} + +// ============================================================================ +// Drag Along Ray +// ============================================================================ + +function closestPointOnRay(viewRay, rayOrigin, rayDir, currentPoint) { + // Find the point on the selection ray closest to the view ray + const w0 = rayOrigin.clone().sub(viewRay.origin); + const a = rayDir.dot(rayDir); + const b = rayDir.dot(viewRay.direction); + const c = viewRay.direction.dot(viewRay.direction); + const d = rayDir.dot(w0); + const e = viewRay.direction.dot(w0); + + const denom = a * c - b * b; + if (Math.abs(denom) < 0.0001) { + // Rays are nearly parallel - keep current point + return currentPoint.clone(); + } + + const t = (b * e - c * d) / denom; + + // Very minimal clamping - just prevent going behind ray origin or too far + const minT = 0.01; // Almost at ray origin + const maxT = rayLineLength * 2; // Allow movement beyond visible ray line + const clampedT = Math.max(minT, Math.min(maxT, t)); + return rayOrigin.clone().add(rayDir.clone().multiplyScalar(clampedT)); +} + +function checkMarkerHit(ndc) { + raycaster.setFromCamera(ndc, camera); + + const objects = []; + if (state.marker1) objects.push(state.marker1); + if (state.marker2) objects.push(state.marker2); + + if (objects.length === 0) return null; + + // Use recursive=true to hit children (sphere and ring inside group) + const hits = raycaster.intersectObjects(objects, true); + if (hits.length > 0) { + // Check if the hit object or its parent is marker1 or marker2 + let hitObj = hits[0].object; + while (hitObj) { + if (hitObj === state.marker1) return "point1"; + if (hitObj === state.marker2) return "point2"; + hitObj = hitObj.parent; + } + } + return null; +} + +// ============================================================================ +// Distance Calculation +// ============================================================================ + +function calculateDistance() { + if (!state.point1 || !state.point2) { + state.currentDistance = 0; + return; + } + + state.currentDistance = state.point1.distanceTo(state.point2); + updateDistanceDisplay(state.currentDistance); + guiParams.measuredDistance = state.currentDistance.toFixed(4); +} + +function updateDistanceDisplay(distance) { + const display = document.getElementById("distance-display"); + const value = document.getElementById("distance-value"); + display.style.display = "block"; + value.textContent = distance.toFixed(4); +} + +// ============================================================================ +// Rescaling +// ============================================================================ + +function rescaleModel(newDistance) { + if (!splatMesh || state.currentDistance <= 0) { + console.warn("Cannot rescale: no model or zero distance"); + return; + } + + const scaleFactor = newDistance / state.currentDistance; + + // Scale all splat centers and scales + splatMesh.packedSplats.forEachSplat( + (i, center, scales, quat, opacity, color) => { + center.multiplyScalar(scaleFactor); + scales.multiplyScalar(scaleFactor); + splatMesh.packedSplats.setSplat(i, center, scales, quat, opacity, color); + }, + ); + + splatMesh.packedSplats.needsUpdate = true; + + // Update points and markers + if (state.point1) { + state.point1.multiplyScalar(scaleFactor); + state.marker1.position.copy(state.point1); + state.ray1Origin.multiplyScalar(scaleFactor); + updateRayLine(state.rayLine1, state.ray1Origin, state.ray1Direction); + } + + if (state.point2) { + state.point2.multiplyScalar(scaleFactor); + state.marker2.position.copy(state.point2); + state.ray2Origin.multiplyScalar(scaleFactor); + updateRayLine(state.rayLine2, state.ray2Origin, state.ray2Direction); + } + + updateDistanceLine(); + state.currentDistance = newDistance; + updateDistanceDisplay(newDistance); + guiParams.measuredDistance = newDistance.toFixed(4); +} + +// ============================================================================ +// Reset +// ============================================================================ + +function disposeObject(obj) { + if (!obj) return; + scene.remove(obj); + obj.traverse((child) => { + if (child.geometry) child.geometry.dispose(); + if (child.material) { + if (Array.isArray(child.material)) { + for (const m of child.material) { + m.dispose(); + } + } else { + child.material.dispose(); + } + } + }); +} + +function resetSelection() { + // Remove and dispose visual elements + disposeObject(state.marker1); + state.marker1 = null; + disposeObject(state.marker2); + state.marker2 = null; + disposeObject(state.rayLine1); + state.rayLine1 = null; + disposeObject(state.rayLine2); + state.rayLine2 = null; + disposeObject(state.distanceLine); + state.distanceLine = null; + + // Reset state + state.point1 = null; + state.point2 = null; + state.ray1Origin = null; + state.ray1Direction = null; + state.ray2Origin = null; + state.ray2Direction = null; + state.currentDistance = 0; + state.mode = "select1"; + state.dragging = null; + + // Update UI + document.getElementById("distance-display").style.display = "none"; + guiParams.measuredDistance = "0.0000"; + updateInstructions("Click on the model to select first measurement point"); +} + +// ============================================================================ +// PLY Export +// ============================================================================ + +function exportPly() { + if (!splatMesh) { + console.warn("No model to export"); + return; + } + + const writer = new PlyWriter(splatMesh.packedSplats); + writer.downloadAs("rescaled_model.ply"); +} + +// ============================================================================ +// UI Updates +// ============================================================================ + +function updateInstructions(text) { + document.getElementById("instructions").textContent = text; +} + +// ============================================================================ +// Event Handlers +// ============================================================================ + +let pointerDownPos = null; + +renderer.domElement.addEventListener("pointerdown", (event) => { + pointerDownPos = { x: event.clientX, y: event.clientY }; + + const ndc = getMouseNDC(event); + + // Check if clicking on a marker to start dragging + const markerHit = checkMarkerHit(ndc); + if (markerHit) { + state.dragging = markerHit; + controls.enabled = false; + return; + } +}); + +renderer.domElement.addEventListener("pointermove", (event) => { + if (!state.dragging) return; + + const ndc = getMouseNDC(event); + raycaster.setFromCamera(ndc, camera); + + let newPoint; + if (state.dragging === "point1") { + newPoint = closestPointOnRay( + raycaster.ray, + state.ray1Origin, + state.ray1Direction, + state.point1, + ); + state.point1.copy(newPoint); + state.marker1.position.copy(newPoint); + } else if (state.dragging === "point2") { + newPoint = closestPointOnRay( + raycaster.ray, + state.ray2Origin, + state.ray2Direction, + state.point2, + ); + state.point2.copy(newPoint); + state.marker2.position.copy(newPoint); + } + + updateDistanceLine(); + calculateDistance(); +}); + +renderer.domElement.addEventListener("pointerup", (event) => { + if (state.dragging) { + state.dragging = null; + controls.enabled = true; + return; + } + + // Check if it was a click (not a drag) + if (pointerDownPos) { + const dx = event.clientX - pointerDownPos.x; + const dy = event.clientY - pointerDownPos.y; + if (Math.sqrt(dx * dx + dy * dy) > 5) { + pointerDownPos = null; + return; // Was a drag, not a click + } + } + + if (!splatMesh) return; + + const ndc = getMouseNDC(event); + const hitPoint = getHitPoint(ndc); + + if (!hitPoint) return; + + if (state.mode === "select1") { + selectPoint1(hitPoint); + } else if (state.mode === "select2") { + selectPoint2(hitPoint); + } + + pointerDownPos = null; +}); + +// ============================================================================ +// GUI +// ============================================================================ + +const gui = new GUI(); +const guiParams = { + measuredDistance: "0.0000", + newDistance: 1.0, + reset: resetSelection, + rescale: () => rescaleModel(guiParams.newDistance), + exportPly: exportPly, +}; + +gui + .add(guiParams, "measuredDistance") + .name("Measured Distance") + .listen() + .disable(); +gui.add(guiParams, "newDistance").name("New Distance"); +gui.add(guiParams, "rescale").name("Apply Rescale"); +gui.add(guiParams, "reset").name("Reset Points"); +gui.add(guiParams, "exportPly").name("Export PLY"); + +// ============================================================================ +// File Loading +// ============================================================================ + +async function loadSplatFile(urlOrFile) { + // Remove existing splat mesh + if (splatMesh) { + scene.remove(splatMesh); + splatMesh = null; + } + + resetSelection(); + updateInstructions("Loading model..."); + + try { + if (typeof urlOrFile === "string") { + // Load from URL + console.log("Loading from URL:", urlOrFile); + splatMesh = new SplatMesh({ url: urlOrFile }); + } else { + // Load from File object + console.log("Loading from file:", urlOrFile.name); + const arrayBuffer = await urlOrFile.arrayBuffer(); + console.log("File size:", arrayBuffer.byteLength, "bytes"); + splatMesh = new SplatMesh({ fileBytes: new Uint8Array(arrayBuffer) }); + } + + // Apply rotation to match common PLY orientation + splatMesh.rotation.x = Math.PI; + scene.add(splatMesh); + + await splatMesh.initialized; + console.log(`Loaded ${splatMesh.packedSplats.numSplats} splats`); + + // Auto-center camera on the model + centerCameraOnModel(); + updateInstructions("Click on the model to select first measurement point"); + } catch (error) { + console.error("Error loading splat:", error); + updateInstructions("Error loading model. Check console for details."); + } +} + +function centerCameraOnModel() { + if (!splatMesh) { + console.warn("centerCameraOnModel: no splatMesh"); + return; + } + + try { + // Use built-in getBoundingBox method + const bbox = splatMesh.getBoundingBox(true); + console.log("Bounding box:", bbox); + + const center = new THREE.Vector3(); + bbox.getCenter(center); + const size = new THREE.Vector3(); + bbox.getSize(size); + const maxDim = Math.max(size.x, size.y, size.z); + + console.log( + "Center:", + center.x.toFixed(2), + center.y.toFixed(2), + center.z.toFixed(2), + ); + console.log( + "Size:", + size.x.toFixed(2), + size.y.toFixed(2), + size.z.toFixed(2), + ); + console.log("Max dimension:", maxDim.toFixed(2)); + + if (maxDim === 0 || !Number.isFinite(maxDim)) { + console.warn("Invalid bounding box size"); + return; + } + + // Update ray line length based on model scale + rayLineLength = maxDim * 5; // 5x model size + console.log("Ray line length:", rayLineLength.toFixed(2)); + + // Position camera to see the entire model + const fov = camera.fov * (Math.PI / 180); + const cameraDistance = (maxDim / (2 * Math.tan(fov / 2))) * 1.5; + + camera.position.set(center.x, center.y, center.z + cameraDistance); + camera.lookAt(center); + camera.near = cameraDistance * 0.001; + camera.far = cameraDistance * 10; + camera.updateProjectionMatrix(); + + // Update OrbitControls target + controls.target.copy(center); + controls.update(); + + console.log( + "Camera position:", + camera.position.x.toFixed(2), + camera.position.y.toFixed(2), + camera.position.z.toFixed(2), + ); + console.log("Camera far:", camera.far); + } catch (error) { + console.error("Error computing bounding box:", error); + } +} + +// File input handler +document + .getElementById("file-input") + .addEventListener("change", async (event) => { + const file = event.target.files[0]; + if (file) { + await loadSplatFile(file); + } + }); + +// Load default asset +async function loadDefaultAsset() { + try { + const url = await getAssetFileURL("penguin.spz"); + if (url) { + await loadSplatFile(url); + } + } catch (error) { + console.error("Error loading default asset:", error); + } +} + +loadDefaultAsset(); + +// ============================================================================ +// Render Loop +// ============================================================================ + +function updateMarkerScale(marker) { + if (!marker) return; + + // Calculate distance from camera to marker + const distance = camera.position.distanceTo(marker.position); + + // Calculate scale to maintain constant screen size + // Based on FOV and desired screen percentage + const fov = camera.fov * (Math.PI / 180); + const scale = distance * Math.tan(fov / 2) * MARKER_SCREEN_SIZE; + + marker.scale.setScalar(scale); + + // Billboard: make ring face camera + if (marker.userData.ring) { + marker.userData.ring.lookAt(camera.position); + } +} + +renderer.setAnimationLoop(() => { + controls.update(); + + // Update marker scales to maintain constant screen size + updateMarkerScale(state.marker1); + updateMarkerScale(state.marker2); + + renderer.render(scene, camera); +}); diff --git a/index.html b/index.html index bd1e79ad..5eaedd4b 100644 --- a/index.html +++ b/index.html @@ -139,6 +139,7 @@

Examples

  • Multiple Viewpoints
  • Procedural Splats
  • Raycasting
  • +
  • Distance Measurement & Rescale
  • Dynamic Lighting
  • Particle Animation
  • Particle Simulation
  • diff --git a/src/PlyWriter.ts b/src/PlyWriter.ts new file mode 100644 index 00000000..31ba56dd --- /dev/null +++ b/src/PlyWriter.ts @@ -0,0 +1,166 @@ +// PLY file format writer for Gaussian Splatting data + +import type { PackedSplats } from "./PackedSplats"; +import { SH_C0 } from "./ply"; + +export type PlyWriterOptions = { + // Output format (default: binary_little_endian) + format?: "binary_little_endian" | "binary_big_endian"; +}; + +/** + * PlyWriter exports PackedSplats data to standard PLY format. + * + * The output PLY file is compatible with common 3DGS tools and can be + * re-imported into Spark or other Gaussian splatting renderers. + */ +export class PlyWriter { + packedSplats: PackedSplats; + options: Required; + + constructor(packedSplats: PackedSplats, options: PlyWriterOptions = {}) { + this.packedSplats = packedSplats; + this.options = { + format: options.format ?? "binary_little_endian", + }; + } + + /** + * Generate the PLY header string. + */ + private generateHeader(): string { + const numSplats = this.packedSplats.numSplats; + // PLY format uses underscores: binary_little_endian, binary_big_endian + const format = this.options.format; + + const lines = [ + "ply", + `format ${format} 1.0`, + `element vertex ${numSplats}`, + "property float x", + "property float y", + "property float z", + "property float scale_0", + "property float scale_1", + "property float scale_2", + "property float rot_0", + "property float rot_1", + "property float rot_2", + "property float rot_3", + "property float opacity", + "property float f_dc_0", + "property float f_dc_1", + "property float f_dc_2", + "end_header", + ]; + + return `${lines.join("\n")}\n`; + } + + /** + * Write binary data for all splats. + * Each splat is 14 float32 values = 56 bytes. + */ + private writeBinaryData(): ArrayBuffer { + const numSplats = this.packedSplats.numSplats; + const bytesPerSplat = 14 * 4; // 14 float32 properties + const buffer = new ArrayBuffer(numSplats * bytesPerSplat); + const dataView = new DataView(buffer); + const littleEndian = this.options.format === "binary_little_endian"; + + let offset = 0; + + this.packedSplats.forEachSplat( + (index, center, scales, quaternion, opacity, color) => { + // Position: x, y, z + dataView.setFloat32(offset, center.x, littleEndian); + offset += 4; + dataView.setFloat32(offset, center.y, littleEndian); + offset += 4; + dataView.setFloat32(offset, center.z, littleEndian); + offset += 4; + + // Scale: log scale (scale_0, scale_1, scale_2) + // Splats with scale=0 are 2DGS, use a very small value + const lnScaleX = scales.x > 0 ? Math.log(scales.x) : -12; + const lnScaleY = scales.y > 0 ? Math.log(scales.y) : -12; + const lnScaleZ = scales.z > 0 ? Math.log(scales.z) : -12; + dataView.setFloat32(offset, lnScaleX, littleEndian); + offset += 4; + dataView.setFloat32(offset, lnScaleY, littleEndian); + offset += 4; + dataView.setFloat32(offset, lnScaleZ, littleEndian); + offset += 4; + + // Rotation: quaternion (rot_0=w, rot_1=x, rot_2=y, rot_3=z) + dataView.setFloat32(offset, quaternion.w, littleEndian); + offset += 4; + dataView.setFloat32(offset, quaternion.x, littleEndian); + offset += 4; + dataView.setFloat32(offset, quaternion.y, littleEndian); + offset += 4; + dataView.setFloat32(offset, quaternion.z, littleEndian); + offset += 4; + + // Opacity: inverse sigmoid + // opacity = 1 / (1 + exp(-x)) => x = -ln(1/opacity - 1) = ln(opacity / (1 - opacity)) + // Clamp opacity to avoid log(0) or log(inf) + const clampedOpacity = Math.max(0.001, Math.min(0.999, opacity)); + const sigmoidOpacity = Math.log(clampedOpacity / (1 - clampedOpacity)); + dataView.setFloat32(offset, sigmoidOpacity, littleEndian); + offset += 4; + + // Color: DC coefficients (f_dc_0, f_dc_1, f_dc_2) + // color = f_dc * SH_C0 + 0.5 => f_dc = (color - 0.5) / SH_C0 + const f_dc_0 = (color.r - 0.5) / SH_C0; + const f_dc_1 = (color.g - 0.5) / SH_C0; + const f_dc_2 = (color.b - 0.5) / SH_C0; + dataView.setFloat32(offset, f_dc_0, littleEndian); + offset += 4; + dataView.setFloat32(offset, f_dc_1, littleEndian); + offset += 4; + dataView.setFloat32(offset, f_dc_2, littleEndian); + offset += 4; + }, + ); + + return buffer; + } + + /** + * Export the PackedSplats as a complete PLY file. + * @returns Uint8Array containing the PLY file data + */ + export(): Uint8Array { + const header = this.generateHeader(); + const headerBytes = new TextEncoder().encode(header); + const binaryData = this.writeBinaryData(); + + // Combine header and binary data + const result = new Uint8Array(headerBytes.length + binaryData.byteLength); + result.set(headerBytes, 0); + result.set(new Uint8Array(binaryData), headerBytes.length); + + return result; + } + + /** + * Export and trigger a file download. + * @param filename The name of the file to download + */ + downloadAs(filename: string): void { + const data = this.export(); + const blob = new Blob([data], { type: "application/octet-stream" }); + const url = URL.createObjectURL(blob); + + const link = document.createElement("a"); + link.href = url; + link.download = filename; + link.style.display = "none"; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + + URL.revokeObjectURL(url); + } +} diff --git a/src/index.ts b/src/index.ts index 8288360b..b8ea944d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,6 +13,7 @@ export { isPcSogs, } from "./SplatLoader"; export { PlyReader } from "./ply"; +export { PlyWriter, type PlyWriterOptions } from "./PlyWriter"; export { SpzReader, SpzWriter, transcodeSpz } from "./spz"; export { PackedSplats, type PackedSplatsOptions } from "./PackedSplats"; diff --git a/test/PlyWriter.test.ts b/test/PlyWriter.test.ts new file mode 100644 index 00000000..bfe18203 --- /dev/null +++ b/test/PlyWriter.test.ts @@ -0,0 +1,619 @@ +import assert from "node:assert"; +import type { PackedSplats } from "../src/PackedSplats.js"; +import { PlyWriter } from "../src/PlyWriter.js"; +import { SH_C0 } from "../src/ply.js"; + +// Mock Vector3-like object +interface Vec3 { + x: number; + y: number; + z: number; +} + +// Mock Quaternion-like object +interface Quat { + x: number; + y: number; + z: number; + w: number; +} + +// Mock Color-like object +interface Col { + r: number; + g: number; + b: number; +} + +// Mock splat data structure +interface MockSplat { + center: Vec3; + scales: Vec3; + quaternion: Quat; + opacity: number; + color: Col; +} + +// Create a mock PackedSplats that mimics the real interface +function createMockPackedSplats(splats: MockSplat[]): PackedSplats { + return { + numSplats: splats.length, + forEachSplat( + callback: ( + index: number, + center: Vec3, + scales: Vec3, + quaternion: Quat, + opacity: number, + color: Col, + ) => void, + ) { + for (let i = 0; i < splats.length; i++) { + const s = splats[i]; + callback(i, s.center, s.scales, s.quaternion, s.opacity, s.color); + } + }, + } as PackedSplats; +} + +// Helper to find header end in PLY data +function findHeaderEnd(data: Uint8Array): number { + const decoder = new TextDecoder(); + for (let i = 0; i < data.length - 10; i++) { + const slice = decoder.decode(data.slice(i, i + 11)); + if (slice === "end_header\n") { + return i + 11; + } + } + return -1; +} + +// Test 1: PlyWriter constructor with default options +{ + const mockSplats = createMockPackedSplats([]); + const writer = new PlyWriter(mockSplats); + + assert.strictEqual( + writer.options.format, + "binary_little_endian", + "Default format should be binary_little_endian", + ); +} + +// Test 2: PlyWriter constructor with custom format +{ + const mockSplats = createMockPackedSplats([]); + const writer = new PlyWriter(mockSplats, { format: "binary_big_endian" }); + + assert.strictEqual( + writer.options.format, + "binary_big_endian", + "Custom format should be respected", + ); +} + +// Test 3: Export generates valid PLY header +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + { + center: { x: 1, y: 1, z: 1 }, + scales: { x: 0.2, y: 0.2, z: 0.2 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.8, + color: { r: 1.0, g: 0.5, b: 0.0 }, + }, + { + center: { x: 2, y: 2, z: 2 }, + scales: { x: 0.3, y: 0.3, z: 0.3 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 1.0, + color: { r: 0.0, g: 1.0, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + assert.ok(headerEndIndex > 0, "Should find end_header marker"); + + const header = new TextDecoder().decode(result.slice(0, headerEndIndex)); + + assert.ok(header.startsWith("ply\n"), "Header should start with 'ply'"); + assert.ok( + header.includes("format binary_little_endian 1.0"), + "Header should include format", + ); + assert.ok( + header.includes("element vertex 3"), + "Header should include correct vertex count", + ); + assert.ok(header.includes("property float x"), "Header should include x"); + assert.ok(header.includes("property float y"), "Header should include y"); + assert.ok(header.includes("property float z"), "Header should include z"); + assert.ok( + header.includes("property float scale_0"), + "Header should include scale_0", + ); + assert.ok( + header.includes("property float scale_1"), + "Header should include scale_1", + ); + assert.ok( + header.includes("property float scale_2"), + "Header should include scale_2", + ); + assert.ok( + header.includes("property float rot_0"), + "Header should include rot_0", + ); + assert.ok( + header.includes("property float rot_1"), + "Header should include rot_1", + ); + assert.ok( + header.includes("property float rot_2"), + "Header should include rot_2", + ); + assert.ok( + header.includes("property float rot_3"), + "Header should include rot_3", + ); + assert.ok( + header.includes("property float opacity"), + "Header should include opacity", + ); + assert.ok( + header.includes("property float f_dc_0"), + "Header should include f_dc_0", + ); + assert.ok( + header.includes("property float f_dc_1"), + "Header should include f_dc_1", + ); + assert.ok( + header.includes("property float f_dc_2"), + "Header should include f_dc_2", + ); + assert.ok(header.includes("end_header"), "Header should end with end_header"); +} + +// Test 4: Export generates correct binary size +{ + const numSplats = 5; + const splats: MockSplat[] = []; + for (let i = 0; i < numSplats; i++) { + splats.push({ + center: { x: i, y: i, z: i }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }); + } + + const mockSplats = createMockPackedSplats(splats); + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + // Each splat is 14 float32 = 56 bytes + const bytesPerSplat = 14 * 4; + const expectedBinarySize = numSplats * bytesPerSplat; + + const headerEndIndex = findHeaderEnd(result); + const binarySize = result.length - headerEndIndex; + + assert.strictEqual( + binarySize, + expectedBinarySize, + `Binary data size should be ${expectedBinarySize} bytes (${numSplats} splats * 56 bytes)`, + ); +} + +// Test 5: Binary data contains correct position values (little endian) +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 1.5, y: 2.5, z: 3.5 }, + scales: { x: 0.1, y: 0.2, z: 0.3 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Position is first 3 floats (little endian) + const x = dataView.getFloat32(0, true); + const y = dataView.getFloat32(4, true); + const z = dataView.getFloat32(8, true); + + assert.strictEqual(x, 1.5, "X position should be 1.5"); + assert.strictEqual(y, 2.5, "Y position should be 2.5"); + assert.strictEqual(z, 3.5, "Z position should be 3.5"); +} + +// Test 6: Scale values are log-encoded +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 1.0, y: Math.E, z: Math.exp(2) }, // log: 0, 1, 2 + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Scale values start at offset 12 (after x, y, z) + const scale0 = dataView.getFloat32(12, true); + const scale1 = dataView.getFloat32(16, true); + const scale2 = dataView.getFloat32(20, true); + + assert.ok( + Math.abs(scale0 - 0) < 0.0001, + `Log scale_0 for scale=1 should be 0, got ${scale0}`, + ); + assert.ok( + Math.abs(scale1 - 1) < 0.0001, + `Log scale_1 for scale=e should be 1, got ${scale1}`, + ); + assert.ok( + Math.abs(scale2 - 2) < 0.0001, + `Log scale_2 for scale=e^2 should be 2, got ${scale2}`, + ); +} + +// Test 7: Zero scale uses fallback value +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0, y: 0, z: 0 }, // Zero scale (2DGS) + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Scale values start at offset 12 + const scale0 = dataView.getFloat32(12, true); + const scale1 = dataView.getFloat32(16, true); + const scale2 = dataView.getFloat32(20, true); + + // Zero scale should use -12 as fallback + assert.strictEqual(scale0, -12, "Zero scale_0 should use -12 fallback"); + assert.strictEqual(scale1, -12, "Zero scale_1 should use -12 fallback"); + assert.strictEqual(scale2, -12, "Zero scale_2 should use -12 fallback"); +} + +// Test 8: Quaternion values are correctly written +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0.1, y: 0.2, z: 0.3, w: 0.9 }, // Custom rotation + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Quaternion starts at offset 24 (after x,y,z,scale0,1,2) + // Order is w, x, y, z (rot_0=w, rot_1=x, rot_2=y, rot_3=z) + const rot0 = dataView.getFloat32(24, true); // w + const rot1 = dataView.getFloat32(28, true); // x + const rot2 = dataView.getFloat32(32, true); // y + const rot3 = dataView.getFloat32(36, true); // z + + assert.ok( + Math.abs(rot0 - 0.9) < 0.0001, + `rot_0 (w) should be 0.9, got ${rot0}`, + ); + assert.ok( + Math.abs(rot1 - 0.1) < 0.0001, + `rot_1 (x) should be 0.1, got ${rot1}`, + ); + assert.ok( + Math.abs(rot2 - 0.2) < 0.0001, + `rot_2 (y) should be 0.2, got ${rot2}`, + ); + assert.ok( + Math.abs(rot3 - 0.3) < 0.0001, + `rot_3 (z) should be 0.3, got ${rot3}`, + ); +} + +// Test 9: Opacity is sigmoid-encoded +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, // sigmoid inverse = ln(0.5/0.5) = 0 + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Opacity is at offset 40 (after x,y,z, scale0,1,2, rot0,1,2,3) + const sigmoidOpacity = dataView.getFloat32(40, true); + + assert.ok( + Math.abs(sigmoidOpacity) < 0.0001, + `Sigmoid opacity for 0.5 should be 0, got ${sigmoidOpacity}`, + ); +} + +// Test 10: Opacity edge cases are clamped +{ + // Test opacity = 1.0 (would be inf without clamping) + const mockSplats1 = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 1.0, // Clamped to 0.999 + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer1 = new PlyWriter(mockSplats1); + const result1 = writer1.export(); + const headerEndIndex1 = findHeaderEnd(result1); + const binaryData1 = result1.slice(headerEndIndex1); + const dataView1 = new DataView(binaryData1.buffer, binaryData1.byteOffset); + const opacity1 = dataView1.getFloat32(40, true); + + assert.ok( + Number.isFinite(opacity1), + `Opacity 1.0 should produce finite value, got ${opacity1}`, + ); + assert.ok(opacity1 > 0, "Opacity 1.0 should produce positive sigmoid value"); + + // Test opacity = 0.0 (would be -inf without clamping) + const mockSplats0 = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.0, // Clamped to 0.001 + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer0 = new PlyWriter(mockSplats0); + const result0 = writer0.export(); + const headerEndIndex0 = findHeaderEnd(result0); + const binaryData0 = result0.slice(headerEndIndex0); + const dataView0 = new DataView(binaryData0.buffer, binaryData0.byteOffset); + const opacity0 = dataView0.getFloat32(40, true); + + assert.ok( + Number.isFinite(opacity0), + `Opacity 0.0 should produce finite value, got ${opacity0}`, + ); + assert.ok(opacity0 < 0, "Opacity 0.0 should produce negative sigmoid value"); +} + +// Test 11: Color DC coefficients are correctly encoded +{ + // color = f_dc * SH_C0 + 0.5 => f_dc = (color - 0.5) / SH_C0 + const testColor = { r: 0.75, g: 0.25, b: 1.0 }; + const expectedDC0 = (testColor.r - 0.5) / SH_C0; + const expectedDC1 = (testColor.g - 0.5) / SH_C0; + const expectedDC2 = (testColor.b - 0.5) / SH_C0; + + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: testColor, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Color DC coefficients start at offset 44 (after opacity) + const f_dc_0 = dataView.getFloat32(44, true); + const f_dc_1 = dataView.getFloat32(48, true); + const f_dc_2 = dataView.getFloat32(52, true); + + assert.ok( + Math.abs(f_dc_0 - expectedDC0) < 0.0001, + `f_dc_0 should be ${expectedDC0}, got ${f_dc_0}`, + ); + assert.ok( + Math.abs(f_dc_1 - expectedDC1) < 0.0001, + `f_dc_1 should be ${expectedDC1}, got ${f_dc_1}`, + ); + assert.ok( + Math.abs(f_dc_2 - expectedDC2) < 0.0001, + `f_dc_2 should be ${expectedDC2}, got ${f_dc_2}`, + ); +} + +// Test 12: Empty PackedSplats exports correctly +{ + const mockSplats = createMockPackedSplats([]); + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const decoder = new TextDecoder(); + const headerStr = decoder.decode(result); + + assert.ok( + headerStr.includes("element vertex 0"), + "Empty export should have 0 vertices", + ); + assert.ok( + headerStr.includes("end_header"), + "Empty export should have valid header", + ); + + // Should only contain header, no binary data + const headerEnd = findHeaderEnd(result); + assert.strictEqual( + result.length, + headerEnd, + "Empty export should have no binary data after header", + ); +} + +// Test 13: Big endian format header +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 1, y: 2, z: 3 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats, { format: "binary_big_endian" }); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const header = new TextDecoder().decode(result.slice(0, headerEndIndex)); + + assert.ok( + header.includes("format binary_big_endian 1.0"), + "Header should specify big endian format", + ); +} + +// Test 14: Big endian binary data is byte-swapped +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 1.5, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const littleWriter = new PlyWriter(mockSplats, { + format: "binary_little_endian", + }); + const bigWriter = new PlyWriter(mockSplats, { format: "binary_big_endian" }); + + const littleResult = littleWriter.export(); + const bigResult = bigWriter.export(); + + const littleHeaderEnd = findHeaderEnd(littleResult); + const bigHeaderEnd = findHeaderEnd(bigResult); + + const littleBinary = littleResult.slice(littleHeaderEnd); + const bigBinary = bigResult.slice(bigHeaderEnd); + + // Read x value from both + const littleView = new DataView(littleBinary.buffer, littleBinary.byteOffset); + const bigView = new DataView(bigBinary.buffer, bigBinary.byteOffset); + + const littleX = littleView.getFloat32(0, true); // Read as little endian + const bigX = bigView.getFloat32(0, false); // Read as big endian + + assert.strictEqual(littleX, 1.5, "Little endian X should be 1.5"); + assert.strictEqual( + bigX, + 1.5, + "Big endian X should be 1.5 when read correctly", + ); +} + +// Test 15: Multiple splats are exported in order +{ + const mockSplats = createMockPackedSplats([ + { + center: { x: 0, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + { + center: { x: 10, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + { + center: { x: 20, y: 0, z: 0 }, + scales: { x: 0.1, y: 0.1, z: 0.1 }, + quaternion: { x: 0, y: 0, z: 0, w: 1 }, + opacity: 0.5, + color: { r: 0.5, g: 0.5, b: 0.5 }, + }, + ]); + + const writer = new PlyWriter(mockSplats); + const result = writer.export(); + + const headerEndIndex = findHeaderEnd(result); + const binaryData = result.slice(headerEndIndex); + const dataView = new DataView(binaryData.buffer, binaryData.byteOffset); + + // Each splat is 56 bytes + const x0 = dataView.getFloat32(0, true); + const x1 = dataView.getFloat32(56, true); + const x2 = dataView.getFloat32(112, true); + + assert.strictEqual(x0, 0, "First splat X should be 0"); + assert.strictEqual(x1, 10, "Second splat X should be 10"); + assert.strictEqual(x2, 20, "Third splat X should be 20"); +} + +console.log("✅ All PlyWriter test cases passed!");