From 24e4b9f38c73d45813fe1c99cb067d4d74354d0b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 08:37:57 +0000 Subject: [PATCH] feat: Add transform method to Layer class This commit introduces a new `transform` method to the `Layer` class, allowing for spatial transformations of layer data. The method takes an output layer, an inverse transformation function, and optional parameters for magnification, number of threads, output bounding box, and chunk shape. Key features: - Supports arbitrary inverse transformations (e.g., affine, rotation, translation, scaling) via a callable. - Processes data in chunks to handle large datasets efficiently. - Utilizes multithreading for improved performance. - Handles out-of-bounds coordinates by clamping and uses nearest-neighbor interpolation for non-integer coordinates. Unit tests have been added to `webknossos/tests/dataset/test_layer.py` to verify the functionality, covering: - Identity, translation, scaling, and affine transformations. - Different output bounding box configurations. - Various chunk shapes and numbers of threads. - Edge cases such as transformations mapping outside input bounds and small input layers. --- webknossos/tests/dataset/test_layer.py | 457 +++++++++++++++++++++++++ webknossos/webknossos/dataset/layer.py | 200 ++++++++++- 2 files changed, 655 insertions(+), 2 deletions(-) diff --git a/webknossos/tests/dataset/test_layer.py b/webknossos/tests/dataset/test_layer.py index 34d943006..f69cd4a46 100644 --- a/webknossos/tests/dataset/test_layer.py +++ b/webknossos/tests/dataset/test_layer.py @@ -1,9 +1,466 @@ from pathlib import Path +from typing import Tuple, Any import numpy as np import tensorstore +import pytest # Import pytest +from skimage.transform import AffineTransform # Import AffineTransform import webknossos as wk +from webknossos import BoundingBox, Dataset, Layer, Mag # Import necessary classes + + +# Helper function to create a sample dataset with a color layer +def create_dataset_with_color_layer( + tmp_path: Path, + dataset_name: str, + layer_name: str, + bounding_box_shape: Tuple[int, int, int], + dtype: Any = np.uint8, + num_channels: int = 1, + voxel_size: Tuple[int, int, int] = (1, 1, 1), + mag: Mag = Mag(1), + data_format: str = "raw", # Default to raw for simplicity in tests + chunk_shape: Tuple[int, int, int] = (64, 64, 64), +) -> Layer: + """Helper function to create a dataset with a single color layer.""" + dataset_path = tmp_path / dataset_name + dataset = Dataset( + dataset_path, + voxel_size=voxel_size, + exist_ok=True, # Allow re-creation for tests if needed + ) + layer = dataset.add_layer( + layer_name, + wk.COLOR_CATEGORY, + bounding_box=BoundingBox((0, 0, 0), shape=bounding_box_shape), + dtype_per_channel=dtype, + num_channels=num_channels, + data_format=data_format, + ) + # Add the default mag if it doesn't exist, configure with chunk_shape + if mag not in layer.mags: + layer.add_mag(mag, chunk_shape=chunk_shape) + else: + # Ensure existing mag also has compatible chunk_shape if we were to rely on it + # For simplicity, we assume add_mag handles or we create fresh datasets. + pass + return layer + + +class TestLayerTransform: + def test_identity_transform(self, tmp_path: Path) -> None: + """Tests the transform method with an identity transformation.""" + input_dataset_name = "input_ds_identity" + input_layer_name = "input_layer_identity" + output_dataset_name = "output_ds_identity" + output_layer_name = "output_layer_identity" + bbox_shape = (128, 128, 128) # Make it a bit larger to test chunking + dtype = np.uint8 + mag_level = Mag(1) + chunk_s = (32,32,32) # Smaller chunk for testing multiple chunks + + # 1. Create an input layer with some data + input_layer = create_dataset_with_color_layer( + tmp_path, + input_dataset_name, + input_layer_name, + bbox_shape, + dtype=dtype, + mag=mag_level, + chunk_shape=chunk_s, + ) + # Populate with some data + # Data shape (C, X, Y, Z) + input_data = np.random.randint(0, 255, size=(input_layer.num_channels, *bbox_shape), dtype=dtype) + input_mag_view = input_layer.get_mag(mag_level) + # Write data to the entire bounding box of the mag view + input_mag_view.write(input_data, input_mag_view.bounding_box) + + + # 2. Create an empty output layer with the same properties + output_layer = create_dataset_with_color_layer( + tmp_path, + output_dataset_name, + output_layer_name, + bbox_shape, + dtype=dtype, + mag=mag_level, + chunk_shape=chunk_s, + ) + + # 3. Define an identity inverse_transform function + def identity_inverse_transform(coords: np.ndarray) -> np.ndarray: + # coords are (N, 3) in global output space + # For identity, input global space is the same + return coords + + # 4. Call layer.transform with the identity transform + # Use a small number of threads for testing to avoid overhead issues on CI + input_layer.transform( + output_layer=output_layer, + inverse_transform=identity_inverse_transform, + mag=mag_level, + num_threads=2, # Use 2 threads for testing parallelism + chunk_shape=chunk_s # Pass the same chunk_shape + ) + + # 5. Assert that the data in the output layer is identical to the input layer + output_mag_view = output_layer.get_mag(mag_level) + output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(output_data, input_data, + err_msg="Data in output layer does not match input layer after identity transform.") + + # Additional check: ensure bounding boxes were handled correctly + assert output_layer.bounding_box == input_layer.bounding_box + assert output_mag_view.bounding_box.shape == input_mag_view.bounding_box.shape + assert output_mag_view.bounding_box.min_coord == input_mag_view.bounding_box.min_coord + + def test_translation_transform(self, tmp_path: Path) -> None: + """Tests the transform method with a translation.""" + input_dataset_name = "input_ds_translate" + input_layer_name = "input_layer_translate" + output_dataset_name = "output_ds_translate" + output_layer_name = "output_layer_translate" + + bbox_shape_orig = (64, 64, 64) # Original data size + dtype = np.uint8 + mag_level = Mag(1) + chunk_s = (16, 16, 16) # Smaller chunk for testing + num_ch = 1 + + # Translation vector (in global coordinates) + translation_vector = np.array([10, -5, 20]) # dx, dy, dz + + # 1. Create an input layer with some data + input_layer = create_dataset_with_color_layer( + tmp_path, input_dataset_name, input_layer_name, + bbox_shape_orig, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Populate with sequential data for easier verification + input_data_flat = np.arange(np.prod(bbox_shape_orig), dtype=dtype).reshape(bbox_shape_orig) + input_data = input_data_flat[np.newaxis, ...] # Add channel dimension: (1, X, Y, Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # 2. Create an empty output layer with the same properties & original bbox + # The transform method itself doesn't change the output layer's bbox, + # it writes within the output_layer.bounding_box or the specified output_bounding_box. + # For a pure translation, the output data region will be the same size as input. + output_layer = create_dataset_with_color_layer( + tmp_path, output_dataset_name, output_layer_name, + bbox_shape_orig, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + + # 3. Define the inverse_transform for translation + # output_coord -> input_coord. So, if output is translated by T, input is output - T. + inverse_translation = lambda coords: coords - translation_vector + + # 4. Call layer.transform + input_layer.transform( + output_layer=output_layer, + inverse_transform=inverse_translation, + mag=mag_level, + num_threads=1, # Test with single thread first + chunk_shape=chunk_s + ) + + # 5. Manually compute the expected output data + expected_output_data = np.zeros_like(input_data) # Fill with 0 (default fill value) + + # Determine the overlapping region in the output layer's perspective + # These are slices for the output_data array + out_x_slice = slice(max(0, int(translation_vector[0])), min(bbox_shape_orig[0], int(bbox_shape_orig[0] + translation_vector[0]))) + out_y_slice = slice(max(0, int(translation_vector[1])), min(bbox_shape_orig[1], int(bbox_shape_orig[1] + translation_vector[1]))) + out_z_slice = slice(max(0, int(translation_vector[2])), min(bbox_shape_orig[2], int(bbox_shape_orig[2] + translation_vector[2]))) + + # Determine the corresponding region in the input layer's perspective + # These are slices for the input_data array + in_x_slice = slice(max(0, int(-translation_vector[0])), min(bbox_shape_orig[0], int(bbox_shape_orig[0] - translation_vector[0]))) + in_y_slice = slice(max(0, int(-translation_vector[1])), min(bbox_shape_orig[1], int(bbox_shape_orig[1] - translation_vector[1]))) + in_z_slice = slice(max(0, int(-translation_vector[2])), min(bbox_shape_orig[2], int(bbox_shape_orig[2] - translation_vector[2]))) + + expected_output_data[0, out_x_slice, out_y_slice, out_z_slice] = input_data[0, in_x_slice, in_y_slice, in_z_slice] + + output_mag_view = output_layer.get_mag(mag_level) + actual_output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(actual_output_data, expected_output_data, + err_msg="Data in output layer does not match expected translated data.") + + def test_scaling_transform(self, tmp_path: Path) -> None: + """Tests the transform method with a scaling transformation (2x magnification).""" + input_dataset_name = "input_ds_scale" + input_layer_name = "input_layer_scale" + output_dataset_name = "output_ds_scale" + output_layer_name = "output_layer_scale" + + input_bbox_shape = (32, 32, 32) # Small input for easier manual verification + dtype = np.uint16 # Use a different dtype + mag_level = Mag(1) # Assume scaling happens at mag 1 for simplicity + chunk_s = (16, 16, 16) + num_ch = 1 + scaling_factor = 2.0 + + # 1. Create input layer and populate with data + input_layer = create_dataset_with_color_layer( + tmp_path, input_dataset_name, input_layer_name, + input_bbox_shape, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Create simple ramp data for predictable scaling + input_data_flat = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(input_bbox_shape) + input_data = input_data_flat[np.newaxis, ...] # (1, X, Y, Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # 2. Define output bounding box and create output layer + # Output bounding box should be scaled version of input + output_bbox_shape = tuple(int(s * scaling_factor) for s in input_bbox_shape) + output_layer = create_dataset_with_color_layer( + tmp_path, output_dataset_name, output_layer_name, + output_bbox_shape, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # The output_bounding_box for the transform call will be the output_layer's bounding_box. + + # 3. Define inverse_transform for scaling + # output_coord -> input_coord. So, input_coord = output_coord / scaling_factor + inverse_scaling = lambda coords: coords / scaling_factor + + # 4. Call layer.transform + input_layer.transform( + output_layer=output_layer, + inverse_transform=inverse_scaling, + mag=mag_level, + num_threads=0, # Test with num_threads=0 (sequential) + chunk_shape=chunk_s + # output_bounding_box is implicitly output_layer.bounding_box here + ) + + # 5. Manually compute expected output data (nearest neighbor) + expected_output_data = np.zeros((num_ch, *output_bbox_shape), dtype=dtype) + for c in range(num_ch): + for ox in range(output_bbox_shape[0]): + for oy in range(output_bbox_shape[1]): + for oz in range(output_bbox_shape[2]): + # Corresponding input coordinate (float) + ix_f = ox / scaling_factor + iy_f = oy / scaling_factor + iz_f = oz / scaling_factor + + # Nearest neighbor rounding + ix = int(round(ix_f)) + iy = int(round(iy_f)) + iz = int(round(iz_f)) + + # Clamp to input bounds + ix = np.clip(ix, 0, input_bbox_shape[0] - 1) + iy = np.clip(iy, 0, input_bbox_shape[1] - 1) + iz = np.clip(iz, 0, input_bbox_shape[2] - 1) + + expected_output_data[c, ox, oy, oz] = input_data[c, ix, iy, iz] + + output_mag_view = output_layer.get_mag(mag_level) + actual_output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(actual_output_data, expected_output_data, + err_msg="Data in output layer does not match expected scaled data.") + + def test_affine_transform_simple_rotation(self, tmp_path: Path) -> None: + """Tests a simple 90-degree rotation around Z-axis, followed by translation.""" + input_name, output_name = "input_ds_affine", "output_ds_affine" + layer_name = "layer_affine" + # Input is a 2x1x1 strip along X to make rotation easy to verify + # Centered at (0.5, 0.5, 0.5) before translation for rotation + input_bbox_shape = (2, 1, 1) + dtype, mag_level, num_ch = np.uint8, Mag(1), 1 + chunk_s = (1, 1, 1) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_name, layer_name, input_bbox_shape, + dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Data: [value1, value2] along X axis for the single channel + input_data = np.array([[[[100]], [[200]]]]).astype(dtype).reshape(num_ch, *input_bbox_shape) # (C,X,Y,Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # Output layer will be large enough to contain the rotated & translated data + # Original data at (0,0,0) and (1,0,0) + # Rotated 90 deg around Z (about origin 0,0): (0,0,0) -> (0,0,0); (1,0,0) -> (0,1,0) + # Translated by (1,1,0): (0,0,0) -> (1,1,0); (0,1,0) -> (1,2,0) + # So, output needs to cover at least (1,1,0) to (1,2,0) + output_bbox_shape = (2, 3, 1) # Make it a bit larger to be safe + output_layer = create_dataset_with_color_layer( + tmp_path, output_name, layer_name, output_bbox_shape, + dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + + # Define affine transform: 90 deg rotation around Z ( skimage uses degrees) then translate by (1,1,0) + # skimage AffineTransform works with (col, row, z) which is (x, y, z) + # Rotation matrix for 90 deg around Z: [[0, -1, 0], [1, 0, 0], [0, 0, 1]] + # Translation: (1,1,0) + # Center of rotation is implicitly (0,0,0) for the matrix part. + # The inverse_transform in the method expects global coords. + # Our input data is at global coords (0,0,0) and (1,0,0). + + # Transformation: T_translate * T_rotate + # Output_coord = T_translate * T_rotate * Input_coord + # Input_coord = T_rotate_inv * T_translate_inv * Output_coord + + # Rotation by +90 deg around Z (counter-clockwise) + # x' = x*cos(a) - y*sin(a) + # y' = x*sin(a) + y*cos(a) + # For +90: x' = -y, y' = x + # Inverse rotation (-90 deg): x = y', y = -x' + + # Translation by (tx, ty, tz) + # x_out = x_rot + tx + # y_out = y_rot + ty + # Inverse translation: x_rot = x_out - tx, y_rot = y_out - ty + + # Combined inverse: + # 1. Output coord (ox, oy, oz) + # 2. Inverse translate: (ox - tx, oy - ty, oz - tz) -> (ix_t, iy_t, iz_t) + # 3. Inverse rotate: (iy_t, -ix_t, iz_t) -> input coord (inx, iny, inz) + + tx, ty, tz = 1, 1, 0 + + def affine_inverse_transform(output_coords_global: np.ndarray) -> np.ndarray: + # output_coords_global is (N,3) + input_coords_translated_inv = output_coords_global - np.array([tx, ty, tz]) + + input_coords_rotated_inv = np.zeros_like(input_coords_translated_inv) + input_coords_rotated_inv[:, 0] = input_coords_translated_inv[:, 1] # x_in = y_translated_inv + input_coords_rotated_inv[:, 1] = -input_coords_translated_inv[:, 0] # y_in = -x_translated_inv + input_coords_rotated_inv[:, 2] = input_coords_translated_inv[:, 2] # z_in = z_translated_inv + return input_coords_rotated_inv + + input_layer.transform( + output_layer, affine_inverse_transform, mag=mag_level, num_threads=None, chunk_shape=chunk_s + ) + + expected_output = np.zeros((num_ch, *output_bbox_shape), dtype=dtype) + # Input (0,0,0) [val 100] -> Rot (0,0,0) -> Trans (1,1,0) + # Input (1,0,0) [val 200] -> Rot (0,1,0) -> Trans (1,2,0) + if 0 <= 1 < output_bbox_shape[0] and 0 <= 1 < output_bbox_shape[1] and 0 <= 0 < output_bbox_shape[2]: + expected_output[0, 1, 1, 0] = 100 # voxel at (1,1,0) in output + if 0 <= 1 < output_bbox_shape[0] and 0 <= 2 < output_bbox_shape[1] and 0 <= 0 < output_bbox_shape[2]: + expected_output[0, 1, 2, 0] = 200 # voxel at (1,2,0) in output + + actual_output = output_layer.get_mag(mag_level).read() + np.testing.assert_array_equal(actual_output, expected_output, err_msg="Affine transform failed.") + + def test_output_bounding_box_smaller_and_shifted(self, tmp_path: Path) -> None: + """Tests transform when output_bounding_box is smaller and shifted.""" + input_ds, input_layer_n = "in_ds_obb", "in_l_obb" + output_ds, output_layer_n = "out_ds_obb", "out_l_obb" + input_bbox_shape = (50, 50, 50) + dtype, mag, num_ch = np.uint8, Mag(1), 1 + chunk_s = (10,10,10) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, input_bbox_shape, + dtype=dtype, mag=mag, chunk_shape=chunk_s, num_channels=num_ch + ) + # Sequential data + input_data = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(num_ch, *input_bbox_shape) + input_layer.get_mag(mag).write(input_data, input_layer.get_mag(mag).bounding_box) + + # Output layer is larger, but we'll only write to a small, shifted part of it + output_layer_bbox_shape = (60, 60, 60) + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, output_layer_bbox_shape, + dtype=dtype, mag=mag, chunk_shape=chunk_s, num_channels=num_ch + ) + # Fill output layer with a distinct value to check only target OBB is written + fill_value = 77 + output_layer.get_mag(mag).write(np.full((num_ch, *output_layer_bbox_shape), fill_value, dtype=dtype)) + + + # Define the specific output_bounding_box for the transform operation + # This OBB is in global coordinates. + # Let's pick a 20x20x20 cube shifted by (5,5,5) in the output layer + # This OBB will read from input layer starting at (5,5,5) due to identity transform + obb_min_coord = (5, 5, 5) + obb_shape = (20, 20, 20) + specific_output_bb = BoundingBox(min_coord=obb_min_coord, shape=obb_shape) + + input_layer.transform( + output_layer, + inverse_transform=lambda coords: coords, # Identity + mag=mag, + output_bounding_box=specific_output_bb, + num_threads=2, + chunk_shape=chunk_s + ) + + # Expected data in the output layer + expected_data_in_output_layer = np.full((num_ch, *output_layer_bbox_shape), fill_value, dtype=dtype) + # The part that should be overwritten comes from input_data[0, 5:25, 5:25, 5:25] + # and written to expected_data_in_output_layer[0, 5:25, 5:25, 5:25] + src_data_slice = (slice(None), slice(obb_min_coord[0], obb_min_coord[0]+obb_shape[0]), \ + slice(obb_min_coord[1], obb_min_coord[1]+obb_shape[1]), \ + slice(obb_min_coord[2], obb_min_coord[2]+obb_shape[2])) + expected_data_in_output_layer[src_data_slice] = input_data[src_data_slice] + + actual_output_data = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output_data, expected_data_in_output_layer) + + def test_transform_all_coords_outside_input(self, tmp_path: Path) -> None: + """Tests transform when all transformed coords are outside input layer's bounds.""" + input_ds, input_layer_n = "in_ds_outside", "in_l_outside" + output_ds, output_layer_n = "out_ds_outside", "out_l_outside" + bbox_shape = (10, 10, 10) + dtype, mag, num_ch, chunk_s = np.uint8, Mag(1), 1, (5,5,5) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + input_layer.get_mag(mag).write(np.ones((num_ch, *bbox_shape), dtype=dtype)) # Fill with 1s + + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + + # Translation that shifts everything out of input bounds + # Input is at [0,0,0] to [9,9,9]. Shift by [100,100,100] + translation_far = np.array([100, 100, 100]) + inverse_transform_far = lambda coords: coords - translation_far + + input_layer.transform( + output_layer, inverse_transform_far, mag=mag, num_threads=1, chunk_shape=chunk_s + ) + + # Expected output is all zeros (default fill value from clamping/empty read) + expected_output = np.zeros((num_ch, *bbox_shape), dtype=dtype) + actual_output = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output, expected_output) + + def test_transform_small_input_layer(self, tmp_path: Path) -> None: + """Tests transform with a very small input layer (2x2x2).""" + input_ds, input_layer_n = "in_ds_small", "in_l_small" + output_ds, output_layer_n = "out_ds_small", "out_l_small" + input_bbox_shape = (2, 2, 2) + dtype, mag, num_ch, chunk_s = np.uint8, Mag(1), 1, (1,1,1) # Chunk size 1 + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, input_bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + input_data = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(num_ch, *input_bbox_shape) + input_layer.get_mag(mag).write(input_data, input_layer.get_mag(mag).bounding_box) + + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, input_bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + + # Simple identity transform + input_layer.transform( + output_layer, lambda coords: coords, mag=mag, num_threads=None, chunk_shape=chunk_s + ) + + actual_output = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output, input_data) def test_add_mag_from_zarrarray(tmp_path: Path) -> None: diff --git a/webknossos/webknossos/dataset/layer.py b/webknossos/webknossos/dataset/layer.py index d85417524..13cb0fc77 100644 --- a/webknossos/webknossos/dataset/layer.py +++ b/webknossos/webknossos/dataset/layer.py @@ -5,7 +5,7 @@ from os import PathLike from os.path import relpath from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, Callable, Tuple from urllib.parse import urlparse import numpy as np @@ -14,7 +14,7 @@ from upath import UPath from ..client.context import _get_context -from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike +from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike, BoundingBox from ..geometry.mag import MagLike from ._array import ArrayException, TensorStoreArray from ._downsampling_utils import ( @@ -1619,6 +1619,201 @@ def __repr__(self) -> str: def _get_largest_segment_id_maybe(self) -> int | None: return None + def transform( + self, + output_layer: "Layer", + inverse_transform: Callable[[np.ndarray], np.ndarray], + mag: Mag = Mag(1), + num_threads: int | None = None, + output_bounding_box: Optional[BoundingBox] = None, + chunk_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + """Transforms the layer into the output_layer using the inverse_transform function. + + Args: + output_layer (Layer): The layer to write the transformed data to. + inverse_transform (Callable[[np.ndarray], np.ndarray]): A function that takes a NumPy array of shape (N, 3) + representing coordinates in the output space and returns a NumPy array of shape (N, 3) + representing the corresponding coordinates in the input space. + mag (Mag, optional): The magnification level to use for reading and writing. Defaults to Mag(1). + num_threads (int | None, optional): The number of threads to use for parallel processing. + Defaults to None, which means cluster_tools will use its default. + output_bounding_box (BoundingBox | None, optional): The bounding box in the output layer to transform. + Defaults to the bounding box of the output_layer. + chunk_shape (Tuple[int, int, int] | None, optional): The shape of chunks to process. + Defaults to (64, 64, 64). + """ + if output_bounding_box is None: + output_bounding_box = output_layer.bounding_box + if output_bounding_box is None: + raise ValueError( + "output_bounding_box must be provided if output_layer has no bounding_box" + ) + + if chunk_shape is None: + chunk_shape_vec = Vec3Int(64, 64, 64) + else: + chunk_shape_vec = Vec3Int.from_tuple(chunk_shape) + + input_mag_view = self.get_mag(mag) + output_mag_view = output_layer.get_mag(mag) + + # Ensure layers are writable if they are the target of write operations + # self._ensure_writable() # Input layer is read from + output_layer._ensure_writable() + + def process_chunk(chunk_bbox_mag_coords: NDBoundingBox) -> None: + # Generate coordinates for all voxels in the current output chunk (in mag coords) + # chunk_bbox_mag_coords is already in mag units + x_coords = np.arange(chunk_bbox_mag_coords.min_x, chunk_bbox_mag_coords.max_x) + y_coords = np.arange(chunk_bbox_mag_coords.min_y, chunk_bbox_mag_coords.max_y) + z_coords = np.arange(chunk_bbox_mag_coords.min_z, chunk_bbox_mag_coords.max_z) + + # Create a meshgrid of coordinates + # Shape: (3, Dx, Dy, Dz) where D are dimensions of the chunk + grid_x, grid_y, grid_z = np.meshgrid(x_coords, y_coords, z_coords, indexing="ij") + + # Flatten and stack to get (N, 3) array of output coordinates + # N = Dx * Dy * Dz + output_coords_mag = np.stack( + [grid_x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=-1 + ) + + # Convert output coordinates from mag units to global units for inverse_transform + # The inverse_transform function expects global coordinates + output_coords_global = output_coords_mag * mag.to_vec3_int().to_np() + output_bounding_box.min_coord_np[:3] + + + # Apply inverse transform to get input coordinates in global space + input_coords_global = inverse_transform(output_coords_global) + + # Convert transformed input coordinates from global units back to mag units for reading + input_coords_mag = input_coords_global / mag.to_vec3_int().to_np() + + + # Handle out-of-bounds: Clamp to input layer's bounds (in mag units) + input_layer_bbox_mag = input_mag_view.bounding_box.in_mag(mag) # Bbox of input layer in mag units + input_coords_mag[:, 0] = np.clip( + input_coords_mag[:, 0], input_layer_bbox_mag.min_x, input_layer_bbox_mag.max_x -1 + ) + input_coords_mag[:, 1] = np.clip( + input_coords_mag[:, 1], input_layer_bbox_mag.min_y, input_layer_bbox_mag.max_y -1 + ) + input_coords_mag[:, 2] = np.clip( + input_coords_mag[:, 2], input_layer_bbox_mag.min_z, input_layer_bbox_mag.max_z -1 + ) + + # Nearest neighbor interpolation (by rounding to nearest integer) + input_coords_mag_rounded = np.round(input_coords_mag).astype(int) + + # Read data from input layer + # This requires careful handling as MagView.read expects a bounding box, + # not arbitrary coordinates. We need to read a bounding box that encompasses + # all unique input_coords_mag_rounded and then select the specific voxels. + # For simplicity in this step, let's assume we can read individual voxels + # or small patches. A more optimized approach would read larger encompassing BBs. + + # A direct voxel-by-voxel read is inefficient. + # Instead, read the bounding box enclosing all needed input voxels. + min_input_coords = np.min(input_coords_mag_rounded, axis=0) + max_input_coords = np.max(input_coords_mag_rounded, axis=0) + + # Create a bounding box for reading from the input layer (in mag coordinates) + # Add 1 to max_coords because NDBoundingBox to_slices is exclusive for the max coord + input_read_bbox_mag = NDBoundingBox.from_iterator( + np.concatenate((min_input_coords, max_input_coords + 1)) + ) + + + # Ensure the read bounding box is within the input layer's actual data bounds + # This is a bit redundant due to clamping, but good for safety + input_read_bbox_mag = input_read_bbox_mag.intersection(input_layer_bbox_mag) + + if input_read_bbox_mag.is_empty(): # All requested coords are outside input layer + # Fill with zeros or appropriate background value + data_shape = ( + self.num_channels, # C + chunk_bbox_mag_coords.shape[0], # X + chunk_bbox_mag_coords.shape[1], # Y + chunk_bbox_mag_coords.shape[2], # Z + ) + read_data_flat = np.zeros((input_coords_mag_rounded.shape[0], self.num_channels), dtype=self.dtype_per_channel) + else: + # Read the encompassing bounding box from the input layer + encompassing_data = input_mag_view.read(input_read_bbox_mag) # Shape (C, Dx', Dy', Dz') + + # Map the rounded input coordinates to indices within the `encompassing_data` + # These are 0-indexed relative to the start of `input_read_bbox_mag` + relative_input_coords_x = input_coords_mag_rounded[:, 0] - input_read_bbox_mag.min_x + relative_input_coords_y = input_coords_mag_rounded[:, 1] - input_read_bbox_mag.min_y + relative_input_coords_z = input_coords_mag_rounded[:, 2] - input_read_bbox_mag.min_z + + # Gather the data using these relative coordinates + # encompassing_data has shape (C, X, Y, Z) + # We want to select N voxels, resulting in (N, C) + read_data_flat = encompassing_data[ + :, # All channels + relative_input_coords_x, + relative_input_coords_y, + relative_input_coords_z, + ].transpose() # Transpose to get (N, C) + + # Reshape flat data (N, C) back to chunk shape (C, Dx, Dy, Dz) + # The order of raveling for grid_x, grid_y, grid_z was 'ij' which corresponds to 'F' (Fortran-like) order + # when thinking about (x,y,z) dimensions. + # However, numpy's default reshape order is 'C' (C-like). + # The output of np.stack was (N,3) where N = X*Y*Z (iterating Z first, then Y, then X for meshgrid 'ij') + # So, read_data_flat corresponds to this N. + # We need to reshape it to (Dx, Dy, Dz, C) and then transpose to (C, Dx, Dy, Dz) + # Dx, Dy, Dz are dimensions of the output chunk + dx = chunk_bbox_mag_coords.shape[0] + dy = chunk_bbox_mag_coords.shape[1] + dz = chunk_bbox_mag_coords.shape[2] + + output_data = read_data_flat.reshape((dx, dy, dz, self.num_channels)).transpose(3, 0, 1, 2) + + + # Write data to the output layer + # chunk_bbox_mag_coords is already in the correct MagView space for output_mag_view + output_mag_view.write(output_data, chunk_bbox_mag_coords) + + + # Iterate over the output_bounding_box in chunks + # output_bounding_box is in global coordinates. Convert to mag coordinates for iteration. + output_bbox_mag = output_bounding_box.in_mag(mag) + + # Align chunk_shape_vec with the axes of the bounding box + # Assuming output_bbox_mag.axes are ('x', 'y', 'z') for simplicity here. + # If axes can be different, this needs to be handled. + # For now, assume standard 'x', 'y', 'z' order for chunk_shape_vec. + + # Create tasks for the executor + tasks = [] + for current_chunk_min_coord_mag in output_bbox_mag.chunked_coords_iter(chunk_shape_vec): + chunk_max_coord_mag = Vec3Int.min(current_chunk_min_coord_mag + chunk_shape_vec, output_bbox_mag.max_coord) + # Define the bounding box for the current chunk in mag coordinates + # NDBoundingBox expects [min_x, min_y, min_z, max_x, max_y, max_z] + # where max is exclusive for slicing, so it's effectively size. + chunk_bbox_mag = NDBoundingBox.from_min_max_coords(current_chunk_min_coord_mag, chunk_max_coord_mag, axes=output_bbox_mag.axes) + if chunk_bbox_mag.is_empty(): + continue + tasks.append(chunk_bbox_mag) + + if num_threads == 0: # Run sequentially for debugging or specific cases + logging.info(f"Running transform for layer {self.name} sequentially.") + for task_bbox in tasks: + process_chunk(task_bbox) + else: + logging.info(f"Running transform for layer {self.name} with up to {num_threads or 'default'} threads.") + with Executor(max_workers=num_threads) as executor: + # cf_executor is the concurrent.futures.Executor + # The map function in cluster_tools.Executor is for cloud jobs. + # We need to use submit for local multithreading. + futures = [executor.submit(process_chunk, chunk_bbox) for chunk_bbox in tasks] + for future in futures: + future.result() # Wait for completion and raise exceptions if any + + def as_segmentation_layer(self) -> "SegmentationLayer": """Casts into SegmentationLayer.""" if isinstance(self, SegmentationLayer): @@ -1626,6 +1821,7 @@ def as_segmentation_layer(self) -> "SegmentationLayer": else: raise TypeError(f"self is not a SegmentationLayer. Got: {type(self)}") + @classmethod def _ensure_layer(cls, layer: Union[str, PathLike, "Layer"]) -> "Layer": if isinstance(layer, Layer):