@@ -545,11 +545,53 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
545545 bool locateLeft = locateFromLeft (left, right);
546546
547547 // Append all combinations of a and b merge points
548- for (auto & leftPoint : left.points ()) {
549- for (auto & rightPoint : right.points ()) {
550- points.push_back (intersectPoints (leftPoint, rightPoint, locateLeft));
548+ struct pointSort {
549+ bool operator ()(const MergePoint& a, const MergePoint& b) {
550+ size_t left_size = a.iterators ().size () + a.locators ().size ();
551+ size_t right_size = b.iterators ().size () + b.locators ().size ();
552+ return left_size > right_size;
553+ }
554+ } pointSorter;
555+
556+ // Append all combinations of the merge points of a and b
557+ auto sorted_apoint = left.points ();
558+ auto sorted_bpoint = right.points ();
559+ std::sort (sorted_apoint.begin (), sorted_apoint.end (), pointSorter);
560+ std::sort (sorted_bpoint.begin (), sorted_bpoint.end (), pointSorter);
561+
562+ set<Iterator> apoint_root_set;
563+ if (!sorted_apoint.empty ())
564+ apoint_root_set = sorted_apoint.begin ()->tensorRegion ();
565+
566+ set<Iterator>bpoint_root_set;
567+ if (!sorted_bpoint.empty ())
568+ bpoint_root_set = sorted_bpoint.begin ()->tensorRegion ();
569+
570+
571+ for (auto & apoint : sorted_apoint) {
572+ for (auto & bpoint : sorted_bpoint) {
573+ bool hasIntersection = true ;
574+
575+ auto apoint_set = apoint.tensorRegion ();
576+ auto bpoint_set = bpoint.tensorRegion ();
577+
578+ for (auto & it : apoint_set) {
579+ if (!std::count (bpoint_set.begin (), bpoint_set.end (), it) &&
580+ std::count (bpoint_root_set.begin (), bpoint_root_set.end (), it)) {
581+ hasIntersection = false ;
582+ }
583+ }
584+ for (auto & it : bpoint_set) {
585+ if (!std::count (apoint_set.begin (), apoint_set.end (), it) &&
586+ std::count (apoint_root_set.begin (), apoint_root_set.end (), it)) {
587+ hasIntersection = false ;
588+ }
589+ }
590+ if (hasIntersection)
591+ points.push_back (intersectPoints (apoint, bpoint, locateLeft));
551592 }
552593 }
594+ std::sort (points.begin (), points.end (), pointSorter);
553595
554596 // Correctness: ensures that points produced on BOTH the left and the
555597 // right lattices are produced in the final intersection.
@@ -561,7 +603,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
561603 // points and resolves conflicts arising between omitters and
562604 // producers
563605 points = removeDuplicatedTensorRegions (points, true );
564-
606+
565607 // Optimization: Removed a subLattice of points if the entire subLattice is
566608 // made of only omitters
567609 // points = removeUnnecessaryOmitterPoints(points);
@@ -581,10 +623,49 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
581623 {
582624 vector<MergePoint> points;
583625
626+ struct pointSort {
627+ bool operator ()(const MergePoint& a, const MergePoint& b) {
628+ size_t left_size = a.iterators ().size () + a.locators ().size ();
629+ size_t right_size = b.iterators ().size () + b.locators ().size ();
630+ return left_size > right_size;
631+ }
632+ } pointSorter;
633+
584634 // Append all combinations of the merge points of a and b
585- for (auto & apoint : left.points ()) {
586- for (auto & bpoint : right.points ()) {
587- points.push_back (unionPoints (apoint, bpoint));
635+ auto sorted_apoint = left.points ();
636+ auto sorted_bpoint = right.points ();
637+ std::sort (sorted_apoint.begin (), sorted_apoint.end (), pointSorter);
638+ std::sort (sorted_bpoint.begin (), sorted_bpoint.end (), pointSorter);
639+
640+ set<Iterator> apoint_root_set;
641+ if (!sorted_apoint.empty ())
642+ apoint_root_set = sorted_apoint.begin ()->tensorRegion ();
643+
644+ set<Iterator>bpoint_root_set;
645+ if (!sorted_bpoint.empty ())
646+ bpoint_root_set = sorted_bpoint.begin ()->tensorRegion ();
647+
648+ for (auto & apoint : sorted_apoint) {
649+ for (auto & bpoint : sorted_bpoint) {
650+ bool hasIntersection = true ;
651+
652+ auto apoint_set = apoint.tensorRegion ();
653+ auto bpoint_set = bpoint.tensorRegion ();
654+
655+ for (auto & it : apoint_set) {
656+ if (!std::count (bpoint_set.begin (), bpoint_set.end (), it) &&
657+ std::count (bpoint_root_set.begin (), bpoint_root_set.end (), it)) {
658+ hasIntersection = false ;
659+ }
660+ }
661+ for (auto & it : bpoint_set) {
662+ if (!std::count (apoint_set.begin (), apoint_set.end (), it) &&
663+ std::count (apoint_root_set.begin (), apoint_root_set.end (), it)) {
664+ hasIntersection = false ;
665+ }
666+ }
667+ if (hasIntersection)
668+ points.push_back (unionPoints (apoint, bpoint));
588669 }
589670 }
590671
@@ -594,22 +675,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
594675 // Append the merge points of b
595676 util::append (points, right.points ());
596677
597- struct pointSort {
598- bool operator ()(const MergePoint& a, const MergePoint& b) {
599- size_t left_size = a.iterators ().size () + a.locators ().size ();
600- size_t right_size = b.iterators ().size () + b.locators ().size ();
601- return left_size > right_size;
602- }
603- } pointSorter;
604-
605678 std::sort (points.begin (), points.end (), pointSorter);
606679
607680 // Correctness: This ensures that points omitted on BOTH the left and the
608681 // right lattices are omitted in the Union. Needed since some
609682 // subpoints may produce leading to erroneous producer regions
610683 points = correctPointTypesAfterUnion (left.points (), right.points (), points);
611684
612-
613685 // Correctness: Deduplicate regions that are described by multiple lattice
614686 // points and resolves conflicts arising between omitters and
615687 // producers
@@ -675,6 +747,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
675747 */
676748 static MergePoint unionPoints (MergePoint left, MergePoint right)
677749 {
750+
678751 vector<Iterator> iterators= combine (left.iterators (),right.iterators ());
679752 vector<Iterator> locaters = combine (left.locators (), right.locators ());
680753 vector<Iterator> results = combine (left.results (), right.results ());
@@ -788,6 +861,21 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA
788861 return deduplicates;
789862 }
790863
864+ static vector<Iterator>
865+ removeDimensionIterators (const vector<Iterator>& iterators)
866+ {
867+ vector<Iterator> result;
868+
869+ // Remove all but one of the dense iterators, which are all the same.
870+ for (auto & iterator : iterators) {
871+ if (!iterator.isDimensionIterator ()) {
872+ result.push_back (iterator);
873+ }
874+ }
875+ return result;
876+ }
877+
878+
791879 static vector<MergePoint>
792880 flipPoints (const vector<MergePoint>& points) {
793881 vector<MergePoint> flippedPoints;
@@ -1123,8 +1211,18 @@ ostream& operator<<(ostream& os, const MergeLattice& ml) {
11231211}
11241212
11251213bool operator ==(const MergeLattice& a, const MergeLattice& b) {
1126- auto & apoints = a.points ();
1127- auto & bpoints = b.points ();
1214+ auto apoints = a.points ();
1215+ auto bpoints = b.points ();
1216+ struct pointSort {
1217+ bool operator ()(const MergePoint& a, const MergePoint& b) {
1218+ size_t left_size = a.iterators ().size () + a.locators ().size ();
1219+ size_t right_size = b.iterators ().size () + b.locators ().size ();
1220+ return left_size > right_size;
1221+ }
1222+ } pointSorter;
1223+
1224+ std::sort (apoints.begin (), apoints.end (), pointSorter);
1225+ std::sort (bpoints.begin (), bpoints.end (), pointSorter);
11281226 if (apoints.size () != bpoints.size ()) {
11291227 return false ;
11301228 }
0 commit comments