Skip to content

Commit 8d7d252

Browse files
committed
lower: fix a bug causing undefined variables when applying fuse
Fixes #355. This commit fixes a bug where the fuse transformation would not generate necessary locator variables when applied to iteration over two dense variables.
1 parent cb4731d commit 8d7d252

File tree

2 files changed

+86
-21
lines changed

2 files changed

+86
-21
lines changed

src/lower/lowerer_impl.cpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,35 +2250,50 @@ Stmt LowererImpl::declLocatePosVars(vector<Iterator> locators) {
22502250
for (Iterator& locator : locators) {
22512251
accessibleIterators.insert(locator);
22522252

2253-
bool doLocate = true;
2253+
// Pull out some logic for constructing the locators for a given iterator.
2254+
auto addLocator = [&](const Iterator& iter) {
2255+
ModeFunction locate = iter.locate(coordinates(iter));
2256+
taco_iassert(isValue(locate.getResults()[1], true));
2257+
Stmt declarePosVar = VarDecl::make(iter.getPosVar(), locate.getResults()[0]);
2258+
result.push_back(declarePosVar);
2259+
};
2260+
2261+
// Look through all of the parent iterators. If any of these iterators
2262+
// are not accessible, we need to construct their accessors before emitting
2263+
// locator's accessors. This is because locator may use the ancestor's
2264+
// variables in its accessors. We add these ancestors into a vector and reverse
2265+
// it so that the highest parent in the tree's accessors get declared first.
2266+
std::vector<Iterator> ancestors;
22542267
for (Iterator ancestorIterator = locator.getParent();
22552268
!ancestorIterator.isRoot() && ancestorIterator.hasLocate();
22562269
ancestorIterator = ancestorIterator.getParent()) {
22572270
if (!accessibleIterators.contains(ancestorIterator)) {
2258-
doLocate = false;
2271+
// Since we're going to emit the locators for this iterator, add it to
2272+
// accessibleIterators so that other locators with this as an ancestor
2273+
// don't do the same.
2274+
accessibleIterators.insert(ancestorIterator);
2275+
ancestors.push_back(ancestorIterator);
22592276
}
22602277
}
2278+
for (auto it = ancestors.rbegin(); it != ancestors.rend(); it++) addLocator(*it);
22612279

2262-
if (doLocate) {
2263-
Iterator locateIterator = locator;
2264-
if (locateIterator.hasPosIter()) {
2265-
taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar()));
2266-
continue; // these will be recovered with separate procedure
2267-
}
2268-
do {
2269-
ModeFunction locate = locateIterator.locate(coordinates(locateIterator));
2270-
taco_iassert(isValue(locate.getResults()[1], true));
2271-
Stmt declarePosVar = VarDecl::make(locateIterator.getPosVar(),
2272-
locate.getResults()[0]);
2273-
result.push_back(declarePosVar);
2274-
2275-
if (locateIterator.isLeaf()) {
2276-
break;
2277-
}
2278-
2279-
locateIterator = locateIterator.getChild();
2280-
} while (accessibleIterators.contains(locateIterator));
2280+
Iterator locateIterator = locator;
2281+
// Position iterators will be recovered with a separate procedure, so
2282+
// don't emit anything if locator is one.
2283+
if (locateIterator.hasPosIter()) {
2284+
taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar()));
2285+
continue;
22812286
}
2287+
2288+
// Once all parent locators have been declared, add the target and all
2289+
// children locators.
2290+
do {
2291+
addLocator(locateIterator);
2292+
if (locateIterator.isLeaf()) {
2293+
break;
2294+
}
2295+
locateIterator = locateIterator.getChild();
2296+
} while (accessibleIterators.contains(locateIterator));
22822297
}
22832298
return result.empty() ? Stmt() : Block::make(result);
22842299
}

test/tests-scheduling.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,56 @@ TEST(scheduling, splitIndexStmt) {
7272
ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt()));
7373
}
7474

75+
TEST(scheduling, fuseDenseLoops) {
76+
auto dim = 4;
77+
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
78+
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
79+
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
80+
IndexVar f("f"), g("g");
81+
for (int i = 0; i < dim; i++) {
82+
for (int j = 0; j < dim; j++) {
83+
for (int k = 0; k < dim; k++) {
84+
A.insert({i, j, k}, i + j + k);
85+
B.insert({i, j, k}, i + j + k);
86+
expected.insert({i, j, k}, 2 * (i + j + k));
87+
}
88+
}
89+
}
90+
A.pack();
91+
B.pack();
92+
expected.pack();
93+
94+
// Helper function to evaluate the target statement and verify the results.
95+
// It takes in a function that applies some scheduling transforms to the
96+
// input IndexStmt, and applies to the point-wise tensor addition below.
97+
// The test is structured this way as TACO does its best to avoid re-compilation
98+
// whenever possible. I.e. changing the stmt that a tensor is compiled with
99+
// doesn't cause compilation to occur again.
100+
auto testFn = [&](IndexStmt modifier (IndexStmt)) {
101+
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
102+
C(i, j, k) = A(i, j, k) + B(i, j, k);
103+
auto stmt = C.getAssignment().concretize();
104+
C.compile(modifier(stmt));
105+
C.evaluate();
106+
ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl;
107+
};
108+
109+
// First, a sanity check with no transformations.
110+
testFn([](IndexStmt stmt) { return stmt; });
111+
// Next, fuse the outer two loops. This tests the original bug in #355.
112+
testFn([](IndexStmt stmt) {
113+
IndexVar f("f");
114+
return stmt.fuse(i, j, f);
115+
});
116+
// Lastly, fuse all of the loops into a single loop. This ensures that
117+
// locators with a chain of ancestors have all of their dependencies
118+
// generated in a valid ordering.
119+
testFn([](IndexStmt stmt) {
120+
IndexVar f("f"), g("g");
121+
return stmt.fuse(i, j, f).fuse(f, k, g);
122+
});
123+
}
124+
75125
TEST(scheduling, lowerDenseMatrixMul) {
76126
Tensor<double> A("A", {4, 4}, {Dense, Dense});
77127
Tensor<double> B("B", {4, 4}, {Dense, Dense});

0 commit comments

Comments
 (0)