File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed
crates/cubecl-quant/src/layout Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -69,6 +69,13 @@ pub struct PerTensorLayout {
6969 tensor_len : u32 ,
7070}
7171
72+ #[ cube]
73+ impl PerTensorLayout {
74+ pub fn new ( tensor_len : u32 ) -> Self {
75+ PerTensorLayout { tensor_len }
76+ }
77+ }
78+
7279#[ cube]
7380impl Layout for PerTensorLayout {
7481 type Coordinates = Coords1d ;
@@ -111,6 +118,25 @@ pub struct BlockScaledLayout {
111118 scales_line_size : u32 ,
112119}
113120
121+ #[ cube]
122+ impl BlockScaledLayout {
123+ pub fn new (
124+ tensor_shape : Sequence < FastDivmod > ,
125+ tensor_len : u32 ,
126+ scales_strides : Sequence < u32 > ,
127+ #[ comptime] block_size : Vec < u8 > ,
128+ #[ comptime] scales_line_size : u32 ,
129+ ) -> Self {
130+ BlockScaledLayout {
131+ tensor_shape,
132+ tensor_len,
133+ scales_strides,
134+ block_size,
135+ scales_line_size,
136+ }
137+ }
138+ }
139+
114140#[ cube]
115141impl Layout for BlockScaledLayout {
116142 type Coordinates = Coords1d ;
@@ -127,6 +153,7 @@ impl Layout for BlockScaledLayout {
127153 let dim = comptime ! [ rank - i - 1 ] ;
128154 let block_size_local = comptime ! [ self . block_size[ dim as usize ] as u32 ] ;
129155 let ( rem, offs_local) = self . tensor_shape . index ( dim) . div_mod ( offs) ;
156+
130157 offs = rem;
131158 scale_offs += ( offs_local / block_size_local) * * self . scales_strides . index ( dim) ;
132159 }
You can’t perform that action at this time.
0 commit comments