Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 60 additions & 86 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace tl {
Expand All @@ -56,15 +58,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return vector_size_;
}

bool GetDynamic() { return dynamic_; }

PrimExpr GetCondition() { return condition_; }

private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));

arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

Expand Down Expand Up @@ -117,72 +113,48 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
if (!extent_ptr)
return;

const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
// 1. Compute raw element offset
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}

// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
return;
}

// 3. Tight vectorize bound
int max_vector_size = vector_load_bits_max_ / buffer->dtype.bits();
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
}

const int vector_load_bits_max_ = 128;

const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};

class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(const VectorizePlanResult &plan)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}

private:
Stmt VisitStmt_(const ForNode *node) final {
Expand All @@ -197,23 +169,19 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
Comment on lines +171 to 184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Preserve loop var dtype and extents when splitting; current code may mis-type on non-i32 loops.

Var("vec") and Var(old_var->name_hint) default to i32. If old_var is not i32 (e.g., i64 target), this breaks type checking. Also ensure extent literals match the loop var dtype.

-      if (extent == vector_size_) {
+      if (extent == vector_size_) {
         fnode.CopyOnWrite()->kind = ForKind::kVectorized;
         return fnode;
       } else {
-        Var inner_var = Var("vec");
-        Var outer_var = Var(old_var->name_hint);
+        DataType it = old_var->dtype;
+        Var inner_var = Var("vec", it);
+        Var outer_var = Var(old_var->name_hint, it);
         Map<Var, PrimExpr> vmap;
         vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
         Stmt body = Substitute(fnode->body, vmap);
-        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
-        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
+        PrimExpr inner_extent = IntImm(it, vector_size_);
+        PrimExpr outer_extent = IntImm(it, extent / vector_size_);
+        body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
+        body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
                    fnode->thread_binding, fnode->annotations, fnode->span);
         return body;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
DataType it = old_var->dtype;
Var inner_var = Var("vec", it);
Var outer_var = Var(old_var->name_hint, it);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
PrimExpr inner_extent = IntImm(it, vector_size_);
PrimExpr outer_extent = IntImm(it, extent / vector_size_);
body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 172-185, the split loop uses
Var("vec") and Var(old_var->name_hint) which default to i32 and can break type
checking for non-i32 loop vars; fix by creating inner_var and outer_var with the
same dtype as old_var (e.g. Var("vec", old_var->dtype) and
Var(old_var->name_hint, old_var->dtype)), and ensure any integer
extents/constants and arithmetic use casts to old_var->dtype (use
Cast(old_var->dtype, vector_size_) and Cast(old_var->dtype, extent /
vector_size_) or cast each literal/term as needed) so the substituted index
expression and the For extent PrimExprs all have matching dtype.

} else {
return ret;
Expand All @@ -222,18 +190,25 @@ class VectorizeRewriter : public StmtExprMutator {

const ForNode *inner_for_{};
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};

int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }

VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
bool used_var = UsesVar(
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
if (!used_var) {
return true;
}
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
Var var_1("_t", var.dtype());
auto expr_1 = Substitute(expr, {{var, var_1}});
if (analyzer->CanProveEqual(expr, expr_1)) {
return true;
}
return false;
}

bool IndiceCanVectorize(const PrimExpr &expr, Var var,
Expand Down Expand Up @@ -280,14 +255,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
}

For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
VectorizePlanner planner;
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}

Expand Down
4 changes: 4 additions & 0 deletions src/transform/loop_vectorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);

For VectorizeLoop(const For &loop, int vectorize_hint = -1);

// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer);
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
Expand Down
Loading