Skip to content

Conversation

@huy209vn
Copy link
Contributor

@huy209vn huy209vn commented Nov 7, 2025

High-Performance Permute / Transpose Kernels

This PR introduces a new tensor::permute module with fast transpose paths and a generic fallback:

  • Dimension folding merges contiguous dims, reducing many N-D permutations to simpler 2D/3D cases.

  • Tiled transpose kernels for [H,W] and [B,H,W] → [B,W,H] using shared memory tiles with padding (bank-conflict free).

    Adaptive tile sizing: 16×16 for small batches (≤4), 32×32 for larger batches
    Plane Shuffle Transpose - Ultra-fast warp-level shuffle for tiny matrices (≤32 elements)
    Zero shared memory, zero barriers
    Perfect for 4×4, 4×8, 8×4 cases
    Channel Shuffle - Specialized kernel for NCHW → NHWC (axes [0, 2, 3, 1])
    Common in computer vision workloads
    Attention Transpose - Optimized for [B, H, N, D] → [B, N, H, D] (axes [0, 2, 1, 3])
    Standard pattern in multi-head attention

  • Automatic kernel selection:

    • Use tiled transpose for [1,0] and [0,2,1] when matrix dims are large.

    • Otherwise fall back to the generic stride-mapped kernel.

  • Optional vectorization (mov2/mov4), disabled by default. Enable via:

    CUBECL_VECTORIZE_TRANSPOSE=1
  • Full correctness tests and benchmarks included.


Performance (CUDA, RTX 3090)

=== CUDA PERMUTE/TRANSPOSE BENCHMARK ===

TEST: 2D Transpose [1024, 1024] axes=[1,0]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.013 629.99 512 tiled_transpose .
F16 0.016 263.95 1024 tiled_transpose .

TEST: 2D Transpose [4096, 4096] axes=[1,0]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.175 765.76 32 tiled_transpose .
F16 0.111 602.79 64 tiled_transpose .

TEST: Batch Transpose [32, 1024, 1024] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.338 794.61 20 tiled_transpose .
F16 0.218 614.74 32 tiled_transpose .

TEST: Batch Transpose [16, 1024, 1024] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.171 784.30 32 tiled_transpose .
F16 0.111 605.03 64 tiled_transpose .

TEST: Batch Transpose [8, 1024, 1024] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.087 772.43 64 tiled_transpose .
F16 0.055 609.32 128 tiled_transpose .

TEST: Batch Transpose [4, 1024, 1024] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.052 643.56 128 tiled_transpose .
F16 0.033 509.56 256 tiled_transpose .

TEST: Batch Transpose [1, 1024, 1024] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.017 495.25 512 tiled_transpose .
F16 0.034 123.82 1024 tiled_transpose .

TEST: Batch Transpose [32, 512, 512] axes=[0,2,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.088 766.60 64 tiled_transpose .
F16 0.053 637.85 128 tiled_transpose .

TEST: Complex Permute [128, 64, 64] axes=[2,0,1]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.035 119.83 1024 naive_permute (fallback)
F16 0.035 59.70 2048 naive_permute (fallback)

TEST: - Channel Shuffle [32, 256, 56, 56] axes=[0,2,3,1] NCHW→NHWC

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.249 825.08 32 channel_shuffle_nchw_to_nhwc ()
F16 0.170 602.81 64 channel_shuffle_nchw_to_nhwc ()

TEST: - Attention Transpose [8, 32, 512, 64] axes=[0,2,1,3]

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.298 224.83 64 attention_transpose_kernel ()
F16 0.292 114.79 128 attention_transpose_kernel ()

TEST: PHASE 4 - Plane Shuffle [4, 4] axes=[1,0] (16 elem)

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.010 0.01 16384 plane_shuffle_transpose (Phase 4)
F16 0.010 0.01 16384 plane_shuffle_transpose (Phase 4)

TEST: PHASE 4 - Plane Shuffle [4, 8] axes=[1,0] (32 elem)

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.011 0.02 16384 plane_shuffle_transpose (Phase 4)
F16 0.011 0.01 16384 plane_shuffle_transpose (Phase 4)

TEST: PHASE 4 - Plane Shuffle [8, 4] axes=[1,0] (32 elem)

Type Time(ms) Bandwidth(GB/s) Iters Kernel
F32 0.010 0.02 16384 plane_shuffle_transpose (Phase 4)
F16 0.012 0.01 16384 plane_shuffle_transpose (Phase 4)

Summary

  • Tiled transpose achieves ~700–830 GB/s, near memory bandwidth limits.

  • Batch transpose scales smoothly by adapting tile size to maintain occupancy.

  • Arbitrary permutations use the generic fallback kernel (slower but always correct).

@huy209vn huy209vn force-pushed the feature/transpose-permute branch 13 times, most recently from b632097 to 1d3f06e Compare November 14, 2025 07:20
@huy209vn huy209vn force-pushed the feature/transpose-permute branch from c664819 to c79a602 Compare November 14, 2025 07:30
@huy209vn huy209vn force-pushed the feature/transpose-permute branch from 47ce546 to e8b171f Compare November 14, 2025 08:44
@nathanielsimard
Copy link
Member

Just to be clear, we're not going to merge kernels to permute and transpose tensors; we already have those included with fusion and reshape on Burn. A kernel like that is more appropriate as an example for CubeCL.

@huy209vn huy209vn force-pushed the feature/transpose-permute branch from 68a5b7e to ee15be1 Compare November 27, 2025 14:56
@huy209vn
Copy link
Contributor Author

there i put it in example.

The permute example had compilation errors due to mismatched ComputeClient types
and missing trait bounds on Runtime. This commit fixes these by:
- Constraining the Runtime generic parameter to R: Runtime<Server = R> for all affected functions.
- Changing the client parameter type from &ComputeClient<R::Server> to &ComputeClient<R>
  in all affected functions, aligning with the new Runtime constraint.
- Addressing unused Result warnings by adding let _ = to launch_unchecked calls.
- Removing a redundant use cubecl; import to resolve a clippy warning.
- Includes a minor clippy fix in crates/cubecl-cuda/src/compute/server.rs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants