1- use cubecl:: frontend:: TensorHandleRef ;
2- use cubecl_core:: { self as cubecl, prelude:: * } ;
3- use cubecl_std:: tensor:: TensorHandle ;
4- use std:: collections:: HashSet ;
5- use std:: env;
6- use std:: sync:: { LazyLock , Mutex } ;
1+
72
83// ================================
94// Constants & tuning parameters
105// ================================
116
12- /// Tile size optimized for 4-element vectorized loads (mov4)
13- const TILE_SIZE_MOV4 : u32 = 32 ;
14- /// Tile size optimized for 2-element vectorized loads (mov2)
15- const TILE_SIZE_MOV2 : u32 = 64 ;
16- /// Number of threads per tile column for cooperative loading
17- const BLOCK_ROWS : u32 = 8 ;
187
198// ===========================================================
209// Host-side utility functions for shape and stride calculations
2110// ===========================================================
2211
23- /// Compute output shape after applying permutation `axes` to `input_shape`.
24- ///
25- /// Example: `infer_output_shape(&[2,3,4], &[1,0,2])` returns `[3,2,4]`
26- fn infer_output_shape ( input_shape : & [ usize ] , axes : & [ usize ] ) -> Vec < usize > {
27- assert_eq ! (
28- axes. len( ) ,
29- input_shape. len( ) ,
30- "axes length must match input shape"
31- ) ;
32- axes. iter ( ) . map ( |& a| input_shape[ a] ) . collect ( )
33- }
12+
3413
3514/// Extract (batch, height, width) dimensions for batch transpose kernels.
3615///
@@ -57,162 +36,10 @@ fn infer_batch_transpose_shape(input_shape: &[usize], _axes: &[usize]) -> (u32,
5736 }
5837}
5938
60- /// Result of dimension folding optimization
61- #[ derive( Debug , Clone ) ]
62- struct FoldedPermutation {
63- /// Folded shape (lower rank, merged contiguous dims)
64- folded_shape : Vec < usize > ,
65- /// Permutation in terms of folded dimensions
66- folded_axes : Vec < usize > ,
67- }
68-
69- /// Fold contiguous dimensions to simplify permutation.
70- ///
71- /// This is a CRITICAL optimization that can turn complex high-rank permutations
72- /// into simple 2D transposes.
73- ///
74- /// Algorithm:
75- /// 1. Identify runs of dimensions that are contiguous in memory (stride[i] == stride[i+1] * shape[i+1])
76- /// 2. Merge those dimensions by multiplying their sizes
77- /// 3. Update the axes permutation to work on the folded dimensions
78- ///
79- /// Example:
80- /// - Input: shape=[8, 16, 32, 64], strides=[32768, 2048, 64, 1], axes=[0, 3, 2, 1]
81- /// - Last two dims are contiguous: stride[2]=64 == stride[3]*shape[3] = 1*64
82- /// - Fold into: shape=[8, 16, 2048], strides=[32768, 2048, 1], axes=[0, 2, 1]
83- /// - Now it's a simple 3D batch transpose!
84- fn fold_contiguous_dimensions (
85- input_shape : & [ usize ] ,
86- input_strides : & [ usize ] ,
87- axes : & [ usize ] ,
88- ) -> FoldedPermutation {
89- let rank = input_shape. len ( ) ;
90-
91- if rank <= 1 {
92- return FoldedPermutation {
93- folded_shape : input_shape. to_vec ( ) ,
94- folded_axes : axes. to_vec ( ) ,
95- } ;
96- }
97-
98- // Find contiguous runs in the INPUT tensor
99- // A run is contiguous if stride[i] == stride[i+1] * shape[i+1]
100- let mut is_contiguous_with_next = vec ! [ false ; rank] ;
101- for i in 0 ..rank - 1 {
102- is_contiguous_with_next[ i] = input_strides[ i] == input_strides[ i + 1 ] * input_shape[ i + 1 ] ;
103- }
104-
105- // Build folded dimensions by merging contiguous runs
106- let mut folded_shape = Vec :: new ( ) ;
107- let mut old_to_new_axis = vec ! [ 0usize ; rank] ; // Maps old axis index to folded axis index
108-
109- let mut i = 0 ;
110- while i < rank {
111- let start = i;
112-
113- // Extend run while contiguous
114- while i < rank - 1 && is_contiguous_with_next[ i] {
115- i += 1 ;
116- }
117-
118- // Merge dimensions [start..=i]
119- let merged_size: usize = ( start..=i) . map ( |j| input_shape[ j] ) . product ( ) ;
120- folded_shape. push ( merged_size) ;
121-
122- // All axes in this run map to the same folded axis
123- let folded_idx = folded_shape. len ( ) - 1 ;
124- for item in old_to_new_axis. iter_mut ( ) . take ( i + 1 ) . skip ( start) {
125- * item = folded_idx;
126- }
12739
128- i += 1 ;
129- }
130-
131- // Now we need to check if the PERMUTATION preserves contiguous runs
132- // If axes permutes within a folded group, we can't use the folding
133- // Example: if dims 2,3 were folded but axes=[0,1,3,2], we can't fold
134- // Also: if dims are folded but get REORDERED, we can't fold (e.g., axes=[1,0] for 2D)
135-
136- // Check if axes respects folded groups
137- let mut axes_respects_folding = true ;
138- for fold_idx in 0 ..folded_shape. len ( ) {
139- // Find all old axes that map to this folded axis
140- let old_axes_in_group: Vec < usize > = ( 0 ..rank)
141- . filter ( |& i| old_to_new_axis[ i] == fold_idx)
142- . collect ( ) ;
143-
144- if old_axes_in_group. len ( ) > 1 {
145- // Check if these axes appear in the SAME ORDER in the permutation
146- // Find their positions in the axes array
147- let mut positions: Vec < usize > = old_axes_in_group
148- . iter ( )
149- . map ( |& old_ax| axes. iter ( ) . position ( |& a| a == old_ax) . unwrap ( ) )
150- . collect ( ) ;
151-
152- // They must be consecutive and in ascending order
153- // This ensures the folded group stays together and in order
154- positions. sort_unstable ( ) ;
155- for j in 0 ..positions. len ( ) - 1 {
156- if positions[ j] + 1 != positions[ j + 1 ] {
157- axes_respects_folding = false ;
158- break ;
159- }
160- }
161-
162- // Verify axes are in ascending order at those positions.
163- // Example: for axes=[1,0], positions=[0,1] but old_axes_in_group=[0,1],
164- // we need axes[positions[0]] < axes[positions[1]]
165- if axes_respects_folding {
166- for j in 0 ..old_axes_in_group. len ( ) - 1 {
167- let pos_j = axes
168- . iter ( )
169- . position ( |& a| a == old_axes_in_group[ j] )
170- . unwrap ( ) ;
171- let pos_jp1 = axes
172- . iter ( )
173- . position ( |& a| a == old_axes_in_group[ j + 1 ] )
174- . unwrap ( ) ;
175- if pos_j > pos_jp1 {
176- // Axes are reversed or out of order within the group - folding not possible
177- axes_respects_folding = false ;
178- break ;
179- }
180- }
181- }
182- }
183- }
184-
185- if !axes_respects_folding {
186- // Folding would produce incorrect results - return original dimensions
187- return FoldedPermutation {
188- folded_shape : input_shape. to_vec ( ) ,
189- folded_axes : axes. to_vec ( ) ,
190- } ;
191- }
192-
193- // Build folded axes: for each position in axes, find which folded group it belongs to
194- // and use the first axis from that group
195- let mut folded_axes = Vec :: new ( ) ;
196- let mut seen_folded = vec ! [ false ; folded_shape. len( ) ] ;
197-
198- for & ax in axes {
199- let folded_idx = old_to_new_axis[ ax] ;
200- if !seen_folded[ folded_idx] {
201- folded_axes. push ( folded_idx) ;
202- seen_folded[ folded_idx] = true ;
203- }
204- }
205-
206- FoldedPermutation {
207- folded_shape,
208- folded_axes,
209- }
210- }
21140// This is the beginning of the permute.rs code
21241// All the kernels are here
21342// ...
21443// ... all the way to the end
21544// ...
216- fn main ( ) {
217- println ! ( "This example is empty. The permute code is in lib.rs" ) ;
218- }
45+
0 commit comments