Skip to content

Commit cc6fc34

Browse files
rohanyweiya711
authored andcommitted
Squash array_algebra commits after master_array_algebra
lowerer_impl: fix some striding bugs Fixes some formulaic errors in generated striding code along with a test that revealed them. WIP trying to fix lattice construction Add in some changes to iteration lattice construction Add in locator check as well Add in iteration lattice construction changes for intersectLattices() as well Fix iteration lattice comparison (to sort points first) and check for empty points when unionLattices and intersectLattices are called
1 parent 2cb2f82 commit cc6fc34

File tree

3 files changed

+129
-21
lines changed

3 files changed

+129
-21
lines changed

src/lower/lowerer_impl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3917,7 +3917,7 @@ Expr LowererImpl::searchForEndOfWindowPosition(Iterator iterator, ir::Expr start
39173917
Stmt LowererImpl::upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access) {
39183918
taco_iassert(iterator.isWindowed());
39193919
return ir::IfThenElse::make(
3920-
ir::Gte::make(access, ir::Sub::make(iterator.getWindowUpperBound(), iterator.getWindowLowerBound())),
3920+
ir::Gte::make(access, ir::Div::make(ir::Sub::make(iterator.getWindowUpperBound(), iterator.getWindowLowerBound()), iterator.getStride())),
39213921
ir::Break::make()
39223922
);
39233923
}
@@ -3936,7 +3936,7 @@ Stmt LowererImpl::strideBoundsGuard(Iterator iterator, ir::Expr access, bool inc
39363936
}
39373937
// The guard makes sure that the coordinate being accessed is along the stride.
39383938
return ir::IfThenElse::make(
3939-
ir::Neq::make(ir::Rem::make(access, iterator.getStride()), ir::Literal::make(0)),
3939+
ir::Neq::make(ir::Rem::make(ir::Sub::make(access, iterator.getWindowLowerBound()), iterator.getStride()), ir::Literal::make(0)),
39403940
cont
39413941
);
39423942
}
@@ -3948,7 +3948,7 @@ Expr LowererImpl::projectWindowedPositionToCanonicalSpace(Iterator iterator, ir:
39483948

39493949

39503950
Expr LowererImpl::projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr) {
3951-
return ir::Mul::make(ir::Add::make(expr, iterator.getWindowLowerBound()), iterator.getStride());
3951+
return ir::Add::make(ir::Mul::make(expr, iterator.getStride()), iterator.getWindowLowerBound());
39523952
}
39533953

39543954
}

src/lower/merge_lattice.cpp

Lines changed: 116 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,53 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
545545
bool locateLeft = locateFromLeft(left, right);
546546

