Skip to content

Commit 691064f

Browse files
authored
fix: Fix several panics occuring with unrolled line sizes in matmul (#980)
1 parent b3eb8fe commit 691064f

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

crates/cubecl-core/src/post_processing/unroll.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ impl UnrollProcessor {
4444
inst: &Instruction,
4545
mappings: &mut Mappings,
4646
) -> TransformAction {
47+
if matches!(inst.operation, Operation::Free(_)) {
48+
return TransformAction::Ignore;
49+
}
50+
4751
if inst.operation.args().is_none() {
4852
// Detect unhandled ops that can't be reflected
4953
match &inst.operation {
@@ -623,7 +627,7 @@ fn create_unrolled(
623627
allocator.create_local_mut(item)
624628
}
625629
VariableKind::LocalConst { .. } => allocator.create_local(item),
626-
_ => panic!("Out must be local"),
630+
other => panic!("Out must be local, found {other:?}"),
627631
})
628632
.collect()
629633
}

crates/cubecl-matmul/src/components/tile/register/matmul.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ impl<Acc: TileKind> RegisterMatmul<Acc> {
158158
#[unroll(UNROLL)]
159159
for line_within_segment in 0..num_lines_per_segment {
160160
let line = tile.get_line(segment, line_within_segment);
161-
#[unroll(UNROLL)]
161+
#[unroll]
162162
for pos_within_line in 0..line_size {
163163
array[segment * segment_size
164164
+ line_within_segment * line_size
@@ -182,7 +182,7 @@ impl<Acc: TileKind> RegisterMatmul<Acc> {
182182
#[unroll(UNROLL)]
183183
for line_within_segment in 0..num_lines_per_segment {
184184
let line = tile.get_line(segment, line_within_segment);
185-
#[unroll(UNROLL)]
185+
#[unroll]
186186
for pos_within_line in 0..line_size {
187187
array[(line_within_segment * line_size + pos_within_line) * num_segments
188188
+ segment] = ER::cast_from(line[pos_within_line]);

crates/cubecl-matmul/src/components/tile/register/writer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl RegisterStageWriter {
2121
#[unroll(UNROLL)]
2222
for i in 0..comptime!(config.tile_size.mn() / out_line_size) {
2323
let mut line = Line::empty(out_line_size);
24-
#[unroll(UNROLL)]
24+
#[unroll]
2525
for j in 0..comptime!(out_line_size) {
2626
line[j] = acc[i * out_line_size + j];
2727
}

0 commit comments

Comments
 (0)