Skip to content

Commit b18d2a5

Browse files
committed
Enable tile_dotProduct_2
1 parent 12a6e45 commit b18d2a5

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

src/index_notation/index_notation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,7 @@ IndexStmt IndexStmt::pos(IndexVar i, IndexVar ipos, Access access) const {
19721972

19731973
// Replace all occurrences of i with ipos
19741974
transformed = Transformation(ForAllReplace({i}, {ipos})).apply(transformed, &reason);
1975-
if (!transformed.defined()) {
1975+
if (!transformed.defined()) {
19761976
taco_uerror << reason;
19771977
}
19781978

src/index_notation/transformations.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const {
779779
for (auto i = replacement.rbegin(); i != replacement.rend(); ++i ) {
780780
stmt = forall(*i, stmt);
781781
}
782+
elementsMatched = 0;
782783
}
783784
// else cut out this node
784785
return;

src/lower/lowerer_impl_imperative.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2302,7 +2302,9 @@ std::pair<bool,bool> LowererImplImperative::canAccelerateDenseTemp(Where where)
23022302
return resultVar == tempVar[0] ||
23032303
provGraph.isDerivedFrom(tempVar[0], resultVar);
23042304
});
2305-
2305+
if (resultVars.size() == 0){
2306+
return std::make_pair(false, false);
2307+
}
23062308
if (it == resultVars.end()) {
23072309
return std::make_pair(true, false);
23082310
}

test/tests-workspaces.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ TEST(workspaces, tile_dotProduct_1) {
504504
ASSERT_TENSOR_EQ(expected, A);
505505
}
506506

507-
TEST(workspaces, DISABLED_tile_dotProduct_2) {
507+
TEST(workspaces, tile_dotProduct_2) {
508508
// FIXME: This is also currently disabled since split(...) scheduling commands
509509
// only split on the FIRST INSTANCE of an indexVar (assumes only one).
510510
// This is wrong if the indexVar is not renamed across iw_vars since an indexVar can
@@ -516,8 +516,8 @@ TEST(workspaces, DISABLED_tile_dotProduct_2) {
516516
Tensor<double> C("C", {N}, Format({Dense}));
517517

518518
for (int i = 0; i < N; i++) {
519-
B.insert({i}, (double) i);
520-
C.insert({i}, (double) i);
519+
B.insert({i}, (double) i / N);
520+
C.insert({i}, (double) i / N);
521521
}
522522

523523
B.pack();

0 commit comments

Comments
 (0)