547547
// Append all combinations of a and b merge points
548-
for (auto& leftPoint : left.points()) {
549-
for (auto& rightPoint : right.points()) {
550-
points.push_back(intersectPoints(leftPoint, rightPoint, locateLeft));
548+
struct pointSort {
549+
bool operator()(const MergePoint& a, const MergePoint& b) {
550+
size_t left_size = a.iterators().size() + a.locators().size();
551+
size_t right_size = b.iterators().size() + b.locators().size();
552+
return left_size > right_size;
553+
}
554+
} pointSorter;
555+
556+
// Append all combinations of the merge points of a and b
557+
auto sorted_apoint = left.points();
558+
auto sorted_bpoint = right.points();
559+
std::sort(sorted_apoint.begin(), sorted_apoint.end(), pointSorter);
560+
std::sort(sorted_bpoint.begin(), sorted_bpoint.end(), pointSorter);
561+
562+
set<Iterator> apoint_root_set;
563+
if (!sorted_apoint.empty())
564+
apoint_root_set = sorted_apoint.begin()->tensorRegion();
565+
566+
set<Iterator>bpoint_root_set;
567+
if (!sorted_bpoint.empty())
568+
bpoint_root_set = sorted_bpoint.begin()->tensorRegion();
569+
570+
571+
for (auto& apoint : sorted_apoint) {
572+
for (auto& bpoint : sorted_bpoint) {
573+
bool hasIntersection = true;
574+
575+
auto apoint_set = apoint.tensorRegion();
576+
auto bpoint_set = bpoint.tensorRegion();
577+
578+
for (auto& it : apoint_set) {
579+
if (!std::count(bpoint_set.begin(), bpoint_set.end(), it) &&
580+
std::count(bpoint_root_set.begin(), bpoint_root_set.end(), it)) {
581+
hasIntersection = false;
582+
}
583+
}
584+
for (auto& it : bpoint_set) {
585+
if (!std::count(apoint_set.begin(), apoint_set.end(), it) &&
586+
std::count(apoint_root_set.begin(), apoint_root_set.end(), it)) {
587+
hasIntersection = false;
588+
}
589+
}
590+
if (hasIntersection)
591+
points.push_back(intersectPoints(apoint, bpoint, locateLeft));
551592
}
552593
}
594+
std::sort(points.begin(), points.end(), pointSorter);
553595

554596
// Correctness: ensures that points produced on BOTH the left and the
555597
// right lattices are produced in the final intersection.
@@ -561,7 +603,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
561603
// points and resolves conflicts arising between omitters and
562604
// producers
563605
points = removeDuplicatedTensorRegions(points, true);
564-
606+
565607
// Optimization: Removed a subLattice of points if the entire subLattice is
566608
// made of only omitters
567609
// points = removeUnnecessaryOmitterPoints(points);
@@ -581,10 +623,49 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
581623
{
582624
vector<MergePoint> points;
583625

626+
struct pointSort {
627+
bool operator()(const MergePoint& a, const MergePoint& b) {
628+
size_t left_size = a.iterators().size() + a.locators().size();
629+
size_t right_size = b.iterators().size() + b.locators().size();
630+
return left_size > right_size;
631+
}
632+
} pointSorter;
633+
584634
// Append all combinations of the merge points of a and b
585-
for (auto& apoint : left.points()) {
586-
for (auto& bpoint : right.points()) {
587-
points.push_back(unionPoints(apoint, bpoint));
635+
auto sorted_apoint = left.points();
636+
auto sorted_bpoint = right.points();
637+
std::sort(sorted_apoint.begin(), sorted_apoint.end(), pointSorter);
638+
std::sort(sorted_bpoint.begin(), sorted_bpoint.end(), pointSorter);
639+
640+
set<Iterator> apoint_root_set;
641+
if (!sorted_apoint.empty())
642+
apoint_root_set = sorted_apoint.begin()->tensorRegion();
643+
644+
set<Iterator>bpoint_root_set;
645+
if (!sorted_bpoint.empty())
646+
bpoint_root_set = sorted_bpoint.begin()->tensorRegion();
647+
648+
for (auto& apoint : sorted_apoint) {
649+
for (auto& bpoint : sorted_bpoint) {
650+
bool hasIntersection = true;
651+
652+
auto apoint_set = apoint.tensorRegion();
653+
auto bpoint_set = bpoint.tensorRegion();
654+
655+
for (auto& it : apoint_set) {
656+
if (!std::count(bpoint_set.begin(), bpoint_set.end(), it) &&
657+
std::count(bpoint_root_set.begin(), bpoint_root_set.end(), it)) {
658+
hasIntersection = false;
659+
}
660+
}
661+
for (auto& it : bpoint_set) {
662+
if (!std::count(apoint_set.begin(), apoint_set.end(), it) &&
663+
std::count(apoint_root_set.begin(), apoint_root_set.end(), it)) {
664+
hasIntersection = false;
665+
}
666+
}
667+
if (hasIntersection)
668+
points.push_back(unionPoints(apoint, bpoint));
588669
}
589670
}
590671

@@ -594,22 +675,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
594675
// Append the merge points of b
595676
util::append(points, right.points());
596677

597-
struct pointSort {
598-
bool operator()(const MergePoint& a, const MergePoint& b) {
599-
size_t left_size = a.iterators().size() + a.locators().size();
600-
size_t right_size = b.iterators().size() + b.locators().size();
601-
return left_size > right_size;
602-
}
603-
} pointSorter;
604-
605678
std::sort(points.begin(), points.end(), pointSorter);
606679

607680
// Correctness: This ensures that points omitted on BOTH the left and the
608681
// right lattices are omitted in the Union. Needed since some
609682
// subpoints may produce leading to erroneous producer regions
610683
points = correctPointTypesAfterUnion(left.points(), right.points(), points);
611684

612-
613685
// Correctness: Deduplicate regions that are described by multiple lattice
614686
// points and resolves conflicts arising between omitters and
615687
// producers
@@ -675,6 +747,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
675747
*/
676748
static MergePoint unionPoints(MergePoint left, MergePoint right)
677749
{
750+
678751
vector<Iterator> iterators= combine(left.iterators(),right.iterators());
679752
vector<Iterator> locaters = combine(left.locators(), right.locators());
680753
vector<Iterator> results = combine(left.results(), right.results());
@@ -788,6 +861,21 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
788861
return deduplicates;
789862
}
790863

864+
static vector<Iterator>
865+
removeDimensionIterators(const vector<Iterator>& iterators)
866+
{
867+
vector<Iterator> result;
868+
869+
// Remove all but one of the dense iterators, which are all the same.
870+
for (auto& iterator : iterators) {
871+
if (!iterator.isDimensionIterator()) {
872+
result.push_back(iterator);
873+
}
874+
}
875+
return result;
876+
}
877+
878+
791879
static vector<MergePoint>
792880
flipPoints(const vector<MergePoint>& points) {
793881
vector<MergePoint> flippedPoints;
@@ -1123,8 +1211,18 @@ ostream& operator<<(ostream& os, const MergeLattice& ml) {
11231211
}
11241212

11251213
bool operator==(const MergeLattice& a, const MergeLattice& b) {
1126-
auto& apoints = a.points();
1127-
auto& bpoints = b.points();
1214+
auto apoints = a.points();
1215+
auto bpoints = b.points();
1216+
struct pointSort {
1217+
bool operator()(const MergePoint& a, const MergePoint& b) {
1218+
size_t left_size = a.iterators().size() + a.locators().size();
1219+
size_t right_size = b.iterators().size() + b.locators().size();
1220+
return left_size > right_size;
1221+
}
1222+
} pointSorter;
1223+
1224+
std::sort(apoints.begin(), apoints.end(), pointSorter);
1225+
std::sort(bpoints.begin(), bpoints.end(), pointSorter);
11281226
if (apoints.size() != bpoints.size()) {
11291227
return false;
11301228
}

test/tests-windowing.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,16 @@ TEST_P(stride, windowing) {
433433
c(i, j) = a(i(0, 10, 5), j(0, 10, 5)) * b(i(0, 10, 5), j(0, 10, 5));
434434
c.evaluate();
435435
ASSERT_TRUE(equals(c, expectedMul)) << c << endl << expectedMul << endl;
436+
437+
// Test a strided assignment where the stride doesn't start at 0.
438+
c(i, j) = a(i(1, 5, 2), j(2, 6, 2));
439+
c.evaluate();
440+
Tensor<int> expectedAssign2("expectedAssign2", {2, 2}, {Dense, Dense});
441+
expectedAssign2.insert({0, 0}, 3); expectedAssign2.insert({0, 1}, 5);
442+
expectedAssign2.insert({1, 0}, 5); expectedAssign2.insert({1, 1}, 7);
443+
expectedAssign2.pack();
444+
ASSERT_TRUE(equals(c, expectedAssign2)) << c << endl << expectedAssign2 << endl;
445+
436446
}
437447
INSTANTIATE_TEST_CASE_P(
438448
windowing,

0 commit comments

Comments
 (0)