Skip to content

Commit 9837321

Browse files
committed
avoid infinite loop in einsummatmul transposition
1 parent 8b8f453 commit 9837321

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

core/src/ops/einsum/einsum_matmul.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ impl TypedOp for EinSumMatMul {
153153
match (self.m.as_i64(), self.n.as_i64()) {
154154
(Some(m), Some(n)) => m < n,
155155
(None, Some(n)) => n >= 8,
156-
_ => false,
156+
(Some(_), _) => false,
157+
_ => (self.n.clone() - &self.m).prove_positive_or_zero(),
157158
}
158159
};
159160
if must_transpose {
@@ -167,11 +168,12 @@ impl TypedOp for EinSumMatMul {
167168
&[node.inputs[1], node.inputs[0]],
168169
op,
169170
)
170-
.map(Some);
171+
.map(|p| Some(p.with_context("transposing")));
171172
}
172173
// opt mat mul assumes we have at least one m or n
173174
if self.c_m().is_some() || self.c_n().is_some() {
174-
return optimized_mat_mul(model, node, self);
175+
return optimized_mat_mul(model, node, self)
176+
.map(|opt| opt.map(|p| p.with_context("optimizing")));
175177
}
176178
Ok(None)
177179
}

0 commit comments

Comments
 (0)