-
Notifications
You must be signed in to change notification settings - Fork 119
permute and tiled transpose kernel #1013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
huy209vn
wants to merge
18
commits into
tracel-ai:main
Choose a base branch
from
huy209vn:feature/transpose-permute
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
b632097 to
1d3f06e
Compare
c664819 to
c79a602
Compare
47ce546 to
e8b171f
Compare
Member
|
Just to be clear, we're not going to merge kernels to permute and transpose tensors; we already have those included with fusion and reshape on Burn. A kernel like that is more appropriate as an example for CubeCL. |
68a5b7e to
ee15be1
Compare
Contributor
Author
|
there i put it in example. |
The permute example had compilation errors due to mismatched ComputeClient types and missing trait bounds on Runtime. This commit fixes these by: - Constraining the Runtime generic parameter to R: Runtime<Server = R> for all affected functions. - Changing the client parameter type from &ComputeClient<R::Server> to &ComputeClient<R> in all affected functions, aligning with the new Runtime constraint. - Addressing unused Result warnings by adding let _ = to launch_unchecked calls. - Removing a redundant use cubecl; import to resolve a clippy warning. - Includes a minor clippy fix in crates/cubecl-cuda/src/compute/server.rs.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
High-Performance Permute / Transpose Kernels
This PR introduces a new
tensor::permutemodule with fast transpose paths and a generic fallback:Dimension folding merges contiguous dims, reducing many N-D permutations to simpler 2D/3D cases.
Tiled transpose kernels for
[H,W]and[B,H,W] → [B,W,H]using shared memory tiles with padding (bank-conflict free).Adaptive tile sizing: 16×16 for small batches (≤4), 32×32 for larger batches
Plane Shuffle Transpose - Ultra-fast warp-level shuffle for tiny matrices (≤32 elements)
Zero shared memory, zero barriers
Perfect for 4×4, 4×8, 8×4 cases
Channel Shuffle - Specialized kernel for NCHW → NHWC (axes [0, 2, 3, 1])
Common in computer vision workloads
Attention Transpose - Optimized for [B, H, N, D] → [B, N, H, D] (axes [0, 2, 1, 3])
Standard pattern in multi-head attention
Automatic kernel selection:
Use tiled transpose for
[1,0]and[0,2,1]when matrix dims are large.Otherwise fall back to the generic stride-mapped kernel.
Optional vectorization (
mov2/mov4), disabled by default. Enable via:Full correctness tests and benchmarks included.
Performance (CUDA, RTX 3090)
TEST: 2D Transpose [1024, 1024] axes=[1,0]
TEST: 2D Transpose [4096, 4096] axes=[1,0]
TEST: Batch Transpose [32, 1024, 1024] axes=[0,2,1]
TEST: Batch Transpose [16, 1024, 1024] axes=[0,2,1]
TEST: Batch Transpose [8, 1024, 1024] axes=[0,2,1]
TEST: Batch Transpose [4, 1024, 1024] axes=[0,2,1]
TEST: Batch Transpose [1, 1024, 1024] axes=[0,2,1]
TEST: Batch Transpose [32, 512, 512] axes=[0,2,1]
TEST: Complex Permute [128, 64, 64] axes=[2,0,1]
TEST: - Channel Shuffle [32, 256, 56, 56] axes=[0,2,3,1] NCHW→NHWC
TEST: - Attention Transpose [8, 32, 512, 64] axes=[0,2,1,3]
TEST: PHASE 4 - Plane Shuffle [4, 4] axes=[1,0] (16 elem)
TEST: PHASE 4 - Plane Shuffle [4, 8] axes=[1,0] (32 elem)
TEST: PHASE 4 - Plane Shuffle [8, 4] axes=[1,0] (32 elem)
Summary
Tiled transpose achieves ~700–830 GB/s, near memory bandwidth limits.
Batch transpose scales smoothly by adapting tile size to maintain occupancy.
Arbitrary permutations use the generic fallback kernel (slower but always correct).