@@ -331,7 +331,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
331331 );
332332 IndexSetRel rel = a.getIndexSetRel ();
333333 switch (rel) {
334- case none: a = Assignment (a.getLhs (), a.getRhs (), Add () );break ; // =
334+ case none: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
335335 case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
336336 case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
337337 case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
@@ -346,10 +346,16 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
346346 const IndexExpr& e,
347347 map<IndexVar, IndexVar> substitutions) {
348348
349- auto assignment = ws (iw_vars) = replace (e, substitutions);
350- if (!assignment.getReductionVars ().empty ())
351- assignment = Assignment (assignment.getLhs (), assignment.getRhs (), Add ());
352- return assignment;
349+ auto a = ws (iw_vars) = replace (e, substitutions);
350+ IndexSetRel rel = a.getIndexSetRel ();
351+ switch (rel) {
352+ case none: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
353+ case rcl: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
354+ case lcr: a = Assignment (a.getLhs (), a.getRhs ());break ; // =
355+ case inter: a = Assignment (a.getLhs (), a.getRhs (), Add ());break ; // +=
356+ case equal: a = Assignment (a.getLhs (), a.getRhs ());break ;// = OR +=
357+ }
358+ return a;
353359 }
354360
355361 IndexStmt generateForalls (IndexStmt stmt, vector<IndexVar> indexVars) {
@@ -459,7 +465,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
459465 IndexNotationRewriter::visit (node);
460466 }
461467 };
462- struct RedundentVisitor : public IndexNotationVisitor {
468+ struct RedundantVisitor : public IndexNotationVisitor {
463469 using IndexNotationVisitor::visit;
464470
465471 std::vector<Assignment>& to_change;
@@ -468,7 +474,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
468474 int ctx_num;
469475 const ProvenanceGraph& provGraph;
470476
471- RedundentVisitor (std::vector<Assignment>& to_change, const ProvenanceGraph& provGraph):to_change(to_change), provGraph(provGraph),ctx_num(0 ){}
477+ RedundantVisitor (std::vector<Assignment>& to_change, const ProvenanceGraph& provGraph):to_change(to_change), provGraph(provGraph),ctx_num(0 ){}
472478
473479 void visit (const ForallNode* node) {
474480 Forall foralli (node);
@@ -512,17 +518,31 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
512518 }
513519 }));
514520 bool is_equal = (a.getIndexSetRel () == equal);
521+ bool is_none = (a.getIndexSetRel () == none);
515522
516523 if (is_equal && has_sibling) {
517524 to_change.push_back (a);
518525 }
526+ if (is_none && has_sibling && ctx_num > 1 ) {
527+ to_change.push_back (a);
528+ }
529+ bool has_outside = false ;
530+ for (auto & var : seen) {
531+ if (var!=ctx_stack.back ()){
532+ has_outside = true ;
533+ break ;
534+ }
535+ }
536+ if (is_none && has_sibling && ctx_num == 1 && has_outside) {
537+ to_change.push_back (a);
538+ }
519539 }
520540 };
521541
522- struct RedundentRewriter : public IndexNotationRewriter {
542+ struct RedundantRewriter : public IndexNotationRewriter {
523543 using IndexNotationRewriter::visit;
524544 std::set<Assignment> to_change;
525- RedundentRewriter (std::vector<Assignment>& to_change):to_change(to_change.begin(),to_change.end()){}
545+ RedundantRewriter (std::vector<Assignment>& to_change):to_change(to_change.begin(),to_change.end()){}
526546
527547 void visit (const AssignmentNode* node) {
528548 Assignment a (node->lhs , node->rhs , node->op );
@@ -545,9 +565,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
545565
546566
547567 std::vector<Assignment> to_change;
548- RedundentVisitor findVisitor (to_change, provGraph);
568+ RedundantVisitor findVisitor (to_change, provGraph);
549569 stmt.accept (&findVisitor);
550- RedundentRewriter ReRewriter (to_change);
570+ RedundantRewriter ReRewriter (to_change);
551571 stmt = ReRewriter.rewrite (stmt);
552572
553573 return stmt;
0 commit comments