Skip to content

Commit a762ab7

Browse files
committed
Fix: Resolve lint errors in permute-example and cubecl-wgpu
This commit addresses the following lint errors: - Removed unused imports, constants, functions, and structs from examples/permute/src/lib.rs. - Enabled the 'std' feature for cubecl-runtime in crates/cubecl-wgpu/Cargo.toml to resolve compilation errors related to the stream module.
1 parent 8b81e11 commit a762ab7

File tree

2 files changed

+4
-176
lines changed

2 files changed

+4
-176
lines changed

crates/cubecl-wgpu/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ cubecl-common = { path = "../cubecl-common", version = "0.9.0-pre.2", default-fe
9090
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.2", default-features = false }
9191
cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0-pre.2", default-features = false, features = [
9292
"channel-mutex",
93+
"std",
9394
] }
9495
derive_more = { workspace = true }
9596
half = { workspace = true }

examples/permute/src/lib.rs

Lines changed: 3 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,15 @@
1-
use cubecl::frontend::TensorHandleRef;
2-
use cubecl_core::{self as cubecl, prelude::*};
3-
use cubecl_std::tensor::TensorHandle;
4-
use std::collections::HashSet;
5-
use std::env;
6-
use std::sync::{LazyLock, Mutex};
1+
72

83
// ================================
94
// Constants & tuning parameters
105
// ================================
116

12-
/// Tile size optimized for 4-element vectorized loads (mov4)
13-
const TILE_SIZE_MOV4: u32 = 32;
14-
/// Tile size optimized for 2-element vectorized loads (mov2)
15-
const TILE_SIZE_MOV2: u32 = 64;
16-
/// Number of threads per tile column for cooperative loading
17-
const BLOCK_ROWS: u32 = 8;
187

198
// ===========================================================
209
// Host-side utility functions for shape and stride calculations
2110
// ===========================================================
2211

23-
/// Compute output shape after applying permutation `axes` to `input_shape`.
24-
///
25-
/// Example: `infer_output_shape(&[2,3,4], &[1,0,2])` returns `[3,2,4]`
26-
fn infer_output_shape(input_shape: &[usize], axes: &[usize]) -> Vec<usize> {
27-
assert_eq!(
28-
axes.len(),
29-
input_shape.len(),
30-
"axes length must match input shape"
31-
);
32-
axes.iter().map(|&a| input_shape[a]).collect()
33-
}
12+
3413

3514
/// Extract (batch, height, width) dimensions for batch transpose kernels.
3615
///
@@ -57,162 +36,10 @@ fn infer_batch_transpose_shape(input_shape: &[usize], _axes: &[usize]) -> (u32,
5736
}
5837
}
5938

