Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 70 additions & 97 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,14 @@

#include "loop_vectorize.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tl {
Expand All @@ -56,15 +53,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return vector_size_;
}

bool GetDynamic() { return dynamic_; }

PrimExpr GetCondition() { return condition_; }

private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));

auto extent_ptr = as_const_int(node->extent);
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
Comment on lines +59 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don’t blanket-disable vectorization on dynamic extents; this regresses prior behavior.

Early-returning with vector_size_ = 1 for non-constant extents disables vectorization on dynamic loops entirely. At minimum, keep the optimistic upper bound and let UpdateVectorSize shrink it, or gate this behind a flag.

-    auto extent_ptr = as_const_int(node->extent);
-    // Here I disable dynamic shape completely,
-    //   In order to do it, the Planner should accept an analyzer with
-    //   arithmetic info outside to prove the dividiblity of vector size
-    if (!extent_ptr) {
-      vector_size_ = 1;
-      return;
-    }
-    vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
+    if (const int64_t* extent_ptr = as_const_int(node->extent)) {
+      vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
+    }  // else keep current vector_size_; UpdateVectorSize will shrink as needed
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto extent_ptr = as_const_int(node->extent);
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
if (const int64_t* extent_ptr = as_const_int(node->extent)) {
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
} // else keep current vector_size_; UpdateVectorSize will shrink as needed
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 59 to 67, the code currently
sets vector_size_ = 1 and returns when extent is not a constant, which disables
vectorization for dynamic extents; instead preserve the existing optimistic
upper bound by removing the early return and not forcing vector_size_ to 1 when
extent_ptr is null, allow UpdateVectorSize (or subsequent analysis) to reduce
vector_size_ based on runtime/shrunken information, or optionally gate the
conservative fallback behind a runtime/config flag; update the logic so that
when extent_ptr is null you skip the ZeroAwareGCD step but do not reset
vector_size_, and add a comment describing the conservative vs optimistic
behavior and the optional flag if you choose to implement one.

arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

Expand Down Expand Up @@ -113,76 +113,47 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
// 1. Compute raw element offset
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
PrimExpr elem_offset = 0;
Comment on lines +117 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Reverse-stride construction is broken; Array doesn’t support rbegin/rend (build blocker).

Array<PrimExpr>{strides.rbegin(), strides.rend()} won’t compile for TVM’s Array. Build the reversed array explicitly.

-    auto strides = buffer->strides;
-    if (buffer->strides.empty()) {
-      PrimExpr stride = 1;
-      for (int i = indices.size() - 1; i >= 0; --i) {
-        strides.push_back(stride);
-        stride = stride * buffer->shape[i];
-      }
-      strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
-    }
+    Array<PrimExpr> strides = buffer->strides;
+    if (strides.empty()) {
+      // Build row-major strides then reverse to match index order.
+      Array<PrimExpr> tmp;
+      PrimExpr stride = 1;
+      for (int i = static_cast<int>(indices.size()) - 1; i >= 0; --i) {
+        tmp.push_back(stride);
+        stride = stride * buffer->shape[i];
+      }
+      strides = Array<PrimExpr>();
+      for (int i = static_cast<int>(tmp.size()) - 1; i >= 0; --i) {
+        strides.push_back(tmp[i]);
+      }
+    }

Also applies to: 127-129

🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 117–126 (and similarly 127–129),
the code attempts to construct a TVM Array using strides.rbegin()/rendre which
TVM::Array does not support; replace the rbegin/rend construction with an
explicit reverse build: after filling the std::vector<PrimExpr> strides, create
a new temporary container and push the elements from strides in reverse order
(or std::reverse the vector and then construct Array from its begin/end), then
assign that Array<PrimExpr> to strides; apply the same explicit reverse
construction to the other occurrence at lines 127–129.

for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}

// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
return;
}

const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
// 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits());

// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
}

const int vector_load_bits_max_ = 128;

const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};

class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(const VectorizePlanResult &plan)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}

private:
Stmt VisitStmt_(const ForNode *node) final {
Expand All @@ -197,23 +168,19 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
Comment on lines +171 to 184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Preserve loop var dtype and extents when splitting; current code may mis-type on non-i32 loops.

Var("vec") and Var(old_var->name_hint) default to i32. If old_var is not i32 (e.g., i64 target), this breaks type checking. Also ensure extent literals match the loop var dtype.

-      if (extent == vector_size_) {
+      if (extent == vector_size_) {
         fnode.CopyOnWrite()->kind = ForKind::kVectorized;
         return fnode;
       } else {
-        Var inner_var = Var("vec");
-        Var outer_var = Var(old_var->name_hint);
+        DataType it = old_var->dtype;
+        Var inner_var = Var("vec", it);
+        Var outer_var = Var(old_var->name_hint, it);
         Map<Var, PrimExpr> vmap;
         vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
         Stmt body = Substitute(fnode->body, vmap);
-        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
-        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
+        PrimExpr inner_extent = IntImm(it, vector_size_);
+        PrimExpr outer_extent = IntImm(it, extent / vector_size_);
+        body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
+        body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
                    fnode->thread_binding, fnode->annotations, fnode->span);
         return body;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
DataType it = old_var->dtype;
Var inner_var = Var("vec", it);
Var outer_var = Var(old_var->name_hint, it);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
PrimExpr inner_extent = IntImm(it, vector_size_);
PrimExpr outer_extent = IntImm(it, extent / vector_size_);
body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 172-185, the split loop uses
Var("vec") and Var(old_var->name_hint) which default to i32 and can break type
checking for non-i32 loop vars; fix by creating inner_var and outer_var with the
same dtype as old_var (e.g. Var("vec", old_var->dtype) and
Var(old_var->name_hint, old_var->dtype)), and ensure any integer
extents/constants and arithmetic use casts to old_var->dtype (use
Cast(old_var->dtype, vector_size_) and Cast(old_var->dtype, extent /
vector_size_) or cast each literal/term as needed) so the substituted index
expression and the For extent PrimExprs all have matching dtype.

} else {
return ret;
Expand All @@ -222,18 +189,25 @@ class VectorizeRewriter : public StmtExprMutator {

const ForNode *inner_for_{};
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};

int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }

VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
bool used_var = UsesVar(
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
if (!used_var) {
return true;
}
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
Var var_1("_t", var.dtype());
auto expr_1 = Substitute(expr, {{var, var_1}});
if (analyzer->CanProveEqual(expr, expr_1)) {
return true;
}
return false;
}

bool IndiceCanVectorize(const PrimExpr &expr, Var var,
Expand Down Expand Up @@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
}

For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
VectorizePlanner planner;
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}

Expand Down
4 changes: 4 additions & 0 deletions src/transform/loop_vectorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);

For VectorizeLoop(const For &loop, int vectorize_hint = -1);

// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer);
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
Expand Down
Loading