Skip to content

Commit caacb7b

Browse files
Fix/perf line size (#953)
1 parent a26fce4 commit caacb7b

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

crates/cubecl-core/src/runtime.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
3838
/// Returns the supported line sizes for the current runtime's compiler.
3939
fn supported_line_sizes() -> &'static [u8];
4040

41+
/// The maximum line size that can be used for global buffer bindings.
42+
fn max_global_line_size() -> u8 {
43+
u8::MAX
44+
}
45+
4146
/// Returns all line sizes that are useful to perform optimal IO operation on the given element.
4247
fn io_optimized_line_sizes(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
4348
let max = (LOAD_WIDTH / elem.size_bits()) as u8;
@@ -50,7 +55,12 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
5055
/// unrolled, and may not support dynamic indexing.
5156
fn io_optimized_line_sizes_unchecked(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
5257
let max = LOAD_WIDTH / elem.size_bits();
53-
(1..max as u8).rev().filter(|v| v.is_power_of_two())
58+
let max = usize::min(Self::max_global_line_size() as usize, max);
59+
60+
// If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.
61+
let num_candidates = f32::log2(max as f32) as u32 + 1;
62+
63+
(0..num_candidates).map(|i| 2u8.pow(i)).rev()
5464
}
5565

5666
/// Returns the maximum cube count on each dimension that can be launched.

crates/cubecl-wgpu/src/runtime.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ impl Runtime for WgpuRuntime {
7474
}
7575
}
7676

77+
fn max_global_line_size() -> u8 {
78+
4
79+
}
80+
7781
fn max_cube_count() -> (u32, u32, u32) {
7882
let max_dim = u16::MAX as u32;
7983
(max_dim, max_dim, max_dim)

crates/cubecl/benches/matmul.rs

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use core::marker::PhantomData;
22
use cubecl::prelude::*;
3+
use cubecl_matmul::AsyncReadingStrategy;
34
use cubecl_matmul::components::batch::HypercubeSelection;
45
use cubecl_matmul::components::stage::PartitionBuffering;
56
use cubecl_matmul::components::{
@@ -11,17 +12,10 @@ use cubecl_matmul::kernels::layered::double_unit::DoubleUnitSelectionArgs;
1112
use cubecl_matmul::kernels::layered::ordered_double_buffering::OrderedSelectionArgs;
1213
use cubecl_matmul::kernels::layered::simple::SimpleArgs;
1314
use cubecl_matmul::kernels::layered::simple_unit::SimpleUnitSelectionArgs;
14-
use cubecl_matmul::kernels::layered::{
15-
MatmulSelection, MultiRowStrategy, Selection, TileSizeSelection, closest_factor_pair,
16-
};
1715
use cubecl_matmul::kernels::layered::{Selection, TileSizeSelection};
18-
use cubecl_matmul::{self as matmul};
1916
use cubecl_matmul::{
2017
self as matmul, MatmulInputHandle, SyncPartialReadingStrategy, SyncReadingStrategy,
2118
};
22-
use cubecl_matmul::{self as matmul, SyncPartialReadingStrategy, SyncReadingStrategy};
23-
use cubecl_matmul::{AsyncReadingStrategy, components::MatmulPrecision};
24-
use cubecl_matmul::{SyncPartialReadingStrategy, SyncReadingStrategy};
2519
use std::collections::BTreeMap;
2620
use std::time::Duration;
2721

@@ -98,8 +92,8 @@ impl<R: Runtime, MP: MatmulPrecision> Benchmark for MatmulBench<R, MP> {
9892
matmul_elems.rhs_global,
9993
matmul_elems.rhs_stage,
10094
matmul_elems.rhs_register,
101-
matmul_elems.acc,
102-
matmul_elems.out,
95+
matmul_elems.acc_register,
96+
matmul_elems.acc_global,
10397
self.strategy
10498
)
10599
.to_lowercase()
@@ -145,13 +139,13 @@ fn entry(m: usize, n: usize, k: usize) -> (usize, usize, usize, usize) {
145139
#[allow(dead_code)]
146140
fn run<R: Runtime, MP: MatmulPrecision>(device: R::Device, strategy: matmul::Strategy) {
147141
for tl in [false] {
148-
for tr in [true] {
142+
for tr in [false] {
149143
for (b, m, n, k) in [
150144
// entry(8192, 8192, 8192),
151-
// entry(6144, 6144, 6144),
145+
entry(6144, 6144, 6144),
152146
// entry(4096, 4096, 4096),
153147
// entry(2048, 2048, 2048),
154-
entry(1024, 1024, 1024),
148+
// entry(1024, 1024, 1024),
155149
// entry(512, 512, 512),
156150
// entry(64, 1024, 64),
157151
// entry(32, 1024, 32),
@@ -397,7 +391,7 @@ fn run_algos_wmma<R: Runtime, MP: MatmulPrecision>() {
397391
#[allow(unused)]
398392
fn run_benches<R: Runtime, MP: MatmulPrecision>() {
399393
// run_grid_search::<R, MP>();
400-
run_algos_unit::<R, MP>();
394+
// run_algos_unit::<R, MP>();
401395
run_algos_wmma::<R, MP>();
402396
// run_algos_vecmat::<R, MP>();
403397
}

0 commit comments

Comments
 (0)