Skip to content

Optimize batch/multi-channel image interpolation (cucim.skimage.transform.warp and cucim.skimage._vendored.ndimage)#1060

Open
grlee77 wants to merge 9 commits intorapidsai:mainfrom
grlee77:grelee/ndinterp-batch-2026
Open

Optimize batch/multi-channel image interpolation (cucim.skimage.transform.warp and cucim.skimage._vendored.ndimage)#1060
grlee77 wants to merge 9 commits intorapidsai:mainfrom
grlee77:grelee/ndinterp-batch-2026

Conversation

@grlee77
Copy link
Copy Markdown
Contributor

@grlee77 grlee77 commented Mar 19, 2026

Batch-axis optimization for ndimage interpolation kernels

Note: Changes from #1055 are also present here, so that one should be reviewed and merged first.

Overview

These two commits optimize cucim.skimage._vendored.ndimage interpolation functions (affine_transform, rotate, shift, zoom, map_coordinates) for arrays with non-interpolated "batch" dimensions. The most common cases being RGB/RGBA image channels (H, W, C) or a stack of equally sized images (N, H, W) or (N, H, W, C).

Previously for these functions, an (H, W, 3) uint8 image rotation was treated as a full 3D interpolation: each output pixel * channel launched a separate thread that independently computed 3D affine coordinates and interpolated across a 4×4×4 = 64-neighbor stencil in 3D (for order=3), even though the channel axis requires no interpolation at all. For order>=2 when spline prefiltering is used, that separable prefiltering should also only be applied along the axes that are actually being filtered.

For SciPy and current cupyx.scipy.ndimage, to avoid the extra interpolation work along batch axes for a (N, H, W, C) image stack, one would need to loop the computation like this:

for batch_idx in range(N):
    for channel_idx in range(C):
        output[batch_idx, :, :, channel_idx] = rotate(image[batch_idx, :, :, channel_idx], angle=5, axes=(1, 0), mode='reflect', order=3)

vs. with the change in this PR a single CUDA kernel can be launched using just

output = rotate(image, angle=5, axes=(2, 1), mode='reflect', order=3)

Due to the flexibility of the interpolation kernels (various interpolation orders and boundary extension modes), the kernel code itself is not straightforward to review. The conceptual changes made are fairly simple, though and are explained below via concrete examples.

There are two separate optimizations in this PR:

  1. batch_axes detection to avoid unnecessary interpolation work along non-interpolated axes
    • this reduces the number of axes involved in the kernel loop to only the non-batch axes instead of all axes
    • the same is true for the separable spline prefiltering
  2. loop_batch_axes to have a single thread loop over elements of the batch when the last axis is in batch_axes
    • in other words instead of 1 pixel per GPU thread, there are N batch elements/channels computed per GPU thread

Of these two, the first provides the majority of the benefit, but the second was still found to be beneficial in cases with relatively few channels on the last axis (e.g. the common case of RGB and RGBA images).

Optimization 1: Batch axes detection

Detects axes where the transform is the identity (e.g., zoom=1, shift=0, or the corresponding affine matrix row is [0...0, 1, 0...0, 0]). These "batch axes" skip interpolation and use a direct index copy instead.

For the (H, W, 3) rotation example, axis 2 (channels) is detected as a batch axis. The affine coordinate computation for order=3 shrinks from 3D to 2D and
the interpolation stencil shrinks from 4 * 4 * 4 = 64 to 4 * 4 = 16 neighbors.

Key changes in _ndimage_interpolation.py:

  • Transform analysis functions (_determine_batch_axes_*) inspect the matrix/zoom/shift to identify batch dimensions.
  • Batch axes are propagated to the kernel getters.

Key changes in _ndimage_interp_kernels.py:

  • All _get_coord_* functions accept batch_axes and emit identity transforms (c_j = in_coord[j]) for those axes.
  • _generate_interp_custom skips interpolation weight computation, boundary condition checks, and nested stencil loops for batch axes.

Generated kernel before (3D interpolation, all 3 axes):

