Skip to content

Commit b3eb8fe

Browse files
authored
fix: Quant matmul line sizes (#978)
1 parent f1e0f12 commit b3eb8fe

File tree

30 files changed

+143
-69
lines changed

30 files changed

+143
-69
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ use crate::{
1414

1515
use crate::components::batch::BatchAttentionConfig;
1616
use crate::components::batch::BatchAttentionFamily;
17-
use cubecl_core::frontend::CubePrimitive;
1817

1918
pub enum Strategy {
2019
/// Temporary implementation
@@ -66,9 +65,9 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
6665
out: &TensorHandleRef<R>,
6766
) -> Result<(), AttentionSetupError> {
6867
let line_sizes = AvailableLineSizes::from_elem_types::<R>(
69-
&QG::<AP>::as_type_native_unchecked(),
70-
&MSK::<AP>::as_type_native_unchecked(),
71-
&OG::<AP>::as_type_native_unchecked(),
68+
query.elem_size,
69+
size_of::<MSK<AP>>(),
70+
out.elem_size,
7271
);
7372
let line_sizes = DummyRegisterAlgorithm::filter_line_sizes(line_sizes)
7473
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)

crates/cubecl-attention/src/components/line_size.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Debug;
22

3-
use cubecl_core::{LineSizeError, Runtime, ir::StorageType, tensor_line_size_parallel};
3+
use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel};
44

55
use crate::components::{AttentionIdent, AttentionSetupError};
66

@@ -29,11 +29,7 @@ pub struct AvailableLineSizes {
2929
}
3030

3131
impl AvailableLineSizes {
32-
pub fn from_elem_types<R: Runtime>(
33-
elem_in: &StorageType,
34-
elem_mask: &StorageType,
35-
elem_out: &StorageType,
36-
) -> Self {
32+
pub fn from_elem_types<R: Runtime>(elem_in: usize, elem_mask: usize, elem_out: usize) -> Self {
3733
let in_available: Vec<u8> = R::io_optimized_line_sizes_unchecked(elem_in).collect();
3834
let mask_available: Vec<u8> = R::io_optimized_line_sizes_unchecked(elem_mask).collect();
3935
let out_available = R::io_optimized_line_sizes_unchecked(elem_out).collect();

crates/cubecl-attention/src/tests/attention_test_launcher.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ pub fn test_attention_algorithm<A, P, R>(
5858
let out = tensor_raw_parts_output::<P, R>(&client, &problem);
5959

6060
let line_sizes = AvailableLineSizes::from_elem_types::<R>(
61-
&P::EG::as_type_native_unchecked(),
62-
&P::EM::as_type_native_unchecked(),
63-
&P::EG::as_type_native_unchecked(),
61+
size_of::<P::EG>(),
62+
size_of::<P::EM>(),
63+
size_of::<P::EG>(),
6464
);
6565
let line_sizes = A::filter_line_sizes(line_sizes);
6666
let line_sizes = line_sizes

crates/cubecl-convolution/src/components/stage/reader.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use cubecl::prelude::*;
22
use cubecl_core as cubecl;
33
use cubecl_matmul::components::{
4-
MatrixLayout, StageIdent,
5-
stage::{StageMemoryConfig, StridedStage, TilingLayout},
4+
InvalidConfigError, MatrixLayout, StageIdent,
5+
global::memory::GlobalMemoryConfig,
6+
stage::{StageMemoryConfig, StridedStage, TilingLayout, TilingValidation},
67
tile::StridedTile,
78
};
89
use cubecl_std::tensor::layout::Coords2d;
@@ -39,3 +40,13 @@ impl TilingLayout for BiasTilingLayout {
3940
)
4041
}
4142
}
43+
44+
impl TilingValidation for BiasTilingLayout {
45+
fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> {
46+
let stage_width = config.elements_in_stage_col;
47+
if config.global_line_size > stage_width {
48+
return Err(Box::new("Invalid line size"));
49+
}
50+
Ok(())
51+
}
52+
}

crates/cubecl-convolution/src/launch.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ where
181181
Input<Alg, MP>: ConcreteInputsFactory,
182182
Output<Alg, MP>: ConcreteOutputFactory,
183183
{
184-
let line_sizes = AvailableLineSizes::from_types::<R>(
185-
&LhsG::<MP>::as_type_native_unchecked(),
186-
&RhsG::<MP>::as_type_native_unchecked(),
187-
&AccG::<MP>::as_type_native_unchecked(),
184+
let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
185+
input.data().elem_size,
186+
weight.data().elem_size,
187+
out.elem_size,
188188
)
189189
.filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
190190
.filter_rhs_with_tensor(

crates/cubecl-convolution/src/tests/convolution_test_launcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
5656
let line_sizes = AvailableLineSizes {
5757
lhs: vec![1],
5858
rhs: vec![1],
59-
out: R::io_optimized_line_sizes_unchecked(&P::EG::as_type_native_unchecked()).collect(),
59+
out: R::io_optimized_line_sizes_unchecked(size_of::<P::EG>()).collect(),
6060
}
6161
.filter_lhs_with_tensor(&lhs.strides, &lhs.shape, problem.lhs_layout)
6262
.filter_rhs_with_tensor(&rhs.strides, &rhs.shape, problem.rhs_layout)

crates/cubecl-core/src/frontend/element/cube_elem.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ pub trait CubePrimitive:
4545
Self::as_type_native().map(|t| t.size_bits())
4646
}
4747

48+
/// Only native element types have a size.
49+
fn size_bits_unchecked() -> usize {
50+
Self::as_type_native_unchecked().size_bits()
51+
}
52+
4853
fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType {
4954
ExpandElementTyped::new(elem)
5055
}

crates/cubecl-core/src/runtime.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
5050
/// Returns all line sizes that are useful to perform optimal IO operation on the given element.
5151
/// Ignores native support, and allows all line sizes. This means the returned size may be
5252
/// unrolled, and may not support dynamic indexing.
53-
fn io_optimized_line_sizes_unchecked(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
54-
let max = LOAD_WIDTH / elem.size_bits();
53+
fn io_optimized_line_sizes_unchecked(size: usize) -> impl Iterator<Item = u8> + Clone {
54+
let size_bits = size * 8;
55+
let max = LOAD_WIDTH / size_bits;
5556
let max = usize::min(Self::max_global_line_size() as usize, max);
5657

5758
// If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1.

crates/cubecl-cpu/src/runtime.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ impl Runtime for CpuRuntime {
111111
supported.iter().filter(move |v| **v <= max).cloned()
112112
}
113113

114-
fn io_optimized_line_sizes_unchecked(elem: &StorageType) -> impl Iterator<Item = u8> + Clone {
115-
let max = LOAD_WIDTH / elem.size_bits();
114+
fn io_optimized_line_sizes_unchecked(elem_size: usize) -> impl Iterator<Item = u8> + Clone {
115+
let elem_size_bits = elem_size * 8;
116+
let max = LOAD_WIDTH / elem_size_bits;
116117
(1..max as u8).rev().filter(|v| v.is_power_of_two())
117118
}
118119

crates/cubecl-matmul/src/components/global/args.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,8 @@ impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
157157
);
158158
let data_view =
159159
ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
160-
let scales_view = ViewArg::new::<BatchedGlobalScaleLayout>(
161-
scale.as_array_arg(line_size),
162-
scales_layout,
163-
);
160+
let scales_view =
161+
ViewArg::new::<BatchedGlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
164162
ViewArg::new_quantized(data_view, scales_view, **scheme)
165163
}
166164
};

0 commit comments

Comments
 (0)