diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index f9122f572496..7791c6ac2f1d 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -224,11 +224,18 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { set &slid_dimensions; Scope scope; - // Loops between the loop being slid over and the produce node + // For loops strictly between the loop being slid over and the current + // node (not including the loop being slid over itself). Scope<> enclosing_loops; map replacements; + // The immediately-enclosing For node, and the one enclosing the target + // producer. Replacements are only applied to LetStmts directly inside + // producer_for. + const For *current_for = nullptr; + Stmt producer_for; + using IRMutator::visit; // Check if the dimension at index 'dim_idx' is always pure (i.e. equal to 'dim') @@ -500,6 +507,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { replacements[n + ".min"] = Variable::make(Int(32), prefix + dim + ".min"); replacements[n + ".max"] = Variable::make(Int(32), prefix + dim + ".max"); } + producer_for = Stmt(current_for); // Ok, we have a new min/max required and we're going to // rewrite all the lets that define bounds required. Now @@ -565,6 +573,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { Expr min = expand_expr(op->min, scope); Expr max = expand_expr(op->max, scope); ScopedBinding<> bind(enclosing_loops, op->name); + ScopedValue bind_for(current_for, op); if (equal(min, max)) { // Just treat it like a let Stmt s = LetStmt::make(op->name, min, op->body); @@ -591,7 +600,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { Expr value = op->value; map::iterator iter = replacements.find(op->name); - if (iter != replacements.end()) { + if (iter != replacements.end() && current_for == producer_for.get()) { value = iter->second; replacements.erase(iter); }