// 3D affine coordinate computation (3 axes × 4 matrix elements each)
W c_0 = 0.0;
c_0 += mat[0] * in_coord[0]; c_0 += mat[1] * in_coord[1]; c_0 += mat[2] * in_coord[2]; c_0 += mat[3];
W c_1 = ...;  // same for axis 1
W c_2 = ...;  // same for axis 2 (channel) — unnecessary!

// Bounds check on all 3 axes including channels
if ((c_0 < 0) || ... || (c_2 < 0) || (c_2 > xsize_2 - 1)) { ... }

// Triply-nested stencil loop: 4 × 4 × 4 = 64 iterations
for (k_0 = 0; k_0 <= 3; k_0++) {          // axis 0
    // weights_0, boundary conditions for axis 0 (×4)
    for (k_1 = 0; k_1 <= 3; k_1++) {      // axis 1
        // weights_1, boundary conditions for axis 1 (×4)
        for (k_2 = 0; k_2 <= 3; k_2++) {  // axis 2 (channels) — unnecessary!
            // weights_2, boundary conditions for axis 2 (×4)
            out += x[ic_0 + ic_1 + ic_2] * (w_0 * w_1 * w_2);
        }
    }
}

Generated kernel after (2D interpolation, channel axis = identity):

// 2D affine coordinate computation (channel axis skipped)
W c_0 = 0.0;
c_0 += mat[0] * in_coord[0]; c_0 += mat[1] * in_coord[1]; c_0 += mat[3];
W c_1 = ...;

// Bounds check on 2 spatial axes only
if ((c_0 < 0) || ... || (c_1 > xsize_1 - 1)) { ... }

// Doubly-nested stencil: 4 × 4 = 16 iterations
// Channel axis uses w_2 = 1.0, ic_2 = in_coord[2] * sx_2 (identity)
for (k_0 = 0; k_0 <= 3; k_0++) {      // axis 0
    for (k_1 = 0; k_1 <= 3; k_1++) {  // axis 1
        {  // batch axis 2: weight=1, no interpolation
            out += x[ic_0 + ic_1 + ic_2] * (w_0 * w_1 * w_2);
        }
    }
}

Optimization 2: Contiguous batch axis optimization

When the batch axis is the last (innermost) axis, restructures the kernel so that one thread processes all batch elements at a given spatial position. Spatial interpolation weights and indices are computed once, then reused across a tight inner loop over the batch dimension.

For the (H, W, 3) case, the grid shrinks from H_out × W_out × 3 threads to H_out × W_out threads. Each thread reads all 3 channels at each of the 16 spatial neighbors.

Key changes in _ndimage_interp_kernels.py:

  • loop_batch_axis flag triggers when the batch axis is the last dimension.
  • Output switches from indexed Y y to raw Y y with manual index computation.
  • Index unraveling covers only spatial dimensions; batch_size is a compile-time constant.
  • Stencil loops iterate over spatial axes only; the innermost loop iterates over batch elements, accumulating into a per-batch output array.

Key changes in _ndimage_interpolation.py:

  • Kernel getters return a KernelInfo(kernel, size) where size is the reduced spatial-only element count for the looped case.

Example of generated kernel code with batch loop:

// One thread per spatial position (not per pixel×channel)
const unsigned int batch_size = 3;
const unsigned int out_base_idx = i * batch_size;

// Index unraveling: 2D spatial only (no channel dimension)
in_coord[1] = ...; in_coord[0] = ...;

// 2D affine coordinate computation (same as commit 1)
W c_0 = ...; W c_1 = ...;

