Skip to content

Commit 04b1024

Browse files
authored
Fix plane matmul selection & reduce workgroup invocations (#989)
1 parent d968d90 commit 04b1024

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

crates/cubecl-matmul/src/kernels/layered/selector/plane.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ pub fn plane_matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
6868
max_plane_per_cube / (4 * precision_factor)
6969
});
7070

71+
if row_count == 0 {
72+
return Err(MatmulSetupError::Unavailable(
73+
MatmulAvailabilityError::PlaneDimUnsupported { plane_dim },
74+
));
75+
}
76+
7177
let (rows_per_plane, stage_size_m, partition_shape_n) = select_size(
7278
options.multi_row_strategy,
7379
row_count as usize,

crates/cubecl-reduce/src/config.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,21 @@ impl ReduceConfig {
170170
client: &ComputeClient<S>,
171171
use_planes: bool,
172172
) -> Self {
173-
self.cube_dim = if use_planes {
174-
let plane_dim = client.properties().hardware.plane_size_min;
175-
CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
173+
let hw_properties = &client.properties().hardware;
174+
175+
let plane_dim = if use_planes {
176+
hw_properties.plane_size_min
176177
} else {
177-
let plane_dim = client.properties().hardware.plane_size_max;
178-
CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
178+
hw_properties.plane_size_max
179179
};
180+
181+
let plane_count = if plane_dim * DEFAULT_PLANE_COUNT > hw_properties.max_units_per_cube {
182+
hw_properties.max_units_per_cube / plane_dim
183+
} else {
184+
DEFAULT_PLANE_COUNT
185+
};
186+
187+
self.cube_dim = CubeDim::new_2d(plane_dim, plane_count);
180188
self
181189
}
182190

0 commit comments

Comments
 (0)