Skip to content

Commit bb716c9

Browse files
authored
Make scales layout constructable at runtime (#945)
1 parent 30366c3 commit bb716c9

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

crates/cubecl-quant/src/layout/scales.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ pub struct PerTensorLayout {
6969
tensor_len: u32,
7070
}
7171

72+
#[cube]
73+
impl PerTensorLayout {
74+
pub fn new(tensor_len: u32) -> Self {
75+
PerTensorLayout { tensor_len }
76+
}
77+
}
78+
7279
#[cube]
7380
impl Layout for PerTensorLayout {
7481
type Coordinates = Coords1d;
@@ -111,6 +118,25 @@ pub struct BlockScaledLayout {
111118
scales_line_size: u32,
112119
}
113120

121+
#[cube]
122+
impl BlockScaledLayout {
123+
pub fn new(
124+
tensor_shape: Sequence<FastDivmod>,
125+
tensor_len: u32,
126+
scales_strides: Sequence<u32>,
127+
#[comptime] block_size: Vec<u8>,
128+
#[comptime] scales_line_size: u32,
129+
) -> Self {
130+
BlockScaledLayout {
131+
tensor_shape,
132+
tensor_len,
133+
scales_strides,
134+
block_size,
135+
scales_line_size,
136+
}
137+
}
138+
}
139+
114140
#[cube]
115141
impl Layout for BlockScaledLayout {
116142
type Coordinates = Coords1d;
@@ -127,6 +153,7 @@ impl Layout for BlockScaledLayout {
127153
let dim = comptime![rank - i - 1];
128154
let block_size_local = comptime![self.block_size[dim as usize] as u32];
129155
let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs);
156+
130157
offs = rem;
131158
scale_offs += (offs_local / block_size_local) * *self.scales_strides.index(dim);
132159
}

0 commit comments

Comments
 (0)