Skip to content

Commit 1391421

Browse files
committed
fix none
1 parent 7df35e4 commit 1391421

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

src/index_notation/transformations.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
331331
);
332332
IndexSetRel rel = a.getIndexSetRel();
333333
switch (rel) {
334-
case none: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // =
334+
case none: a = Assignment(a.getLhs(), a.getRhs());break; // =
335335
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
336336
case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // =
337337
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
@@ -346,10 +346,16 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
346346
const IndexExpr& e,
347347
map<IndexVar, IndexVar> substitutions) {
348348

349-
auto assignment = ws(iw_vars) = replace(e, substitutions);
350-
if (!assignment.getReductionVars().empty())
351-
assignment = Assignment(assignment.getLhs(), assignment.getRhs(), Add());
352-
return assignment;
349+
auto a = ws(iw_vars) = replace(e, substitutions);
350+
IndexSetRel rel = a.getIndexSetRel();
351+
switch (rel) {
352+
case none: a = Assignment(a.getLhs(), a.getRhs());break; // =
353+
case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
354+
case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // =
355+
case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // +=
356+
case equal: a = Assignment(a.getLhs(), a.getRhs());break;// = OR +=
357+
}
358+
return a;
353359
}
354360

355361
IndexStmt generateForalls(IndexStmt stmt, vector<IndexVar> indexVars) {
@@ -459,7 +465,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
459465
IndexNotationRewriter::visit(node);
460466
}
461467
};
462-
struct RedundentVisitor: public IndexNotationVisitor {
468+
struct RedundantVisitor: public IndexNotationVisitor {
463469
using IndexNotationVisitor::visit;
464470

465471
std::vector<Assignment>& to_change;
@@ -468,7 +474,7 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
468474
int ctx_num;
469475
const ProvenanceGraph& provGraph;
470476

471-
RedundentVisitor(std::vector<Assignment>& to_change, const ProvenanceGraph& provGraph):to_change(to_change), provGraph(provGraph),ctx_num(0){}
477+
RedundantVisitor(std::vector<Assignment>& to_change, const ProvenanceGraph& provGraph):to_change(to_change), provGraph(provGraph),ctx_num(0){}
472478

473479
void visit(const ForallNode* node) {
474480
Forall foralli(node);
@@ -512,17 +518,31 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
512518
}
513519
}));
514520
bool is_equal = (a.getIndexSetRel() == equal);
521+
bool is_none = (a.getIndexSetRel() == none);
515522

516523
if (is_equal && has_sibling) {
517524
to_change.push_back(a);
518525
}
526+
if (is_none && has_sibling && ctx_num > 1) {
527+
to_change.push_back(a);
528+
}
529+
bool has_outside = false;
530+
for (auto & var : seen) {
531+
if (var!=ctx_stack.back()){
532+
has_outside = true;
533+
break;
534+
}
535+
}
536+
if (is_none && has_sibling && ctx_num == 1 && has_outside) {
537+
to_change.push_back(a);
538+
}
519539
}
520540
};
521541

522-
struct RedundentRewriter: public IndexNotationRewriter {
542+
struct RedundantRewriter: public IndexNotationRewriter {
523543
using IndexNotationRewriter::visit;
524544
std::set<Assignment> to_change;
525-
RedundentRewriter(std::vector<Assignment>& to_change):to_change(to_change.begin(),to_change.end()){}
545+
RedundantRewriter(std::vector<Assignment>& to_change):to_change(to_change.begin(),to_change.end()){}
526546

527547
void visit(const AssignmentNode* node) {
528548
Assignment a(node->lhs, node->rhs, node->op);
@@ -545,9 +565,9 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
545565

546566

547567
std::vector<Assignment> to_change;
548-
RedundentVisitor findVisitor(to_change, provGraph);
568+
RedundantVisitor findVisitor(to_change, provGraph);
549569
stmt.accept(&findVisitor);
550-
RedundentRewriter ReRewriter(to_change);
570+
RedundantRewriter ReRewriter(to_change);
551571
stmt = ReRewriter.rewrite(stmt);
552572

553573
return stmt;

0 commit comments

Comments
 (0)