Skip to content

Commit 30366c3

Browse files
authored
refactor: Refactor block size and fix burn issues (#942)
1 parent f407c26 commit 30366c3

File tree

18 files changed

+253
-165
lines changed

18 files changed

+253
-165
lines changed
Lines changed: 76 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
use core::{
2-
fmt::Display,
3-
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
4-
};
1+
use core::fmt::Display;
52

63
use bytemuck::{Pod, Zeroable};
7-
use float4::E8M0;
8-
use num_traits::{NumCast, ToPrimitive};
94

105
/// An 8-bit unsigned floating point type with 8 exponent bits and no mantissa bits.
116
/// Used for scaling factors.
@@ -37,6 +32,7 @@ impl ue8m0 {
3732
/// other values are truncated and rounded to the nearest representable value.
3833
#[inline]
3934
#[must_use]
35+
#[cfg(feature = "float4")]
4036
pub fn from_f32(value: f32) -> ue8m0 {
4137
Self::from_f64(value as f64)
4238
}
@@ -49,8 +45,9 @@ impl ue8m0 {
4945
/// values are truncated and rounded to the nearest representable value.
5046
#[inline]
5147
#[must_use]
48+
#[cfg(feature = "float4")]
5249
pub fn from_f64(value: f64) -> ue8m0 {
53-
ue8m0(E8M0::from_f64(value).to_bits())
50+
ue8m0(float4::E8M0::from_f64(value).to_bits())
5451
}
5552

5653
/// Converts a [`ue8m0`] into the underlying bit representation.
@@ -65,6 +62,7 @@ impl ue8m0 {
6562
/// This conversion is lossless as all values can be represented exactly in [`f32`].
6663
#[inline]
6764
#[must_use]
65+
#[cfg(feature = "float4")]
6866
pub fn to_f32(self) -> f32 {
6967
self.to_f64() as f32
7068
}
@@ -74,101 +72,110 @@ impl ue8m0 {
7472
/// This conversion is lossless as all values can be represented exactly in [`f64`].
7573
#[inline]
7674
#[must_use]
75+
#[cfg(feature = "float4")]
7776
pub fn to_f64(self) -> f64 {
78-
E8M0::from_bits(self.0).to_f64()
77+
float4::E8M0::from_bits(self.0).to_f64()
7978
}
8079
}
8180

82-
impl Neg for ue8m0 {
83-
type Output = Self;
84-
85-
fn neg(self) -> Self::Output {
86-
Self::from_f32(self.to_f32().neg())
81+
impl Display for ue8m0 {
82+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
83+
write!(f, "{}", self.0)
8784
}
8885
}
8986

90-
impl Mul for ue8m0 {
91-
type Output = Self;
87+
#[cfg(feature = "float4")]
88+
mod numeric {
89+
use num_traits::{NumCast, ToPrimitive};
9290

93-
fn mul(self, rhs: Self) -> Self::Output {
94-
Self::from_f32(self.to_f32() * rhs.to_f32())
95-
}
96-
}
91+
use super::*;
92+
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
9793

98-
impl MulAssign for ue8m0 {
99-
fn mul_assign(&mut self, rhs: Self) {
100-
*self = *self * rhs;
94+
impl Neg for ue8m0 {
95+
type Output = Self;
96+
97+
fn neg(self) -> Self::Output {
98+
Self::from_f32(self.to_f32().neg())
99+
}
101100
}
102-
}
103101

104-
impl Div for ue8m0 {
105-
type Output = Self;
102+
impl Mul for ue8m0 {
103+
type Output = Self;
106104

107-
fn div(self, rhs: Self) -> Self::Output {
108-
Self::from_f32(self.to_f32() / rhs.to_f32())
105+
fn mul(self, rhs: Self) -> Self::Output {
106+
Self::from_f32(self.to_f32() * rhs.to_f32())
107+
}
109108
}
110-
}
111109

112-
impl DivAssign for ue8m0 {
113-
fn div_assign(&mut self, rhs: Self) {
114-
*self = *self / rhs;
110+
impl MulAssign for ue8m0 {
111+
fn mul_assign(&mut self, rhs: Self) {
112+
*self = *self * rhs;
113+
}
115114
}
116-
}
117115

118-
impl Add for ue8m0 {
119-
type Output = Self;
116+
impl Div for ue8m0 {
117+
type Output = Self;
120118

121-
fn add(self, rhs: Self) -> Self::Output {
122-
Self::from_f32(self.to_f32() + rhs.to_f32())
119+
fn div(self, rhs: Self) -> Self::Output {
120+
Self::from_f32(self.to_f32() / rhs.to_f32())
121+
}
123122
}
124-
}
125123

126-
impl AddAssign for ue8m0 {
127-
fn add_assign(&mut self, rhs: Self) {
128-
*self = *self + rhs;
124+
impl DivAssign for ue8m0 {
125+
fn div_assign(&mut self, rhs: Self) {
126+
*self = *self / rhs;
127+
}
129128
}
130-
}
131129

132-
impl Sub for ue8m0 {
133-
type Output = Self;
130+
impl Add for ue8m0 {
131+
type Output = Self;
134132

135-
fn sub(self, rhs: Self) -> Self::Output {
136-
Self::from_f32(self.to_f32() - rhs.to_f32())
133+
fn add(self, rhs: Self) -> Self::Output {
134+
Self::from_f32(self.to_f32() + rhs.to_f32())
135+
}
137136
}
138-
}
139137

140-
impl SubAssign for ue8m0 {
141-
fn sub_assign(&mut self, rhs: Self) {
142-
*self = *self - rhs;
138+
impl AddAssign for ue8m0 {
139+
fn add_assign(&mut self, rhs: Self) {
140+
*self = *self + rhs;
141+
}
143142
}
144-
}
145143

146-
impl ToPrimitive for ue8m0 {
147-
fn to_i64(&self) -> Option<i64> {
148-
Some(ue8m0::to_f32(*self) as i64)
149-
}
144+
impl Sub for ue8m0 {
145+
type Output = Self;
150146

151-
fn to_u64(&self) -> Option<u64> {
152-
Some(ue8m0::to_f64(*self) as u64)
147+
fn sub(self, rhs: Self) -> Self::Output {
148+
Self::from_f32(self.to_f32() - rhs.to_f32())
149+
}
153150
}
154151

155-
fn to_f32(&self) -> Option<f32> {
156-
Some(ue8m0::to_f32(*self))
152+
impl SubAssign for ue8m0 {
153+
fn sub_assign(&mut self, rhs: Self) {
154+
*self = *self - rhs;
155+
}
157156
}
158157

159-
fn to_f64(&self) -> Option<f64> {
160-
Some(ue8m0::to_f64(*self))
161-
}
162-
}
158+
impl ToPrimitive for ue8m0 {
159+
fn to_i64(&self) -> Option<i64> {
160+
Some(ue8m0::to_f32(*self) as i64)
161+
}
162+
163+
fn to_u64(&self) -> Option<u64> {
164+
Some(ue8m0::to_f64(*self) as u64)
165+
}
163166

164-
impl NumCast for ue8m0 {
165-
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
166-
Some(Self::from_f32(n.to_f32()?))
167+
fn to_f32(&self) -> Option<f32> {
168+
Some(ue8m0::to_f32(*self))
169+
}
170+
171+
fn to_f64(&self) -> Option<f64> {
172+
Some(ue8m0::to_f64(*self))
173+
}
167174
}
168-
}
169175

170-
impl Display for ue8m0 {
171-
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
172-
write!(f, "{}", self.0)
176+
impl NumCast for ue8m0 {
177+
fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
178+
Some(Self::from_f32(n.to_f32()?))
179+
}
173180
}
174181
}

crates/cubecl-core/src/runtime_tests/launch.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub fn test_kernel_with_comptime_tag<R: Runtime>(client: ComputeClient<R::Server
6161
&client,
6262
CubeCount::Static(1, 1, 1),
6363
CubeDim::default(),
64-
ComptimeTagLaunch::new(array_arg, &"zero".to_string()),
64+
ComptimeTagLaunch::new(array_arg, "zero".to_string()),
6565
);
6666

6767
let actual = client.read_one(handle);
@@ -76,7 +76,7 @@ pub fn test_kernel_with_comptime_tag<R: Runtime>(client: ComputeClient<R::Server
7676
&client,
7777
CubeCount::Static(1, 1, 1),
7878
CubeDim::default(),
79-
ComptimeTagLaunch::new(array_arg, &"not_zero".to_string()),
79+
ComptimeTagLaunch::new(array_arg, "not_zero".to_string()),
8080
);
8181

8282
let actual = client.read_one(handle);

crates/cubecl-macros/src/generate/cube_type/generate_struct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ impl TypeField {
375375
if !self.comptime.is_present() {
376376
quote![#vis #name: <#ty as #launch_arg>::RuntimeArg<'a, R>]
377377
} else {
378-
quote![#vis #name: &'a #ty]
378+
quote![#vis #name: #ty]
379379
}
380380
}
381381

@@ -387,7 +387,7 @@ impl TypeField {
387387
if !self.comptime.is_present() {
388388
quote![#name: <#ty as #launch_arg>::RuntimeArg<'a, R>]
389389
} else {
390-
quote![#name: &'a #ty]
390+
quote![#name: #ty]
391391
}
392392
}
393393

crates/cubecl-quant/src/dequantize.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub fn dequantize_symmetric_packed_values<F: Float, FS: CubePrimitive, QI: Int>(
3131
position: u32,
3232
values: &View<Line<QI>, u32>,
3333
scales: &View<FS, u32>,
34-
#[comptime] scheme: &QuantScheme,
34+
#[comptime] scheme: QuantScheme,
3535
) -> Array<Line<F>> {
3636
dequantize_symmetric_packed_value_at::<F, FS, QI>(position, values[position], scales, scheme)
3737
}
@@ -45,7 +45,7 @@ pub fn dequantize_symmetric_packed_value_at<F: Float, FS: CubePrimitive, QI: Int
4545
position: u32,
4646
values: Line<QI>,
4747
scales: &View<FS, u32>,
48-
#[comptime] scheme: &QuantScheme,
48+
#[comptime] scheme: QuantScheme,
4949
) -> Array<Line<F>> {
5050
dequantize_symmetric_packed_value::<F, FS, QI>(values, scales, position, scheme)
5151
}
@@ -59,7 +59,7 @@ pub fn dequantize_symmetric_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
5959
values: Line<QS>,
6060
scales: &View<FS, u32>,
6161
position: u32,
62-
#[comptime] scheme: &QuantScheme,
62+
#[comptime] scheme: QuantScheme,
6363
) -> Array<Line<F>> {
6464
let line_size_values = values.line_size();
6565
let num_quants = comptime!(scheme.num_quants() as u32);
@@ -120,7 +120,7 @@ fn dequantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
120120
input: &LinearView<Line<u32>>,
121121
scales: &ScalesView<FS>,
122122
output: &mut LinearView<Line<F>, ReadWrite>,
123-
#[comptime] scheme: &QuantScheme,
123+
#[comptime] scheme: QuantScheme,
124124
) {
125125
if !input.is_in_bounds(ABSOLUTE_POS) {
126126
terminate!();
@@ -177,19 +177,19 @@ pub fn launch_ref<R: Runtime, F: Float>(
177177
..
178178
} => match scheme.param {
179179
QuantParam::F32 => {
180-
dequantize_packed::<R, F, f32>(client, values, scheme, params, output)
180+
dequantize_packed::<R, F, f32>(client, values, *scheme, params, output)
181181
}
182182
QuantParam::F16 => {
183-
dequantize_packed::<R, F, f16>(client, values, scheme, params, output)
183+
dequantize_packed::<R, F, f16>(client, values, *scheme, params, output)
184184
}
185185
QuantParam::BF16 => {
186-
dequantize_packed::<R, F, bf16>(client, values, scheme, params, output)
186+
dequantize_packed::<R, F, bf16>(client, values, *scheme, params, output)
187187
}
188188
QuantParam::UE8M0 => {
189-
dequantize_packed::<R, F, ue8m0>(client, values, scheme, params, output)
189+
dequantize_packed::<R, F, ue8m0>(client, values, *scheme, params, output)
190190
}
191191
QuantParam::UE4M3 => {
192-
dequantize_packed::<R, F, e4m3>(client, values, scheme, params, output)
192+
dequantize_packed::<R, F, e4m3>(client, values, *scheme, params, output)
193193
}
194194
},
195195
QuantScheme {
@@ -211,19 +211,19 @@ pub fn launch_ref<R: Runtime, F: Float>(
211211

212212
match scheme.param {
213213
QuantParam::F32 => {
214-
dequantize_native::<R, F, f32>(client, values, scheme, params, output)
214+
dequantize_native::<R, F, f32>(client, values, *scheme, params, output)
215215
}
216216
QuantParam::F16 => {
217-
dequantize_native::<R, F, f16>(client, values, scheme, params, output)
217+
dequantize_native::<R, F, f16>(client, values, *scheme, params, output)
218218
}
219219
QuantParam::BF16 => {
220-
dequantize_native::<R, F, bf16>(client, values, scheme, params, output)
220+
dequantize_native::<R, F, bf16>(client, values, *scheme, params, output)
221221
}
222222
QuantParam::UE8M0 => {
223-
dequantize_native::<R, F, ue8m0>(client, values, scheme, params, output)
223+
dequantize_native::<R, F, ue8m0>(client, values, *scheme, params, output)
224224
}
225225
QuantParam::UE4M3 => {
226-
dequantize_native::<R, F, e4m3>(client, values, scheme, params, output)
226+
dequantize_native::<R, F, e4m3>(client, values, *scheme, params, output)
227227
}
228228
}
229229
}
@@ -240,7 +240,7 @@ pub fn launch_ref<R: Runtime, F: Float>(
240240
fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
241241
client: &ComputeClient<R::Server, R::Channel>,
242242
input: &TensorHandleRef<R>,
243-
scheme: &QuantScheme,
243+
scheme: QuantScheme,
244244
scale: &TensorHandleRef<'_, R>,
245245
output: &TensorHandleRef<R>,
246246
) {
@@ -276,10 +276,10 @@ fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
276276
client,
277277
cube_count,
278278
cube_dim,
279-
linear_view(client, input, &line_size_in),
280-
scales_view(client, input, scale, &1, scheme),
281-
linear_view(client, output, &line_size_out),
282-
scheme.clone(),
279+
linear_view(client, input, line_size_in),
280+
scales_view(client, input, scale, 1, &scheme),
281+
linear_view(client, output, line_size_out),
282+
scheme,
283283
)
284284
};
285285
}
@@ -290,7 +290,7 @@ fn dequantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
290290
fn dequantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
291291
client: &ComputeClient<R::Server, R::Channel>,
292292
input: &TensorHandleRef<R>,
293-
scheme: &QuantScheme,
293+
scheme: QuantScheme,
294294
scale: &TensorHandleRef<'_, R>,
295295
output: &TensorHandleRef<R>,
296296
) {
@@ -333,9 +333,9 @@ fn dequantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
333333
client,
334334
cube_count,
335335
cube_dim,
336-
linear_view(client, input, &line_size),
337-
scales_view(client, input, scale, &1, scheme),
338-
linear_view(client, output, &line_size),
336+
linear_view(client, input, line_size),
337+
scales_view(client, input, scale, 1, &scheme),
338+
linear_view(client, output, line_size),
339339
)
340340
};
341341
}

0 commit comments

Comments
 (0)