Skip to content

Commit 79dd27c

Browse files
committed
Fix matcher
1 parent 9b8a96b commit 79dd27c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/index_notation/index_notation.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2055,13 +2055,15 @@ IndexStmt IndexStmt::wsaccel(TensorVar& ws, const std::vector<IndexVar>& accels,
20552055
}
20562056
set<IndexVar> TempVars;
20572057
match(*this,
2058-
std::function<void(const WhereNode*)>([&](const WhereNode* where) {
2058+
std::function<void(const WhereNode*,Matcher*)>([&](const WhereNode* where,Matcher* ctx) {
20592059
auto Temp = getResultAccesses(where->producer).first[0];
20602060
if (Temp.getTensorVar() == ws) {
20612061
for (auto i :getIndexVars()){
20622062
TempVars.insert(i);
20632063
}
20642064
}
2065+
ctx->match(where->producer);
2066+
ctx->match(where->consumer);
20652067
}));
20662068
for (auto i : accels) {
20672069
if (TempVars.find(i) == TempVars.end()) {

0 commit comments

Comments
 (0)