Optimize batch/multi-channel image interpolation (cucim.skimage.transform.warp and cucim.skimage._vendored.ndimage)#1060
Open
grlee77 wants to merge 9 commits intorapidsai:mainfrom
Open
Conversation
Ports the precision fix that was contributed upstream to CuPy. See: cupy/cupy#9769 This fix was also backported to CuPy's 14.x branch for a future v14 release
add batch axis support to affine transform add batch_axes kwarg to map_coordinates fix spline prefiltering for batch_axes case replace _local_mean_weights computation with an ElementwiseKernel add test cases for channel/batch axis handling
When the batch axis is the last (contiguous) dimension, reduce kernel launch overhead by having each thread process all batch elements in a loop instead of launching one thread per output pixel. Key changes: - Add loop_batch_axis parameter to _generate_interp_custom and coord funcs - Use raw output with explicit size for looped kernels - Fix coords indexing in _get_coord_map to use full output index - Return KernelInfo namedtuple from kernel getters with optional size This improves performance for RGB/multichannel image transformations where channel_axis=-1 (e.g., resize, rotate, warp_polar with RGB images). fix bug in 'warp' edge mode (not used by cuCIM itself) use typing.NamedTuple instead of collections.namedtuple fix edge case for 1D data: only enable loop_batch_axis if there is at least 1 spatial axis port float_type changes from upstream
gforsyth
approved these changes
Mar 19, 2026
gforsyth
approved these changes
Mar 19, 2026
Contributor
gforsyth
left a comment
There was a problem hiding this comment.
Approving with context for future self and others.
ci-codeowners is requested here for the additional files excluded from the verify-copyright pre-commit check. Approving as the files in question are vendored from upstream and so shouldn't have an NVIDIA copyright applied.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Batch-axis optimization for ndimage interpolation kernels
Note: Changes from #1055 are also present here, so that one should be reviewed and merged first.
Overview
These two commits optimize
cucim.skimage._vendored.ndimageinterpolation functions (affine_transform,rotate,shift,zoom,map_coordinates) for arrays with non-interpolated "batch" dimensions. The most common cases being RGB/RGBA image channels (H, W, C) or a stack of equally sized images (N, H, W) or (N, H, W, C).Previously for these functions, an
(H, W, 3)uint8 image rotation was treated as a full 3D interpolation: each output pixel * channel launched a separate thread that independently computed 3D affine coordinates and interpolated across a 4×4×4 = 64-neighbor stencil in 3D (for order=3), even though the channel axis requires no interpolation at all. For order>=2 when spline prefiltering is used, that separable prefiltering should also only be applied along the axes that are actually being filtered.For SciPy and current
cupyx.scipy.ndimage, to avoid the extra interpolation work along batch axes for a (N, H, W, C) image stack, one would need to loop the computation like this:vs. with the change in this PR a single CUDA kernel can be launched using just
Due to the flexibility of the interpolation kernels (various interpolation orders and boundary extension modes), the kernel code itself is not straightforward to review. The conceptual changes made are fairly simple, though and are explained below via concrete examples.
There are two separate optimizations in this PR:
batch_axesdetection to avoid unnecessary interpolation work along non-interpolated axesloop_batch_axesto have a single thread loop over elements of the batch when the last axis is inbatch_axesOf these two, the first provides the majority of the benefit, but the second was still found to be beneficial in cases with relatively few channels on the last axis (e.g. the common case of RGB and RGBA images).
Optimization 1: Batch axes detection
Detects axes where the transform is the identity (e.g.,
zoom=1,shift=0, or the corresponding affine matrix row is[0...0, 1, 0...0, 0]). These "batch axes" skip interpolation and use a direct index copy instead.For the
(H, W, 3)rotation example, axis 2 (channels) is detected as a batch axis. The affine coordinate computation fororder=3shrinks from 3D to 2D andthe interpolation stencil shrinks from
4 * 4 * 4 = 64to4 * 4 = 16neighbors.Key changes in
_ndimage_interpolation.py:_determine_batch_axes_*) inspect the matrix/zoom/shift to identify batch dimensions.Key changes in
_ndimage_interp_kernels.py:_get_coord_*functions acceptbatch_axesand emit identity transforms (c_j = in_coord[j]) for those axes._generate_interp_customskips interpolation weight computation, boundary condition checks, and nested stencil loops for batch axes.Generated kernel before (3D interpolation, all 3 axes):
Generated kernel after (2D interpolation, channel axis = identity):
Optimization 2: Contiguous batch axis optimization
When the batch axis is the last (innermost) axis, restructures the kernel so that one thread processes all batch elements at a given spatial position. Spatial interpolation weights and indices are computed once, then reused across a tight inner loop over the batch dimension.
For the
(H, W, 3)case, the grid shrinks fromH_out × W_out × 3threads toH_out × W_outthreads. Each thread reads all 3 channels at each of the 16 spatial neighbors.Key changes in
_ndimage_interp_kernels.py:loop_batch_axisflag triggers when the batch axis is the last dimension.Y ytoraw Y ywith manual index computation.batch_sizeis a compile-time constant.Key changes in
_ndimage_interpolation.py:KernelInfo(kernel, size)wheresizeis the reduced spatial-only element count for the looped case.Example of generated kernel code with batch loop:
Performance impact
The following test cases compare the performance of the batch implementation in this PR to using CuPy
maindirectly. The comparison here is to CuPy after the float32 dtype fixes were already incorporated there, so this difference is due purely to the batch axis changes in this MR.The results in the table below are for
cupyx.scipy.ndimage.rotatevs.cucim.skimage._vendored.ndimage.rotateNote that cuCIM's
cucim.skimage.transform.rotatedoes not currently support multiple batch-axes, but does support rotation of RGB/RGBA images with channels on the last axis. So the (3840, 2160, 4) and (512, 512, 100) cases are most relevant to cuCIM usage.