// Spatial bounds check -> fill all channels with cval if out of bounds
if (out_of_bounds) {
    for (batch_idx = 0; batch_idx < batch_size; batch_idx++)
        y[out_base_idx + batch_idx] = cval;
} else {
    // Compute spline weights/indices for spatial axes ONCE
    W weights_0[4]; ...  // axis 0 weights and boundary indices
    W weights_1[4]; ...  // axis 1 weights and boundary indices

    float out_batch[3] = {0, 0, 0};

    // Spatial stencil loops (outer)
    for (k_0 = 0; k_0 <= 3; k_0++) {
        for (k_1 = 0; k_1 <= 3; k_1++) {
            W spatial_weight = w_0 * w_1;  // computed once
            int ic_base = ic_0 + ic_1;     // computed once

            // Batch loop (innermost) -- reads 3 channels at same neighbor
            for (batch_idx = 0; batch_idx < batch_size; batch_idx++) {
                float val = x[ic_base + batch_idx];
                out_batch[batch_idx] += val * spatial_weight;
            }
        }
    }

    for (batch_idx = 0; batch_idx < batch_size; batch_idx++)
        y[out_base_idx + batch_idx] = (Y)rint(out_batch[batch_idx]);
}

Performance impact

The following test cases compare the performance of the batch implementation in this PR to using CuPy main directly. The comparison here is to CuPy after the float32 dtype fixes were already incorporated there, so this difference is due purely to the batch axis changes in this MR.

The results in the table below are for cupyx.scipy.ndimage.rotate vs. cucim.skimage._vendored.ndimage.rotate

Note that cuCIM's cucim.skimage.transform.rotate does not currently support multiple batch-axes, but does support rotation of RGB/RGBA images with channels on the last axis. So the (3840, 2160, 4) and (512, 512, 100) cases are most relevant to cuCIM usage.

shape axes dtype order prefilter CuPy (ms) This PR (ms) accel.
(3840, 2160, 4) (1, 0) uint8 0 False 0.528 0.551 0.957
(3840, 2160, 4) (1, 0) uint8 1 False 0.586 0.547 1.071
(3840, 2160, 4) (1, 0) uint8 2 False 1.022 0.561 1.821
(3840, 2160, 4) (1, 0) uint8 3 False 1.530 0.619 2.470
(3840, 2160, 4) (1, 0) uint8 4 False 2.241 0.672 3.336
(3840, 2160, 4) (1, 0) uint8 5 False 6.226 0.749 8.307
(3840, 2160, 4) (1, 0) uint8 2 True 3.742 2.794 1.339
(3840, 2160, 4) (1, 0) uint8 3 True 4.222 2.844 1.485
(3840, 2160, 4) (1, 0) uint8 4 True 6.904 4.564 1.513
(3840, 2160, 4) (1, 0) uint8 5 True 10.802 4.648 2.324
(16, 1920, 1080, 4) (2, 1) uint8 0 False 0.730 0.732 0.998
(16, 1920, 1080, 4) (2, 1) uint8 1 False 0.891 0.847 1.051
(16, 1920, 1080, 4) (2, 1) uint8 2 False 3.339 1.443 2.313
(16, 1920, 1080, 4) (2, 1) uint8 3 False 5.911 1.918 3.081
(16, 1920, 1080, 4) (2, 1) uint8 4 False 9.663 2.436 3.967
(16, 1920, 1080, 4) (2, 1) uint8 5 False 27.540 3.135 8.785
(16, 1920, 1080, 4) (2, 1) uint8 2 True 6.159 2.855 2.158
(16, 1920, 1080, 4) (2, 1) uint8 3 True 8.730 3.316 2.632
(16, 1920, 1080, 4) (2, 1) uint8 4 True 14.571 4.623 3.152
(16, 1920, 1080, 4) (2, 1) uint8 5 True 33.358 5.491 6.075
(512, 512, 100) (1, 0) uint8 0 False 0.531 0.549 0.967
(512, 512, 100) (1, 0) uint8 1 False 0.587 0.557 1.055
(512, 512, 100) (1, 0) uint8 2 False 1.070 0.574 1.864
(512, 512, 100) (1, 0) uint8 3 False 1.611 0.621 2.594
(512, 512, 100) (1, 0) uint8 4 False 2.373 0.678 3.500
(512, 512, 100) (1, 0) uint8 5 False 6.415 0.763 8.404
(512, 512, 100) (1, 0) uint8 2 True 3.820 2.826 1.352
(512, 512, 100) (1, 0) uint8 3 True 4.301 2.878 1.494
(512, 512, 100) (1, 0) uint8 4 True 7.043 4.616 1.526
(512, 512, 100) (1, 0) uint8 5 True 11.006 4.750 2.317
(512, 100, 512) (2, 0) uint8 0 False 0.889 0.874 1.017
(512, 100, 512) (2, 0) uint8 1 False 1.211 1.095 1.105
(512, 100, 512) (2, 0) uint8 2 False 5.336 2.090 2.554
(512, 100, 512) (2, 0) uint8 3 False 9.375 2.964 3.162
(512, 100, 512) (2, 0) uint8 4 False 15.537 3.891 3.993
(512, 100, 512) (2, 0) uint8 5 False 45.219 4.750 9.520
(512, 100, 512) (2, 0) uint8 2 True 8.171 4.043 2.021
(512, 100, 512) (2, 0) uint8 3 True 12.276 4.932 2.489
(512, 100, 512) (2, 0) uint8 4 True 20.294 7.173 2.829
(512, 100, 512) (2, 0) uint8 5 True 49.587 8.141 6.091
(100, 512, 512) (2, 1) uint8 0 False 0.749 0.744 1.008
(100, 512, 512) (2, 1) uint8 1 False 0.916 0.868 1.056
(100, 512, 512) (2, 1) uint8 2 False 3.545 1.502 2.361
(100, 512, 512) (2, 1) uint8 3 False 6.196 1.969 3.147
(100, 512, 512) (2, 1) uint8 4 False 10.281 2.513 4.091
(100, 512, 512) (2, 1) uint8 5 False 29.419 3.285 8.957
(100, 512, 512) (2, 1) uint8 2 True 6.335 2.964 2.137
(100, 512, 512) (2, 1) uint8 3 True 8.998 3.411 2.638
(100, 512, 512) (2, 1) uint8 4 True 14.985 4.834 3.100
(100, 512, 512) (2, 1) uint8 5 True 33.993 5.477 6.206

