Skip to content

Commit 2eb298e

Browse files
committed
Relaxes requirements for spmm transformation
1 parent d5721d7 commit 2eb298e

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/index_notation/transformations.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,16 +1115,15 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
11151115
}
11161116

11171117
TensorVar B = Baccess.getTensorVar();
1118-
if (B.getFormat().getModeFormats()[0].getName() != "dense" ||
1119-
B.getFormat().getModeFormats()[1].getName() != "compressed" ||
1120-
B.getFormat().getModeOrdering()[0] != 0 ||
1118+
if (B.getFormat().getModeOrdering()[0] != 0 ||
11211119
B.getFormat().getModeOrdering()[1] != 1) {
11221120
return stmt;
11231121
}
11241122

1123+
// We need random access into the first mode or this tensor in order to perform a linear combination of rows
1124+
// algorithm. (I think?)
11251125
TensorVar C = Caccess.getTensorVar();
1126-
if (C.getFormat().getModeFormats()[0].getName() != "dense" ||
1127-
C.getFormat().getModeFormats()[1].getName() != "compressed" ||
1126+
if (C.getFormat().getModeFormats()[0].getName() == "compressed" ||
11281127
C.getFormat().getModeOrdering()[0] != 0 ||
11291128
C.getFormat().getModeOrdering()[1] != 1) {
11301129
return stmt;

0 commit comments

Comments
 (0)