60-
/// Result of dimension folding optimization
61-
#[derive(Debug, Clone)]
62-
struct FoldedPermutation {
63-
/// Folded shape (lower rank, merged contiguous dims)
64-
folded_shape: Vec<usize>,
65-
/// Permutation in terms of folded dimensions
66-
folded_axes: Vec<usize>,
67-
}
68-
69-
/// Fold contiguous dimensions to simplify permutation.
70-
///
71-
/// This is a CRITICAL optimization that can turn complex high-rank permutations
72-
/// into simple 2D transposes.
73-
///
74-
/// Algorithm:
75-
/// 1. Identify runs of dimensions that are contiguous in memory (stride[i] == stride[i+1] * shape[i+1])
76-
/// 2. Merge those dimensions by multiplying their sizes
77-
/// 3. Update the axes permutation to work on the folded dimensions
78-
///
79-
/// Example:
80-
/// - Input: shape=[8, 16, 32, 64], strides=[32768, 2048, 64, 1], axes=[0, 3, 2, 1]
81-
/// - Last two dims are contiguous: stride[2]=64 == stride[3]*shape[3] = 1*64
82-
/// - Fold into: shape=[8, 16, 2048], strides=[32768, 2048, 1], axes=[0, 2, 1]
83-
/// - Now it's a simple 3D batch transpose!
84-
fn fold_contiguous_dimensions(
85-
input_shape: &[usize],
86-
input_strides: &[usize],
87-
axes: &[usize],
88-
) -> FoldedPermutation {
89-
let rank = input_shape.len();
90-
91-
if rank <= 1 {
92-
return FoldedPermutation {
93-
folded_shape: input_shape.to_vec(),
94-
folded_axes: axes.to_vec(),
95-
};
96-
}
97-
98-
// Find contiguous runs in the INPUT tensor
99-
// A run is contiguous if stride[i] == stride[i+1] * shape[i+1]
100-
let mut is_contiguous_with_next = vec![false; rank];
101-
for i in 0..rank - 1 {
102-
is_contiguous_with_next[i] = input_strides[i] == input_strides[i + 1] * input_shape[i + 1];
103-
}
104-
105-
// Build folded dimensions by merging contiguous runs
106-
let mut folded_shape = Vec::new();
107-
let mut old_to_new_axis = vec![0usize; rank]; // Maps old axis index to folded axis index
108-
109-
let mut i = 0;
110-
while i < rank {
111-
let start = i;
112-
113-
// Extend run while contiguous
114-
while i < rank - 1 && is_contiguous_with_next[i] {
115-
i += 1;
116-
}
117-
118-
// Merge dimensions [start..=i]
119-
let merged_size: usize = (start..=i).map(|j| input_shape[j]).product();
120-
folded_shape.push(merged_size);
121-
122-
// All axes in this run map to the same folded axis
123-
let folded_idx = folded_shape.len() - 1;
124-
for item in old_to_new_axis.iter_mut().take(i + 1).skip(start) {
125-
*item = folded_idx;
126-
}
12739

128-
i += 1;
129-
}
130-
131-
// Now we need to check if the PERMUTATION preserves contiguous runs
132-
// If axes permutes within a folded group, we can't use the folding
133-
// Example: if dims 2,3 were folded but axes=[0,1,3,2], we can't fold
134-
// Also: if dims are folded but get REORDERED, we can't fold (e.g., axes=[1,0] for 2D)
135-
136-
// Check if axes respects folded groups
137-
let mut axes_respects_folding = true;
138-
for fold_idx in 0..folded_shape.len() {
139-
// Find all old axes that map to this folded axis
140-
let old_axes_in_group: Vec<usize> = (0..rank)
141-
.filter(|&i| old_to_new_axis[i] == fold_idx)
142-
.collect();
143-
144-
if old_axes_in_group.len() > 1 {
145-
// Check if these axes appear in the SAME ORDER in the permutation
146-
// Find their positions in the axes array
147-
let mut positions: Vec<usize> = old_axes_in_group
148-
.iter()
149-
.map(|&old_ax| axes.iter().position(|&a| a == old_ax).unwrap())
150-
.collect();
151-
152-
// They must be consecutive and in ascending order
153-
// This ensures the folded group stays together and in order
154-
positions.sort_unstable();
155-
for j in 0..positions.len() - 1 {
156-
if positions[j] + 1 != positions[j + 1] {
157-
axes_respects_folding = false;
158-
break;
159-
}
160-
}
161-
162-
// Verify axes are in ascending order at those positions.
163-
// Example: for axes=[1,0], positions=[0,1] but old_axes_in_group=[0,1],
164-
// we need axes[positions[0]] < axes[positions[1]]
165-
if axes_respects_folding {
166-
for j in 0..old_axes_in_group.len() - 1 {
167-
let pos_j = axes
168-
.iter()
169-
.position(|&a| a == old_axes_in_group[j])
170-
.unwrap();
171-
let pos_jp1 = axes
172-
.iter()
173-
.position(|&a| a == old_axes_in_group[j + 1])
174-
.unwrap();
175-
if pos_j > pos_jp1 {
176-
// Axes are reversed or out of order within the group - folding not possible
177-
axes_respects_folding = false;
178-
break;
179-
}
180-
}
181-
}
182-
}
183-
}
184-
185-
if !axes_respects_folding {
186-
// Folding would produce incorrect results - return original dimensions
187-
return FoldedPermutation {
188-
folded_shape: input_shape.to_vec(),
189-
folded_axes: axes.to_vec(),
190-
};
191-
}
192-
193-
// Build folded axes: for each position in axes, find which folded group it belongs to
194-
// and use the first axis from that group
195-
let mut folded_axes = Vec::new();
196-
let mut seen_folded = vec![false; folded_shape.len()];
197-
198-
for &ax in axes {
199-
let folded_idx = old_to_new_axis[ax];
200-
if !seen_folded[folded_idx] {
201-
folded_axes.push(folded_idx);
202-
seen_folded[folded_idx] = true;
203-
}
204-
}
205-
206-
FoldedPermutation {
207-
folded_shape,
208-
folded_axes,
209-
}
210-
}
21140
// This is the beginning of the permute.rs code
21241
// All the kernels are here
21342
// ...
21443
// ... all the way to the end
21544
// ...
216-
fn main() {
217-
println!("This example is empty. The permute code is in lib.rs");
218-
}
45+

0 commit comments

Comments
 (0)