grlee77 and others added 7 commits March 19, 2026 11:03
Ports the precision fix that was contributed upstream to CuPy. See:
cupy/cupy#9769

This fix was also backported to CuPy's 14.x branch for a future v14 release
add batch axis support to affine transform

add batch_axes kwarg to map_coordinates

fix spline prefiltering for batch_axes case

replace _local_mean_weights computation with an ElementwiseKernel

add test cases for channel/batch axis handling
When the batch axis is the last (contiguous) dimension, reduce kernel
launch overhead by having each thread process all batch elements in a
loop instead of launching one thread per output pixel.

Key changes:
- Add loop_batch_axis parameter to _generate_interp_custom and coord funcs
- Use raw output with explicit size for looped kernels
- Fix coords indexing in _get_coord_map to use full output index
- Return KernelInfo namedtuple from kernel getters with optional size

This improves performance for RGB/multichannel image transformations
where channel_axis=-1 (e.g., resize, rotate, warp_polar with RGB images).

fix bug in 'warp' edge mode (not used by cuCIM itself)

use typing.NamedTuple instead of collections.namedtuple

fix edge case for 1D data: only enable loop_batch_axis if there is at least 1 spatial axis

port float_type changes from upstream
@grlee77 grlee77 added this to the v26.06.00 milestone Mar 19, 2026
@grlee77 grlee77 requested review from a team as code owners March 19, 2026 15:49
@grlee77 grlee77 requested a review from gforsyth March 19, 2026 15:49
@grlee77 grlee77 added the non-breaking Introduces a non-breaking change label Mar 19, 2026
@grlee77 grlee77 added this to cucim Mar 19, 2026
@grlee77 grlee77 added performance Performance improvement improvement Improves an existing functionality labels Mar 19, 2026
@grlee77 grlee77 self-assigned this Mar 19, 2026
Copy link
Copy Markdown
Contributor

@gforsyth gforsyth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with context for future self and others.

ci-codeowners is requested here for the additional files excluded from the verify-copyright pre-commit check. Approving as the files in question are vendored from upstream and so shouldn't have an NVIDIA copyright applied.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change performance Performance improvement

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants