@@ -433,13 +433,16 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
433433 );
434434
435435 IndexSetRel rel = a.getIndexSetRel ();
436+ // / The reduceOp depends on the relation between indexVar sets of rhs and lhs. For rcl and inter, reduceOp
437+ // / must be +=. For lcr, reduceOp must be =. For none and equal, reduceOp can't be decided at this stage.
436438 switch (rel) {
437- case none: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
438- case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
439- case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
440- case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
441- case equal: a = Assignment (a.getLhs (), a.getRhs ());break ;// = OR +=
442- }return a;
439+ case none: a = Assignment (a.getLhs (), a.getRhs ());break ;
440+ case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ;
441+ case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ;
442+ case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ;
443+ case equal: a = Assignment (a.getLhs (), a.getRhs ());break ;
444+ }
445+ return a;
443446 }
444447
445448 Assignment getProducerAssignment (TensorVar& ws,
@@ -450,14 +453,15 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
450453
451454 auto a = ws (iw_vars) = replace (e, substitutions);
452455 IndexSetRel rel = a.getIndexSetRel ();
456+ // / The reduceOp depends on the relation between indexVar sets of rhs and lhs. For rcl and inter, reduceOp
457+ // / must be +=. For lcr, reduceOp must be =. For none and equal, reduceOp can't be decided at this stage.
453458 switch (rel) {
454- case none: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
455- case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
456- case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
457- case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
458- case equal: a = Assignment (a.getLhs (), a.getRhs ());break ;// = OR +=
459+ case none: a = Assignment (a.getLhs (), a.getRhs ());break ;
460+ case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ;
461+ case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ;
462+ case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ;
463+ case equal: a = Assignment (a.getLhs (), a.getRhs ());break ;
459464 }
460-
461465 return a;
462466 }
463467
@@ -565,6 +569,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
565569 }
566570 };
567571
572+ // / RedundantVisitor uses Forall Context to determine reduceOp for none and equal.
573+ // / We assume += is used if a workspace is accessed multiple times, otherwise =.
574+ // / Forall Context describes the related indexVars of the given indexVar at a specific stage. `ctx_stack` implements such concept.
568575 struct RedundantVisitor : public IndexNotationVisitor {
569576 using IndexNotationVisitor::visit;
570577
@@ -608,6 +615,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
608615 Assignment a (node->lhs , node->rhs , node->op );
609616 vector<IndexVar> freeVars = a.getLhs ().getIndexVars ();
610617 set<IndexVar> seen (freeVars.begin (), freeVars.end ());
618+
619+ // / For equal, if some indexVar in lhs has sibling in ctx stack, reduceOp will be +=.
620+ bool is_equal = (a.getIndexSetRel () == equal);
611621 bool has_sibling = false ;
612622 match (a.getRhs (),
613623 std::function<void (const AccessNode*)>([&](const AccessNode* op) {
@@ -619,11 +629,12 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
619629 }
620630 }
621631 }));
622- bool is_equal = (a.getIndexSetRel () == equal);
623- bool is_none = (a.getIndexSetRel () == none);
624632 if (is_equal && has_sibling) {
625633 to_change.push_back (a);
626634 }
635+
636+ // / For none, if ctx_stack except the top contains indexVars in lhs, reduceOp will be +=.
637+ bool is_none = (a.getIndexSetRel () == none);
627638 bool has_outside = true ;
628639 for (auto & var : seen) {
629640 for (auto &svar: ctx_stack) {
0 commit comments