Skip to content

Commit 8471869

Browse files
committed
Changes SPMM tranform requirement. Unsure about this
1 parent 46aed13 commit 8471869

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/index_notation/transformations.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,16 +1114,19 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
11141114
return stmt;
11151115
}
11161116

1117+
// I think we can to linear combination of rows as long as there are no permutations in the format and the
1118+
// level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue.
11171119
TensorVar B = Baccess.getTensorVar();
1118-
if (B.getFormat().getModeOrdering()[0] != 0 ||
1120+
if (!B.getFormat().getModeFormats()[0].isOrdered() ||
1121+
!B.getFormat().getModeFormats()[1].isOrdered() ||
1122+
B.getFormat().getModeOrdering()[0] != 0 ||
11191123
B.getFormat().getModeOrdering()[1] != 1) {
11201124
return stmt;
11211125
}
11221126

1123-
// We need random access into the first mode or this tensor in order to perform a linear combination of rows
1124-
// algorithm. (I think?)
11251127
TensorVar C = Caccess.getTensorVar();
1126-
if (!C.getFormat().getModeFormats()[0].hasLocate() ||
1128+
if (!C.getFormat().getModeFormats()[0].isOrdered() ||
1129+
!C.getFormat().getModeFormats()[1].isOrdered() ||
11271130
C.getFormat().getModeOrdering()[0] != 0 ||
11281131
C.getFormat().getModeOrdering()[1] != 1) {
11291132
return stmt;

0 commit comments

Comments
 (0)