Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 85 additions & 108 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return vector_size_;
}

bool GetDynamic() { return dynamic_; }

PrimExpr GetCondition() { return condition_; }

private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));

arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

Expand Down Expand Up @@ -117,72 +111,48 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
if (!extent_ptr)
return;

const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
// 1. Compute raw element offset
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}

// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
return;
}

// 3. Tight vectorize bound
int max_vector_size = vector_load_bits_max_ / buffer->dtype.bits();
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
}

const int vector_load_bits_max_ = 128;

const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};

class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(const VectorizePlanResult &plan)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}

private:
Stmt VisitStmt_(const ForNode *node) final {
Expand All @@ -197,23 +167,19 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
Comment on lines +171 to 184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Preserve loop var dtype and extents when splitting; current code may mis-type on non-i32 loops.

Var("vec") and Var(old_var->name_hint) default to i32. If old_var is not i32 (e.g., i64 target), this breaks type checking. Also ensure extent literals match the loop var dtype.

-      if (extent == vector_size_) {
+      if (extent == vector_size_) {
         fnode.CopyOnWrite()->kind = ForKind::kVectorized;
         return fnode;
       } else {
-        Var inner_var = Var("vec");
-        Var outer_var = Var(old_var->name_hint);
+        DataType it = old_var->dtype;
+        Var inner_var = Var("vec", it);
+        Var outer_var = Var(old_var->name_hint, it);
         Map<Var, PrimExpr> vmap;
         vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
         Stmt body = Substitute(fnode->body, vmap);
-        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
-        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
+        PrimExpr inner_extent = IntImm(it, vector_size_);
+        PrimExpr outer_extent = IntImm(it, extent / vector_size_);
+        body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
+        body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
                    fnode->thread_binding, fnode->annotations, fnode->span);
         return body;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
DataType it = old_var->dtype;
Var inner_var = Var("vec", it);
Var outer_var = Var(old_var->name_hint, it);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
PrimExpr inner_extent = IntImm(it, vector_size_);
PrimExpr outer_extent = IntImm(it, extent / vector_size_);
body = For(inner_var, IntImm(it, 0), inner_extent, ForKind::kVectorized, body);
body = For(outer_var, IntImm(it, 0), outer_extent, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 172-185, the split loop uses
Var("vec") and Var(old_var->name_hint) which default to i32 and can break type
checking for non-i32 loop vars; fix by creating inner_var and outer_var with the
same dtype as old_var (e.g. Var("vec", old_var->dtype) and
Var(old_var->name_hint, old_var->dtype)), and ensure any integer
extents/constants and arithmetic use casts to old_var->dtype (use
Cast(old_var->dtype, vector_size_) and Cast(old_var->dtype, extent /
vector_size_) or cast each literal/term as needed) so the substituted index
expression and the For extent PrimExprs all have matching dtype.

} else {
return ret;
Expand All @@ -222,18 +188,35 @@ class VectorizeRewriter : public StmtExprMutator {

const ForNode *inner_for_{};
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};

int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }

VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
struct FindVarVisitor : ExprVisitor {
Var target;
bool found = false;
FindVarVisitor(Var target) : target(std::move(target)) {}
void run(const PrimExpr &expr) { this->VisitExpr(expr); }
void VisitExpr_(const VarNode *node) final {
if (node == target.get()) {
found = true;
}
}
};
FindVarVisitor visitor(var);
visitor.run(expr);
if (!visitor.found)
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FindVarVisitor struct and its usage can be simplified by using the tvm::tir::UsesVar utility function. This makes the code more concise and idiomatic to the TVM codebase.

You would need to add the following include at the top of the file:

#include <tvm/tir/analysis.h>
  if (!tvm::tir::UsesVar(expr, [&var](const Var& v) { return v.same_as(var); })) {
    return true;
  }

// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
Var var_1("_t", var.dtype());
auto expr_1 = Substitute(expr, {{var, var_1}});
if (analyzer->CanProveEqual(expr, expr_1)) {
return true;
}
return false;
}

bool IndiceCanVectorize(const PrimExpr &expr, Var var,
Expand All @@ -245,12 +228,7 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,

// Extent must be divisible
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;

// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
0)) {
return false;
}

Expand All @@ -259,35 +237,34 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
iter_var_size, target_vectorized_size))));
PrimExpr expr_transformed = analyzer->Simplify(
PrimExpr access_pos = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);

// This simplify is necessary for thread region specified
// optimizations.
expr_vectorized = analyzer->Simplify(expr_vectorized);
auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) {
// Broadcast value
if (expr_vectorized.dtype().lanes() == 1)
return true;
else
return false;
} else {
return is_one(ramp_node->stride);
// for (int ph_v = target_vectorized_size; ph_v > 1; ph_v /= 2) {
// ph_v: physical load/store vectorized size
// TODO: allow a more generalized vectorize: B[i] = A[i // 2]
auto ph_v = target_vectorized_size;
auto group = target_vectorized_size / ph_v;
// Check if access_pos is contingentous: ap === v0 // group (mod ph_v)
auto is_contingous =
analyzer->CanProveEqual(FloorMod(access_pos, ph_v), FloorDiv(v0, group));
// Check if access is aligned
auto is_aligned =
analyzer->CanProveEqual(FloorMod(Substitute(expr, {{var, 0}}), ph_v), 0);
if (is_contingous && is_aligned) {
return true;
}
// }
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify the contiguity check logic.

The contiguity and alignment checks in IndiceCanVectorize have been significantly simplified. The logic checks:

  • access_pos is contiguous: FloorMod(access_pos, ph_v) == FloorDiv(v0, group)
  • Access is aligned: FloorMod(Substitute(expr, {{var, 0}}), ph_v) == 0

However, there are some concerns:

  1. Line 246: group = target_vectorized_size / ph_v where ph_v = target_vectorized_size results in group = 1, which may not be the intended behavior
  2. The commented-out loop suggests this was meant to handle multiple physical vector sizes, but now only handles one case

Run the following script to check if there are test cases that validate this vectorization logic:


🏁 Script executed:

#!/bin/bash
# Search for tests related to IndiceCanVectorize and vectorization
echo "=== Searching for vectorization tests ==="
fd -e cc -e cpp -e py | xargs rg -l "IndiceCanVectorize|VectorizeLoop" | head -10

echo -e "\n=== Checking test patterns for contiguity/alignment ==="
rg -A5 -B5 "contiguous|aligned" --type cpp --type py | head -20

Length of output: 1758


Restore multi-phase contiguity logic in src/transform/loop_vectorize.cc
The current code sets

auto ph_v = target_vectorized_size;
auto group = target_vectorized_size / ph_v;  // == 1

so the contiguity check

CanProveEqual(FloorMod(access_pos, ph_v), FloorDiv(v0, group))

degenerates and no longer iterates over smaller physical vector widths as intended (see commented-out for on lines 243–246). Reinstate the loop over ph_v (halving each iteration) and recalculate group inside it. Add unit tests for IndiceCanVectorize covering both contiguous and aligned cases.

🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 240 to 258, the temporary
removal of the multi-phase contiguity logic made ph_v always equal to
target_vectorized_size so the contiguity check degenerates; restore the original
loop that iterates ph_v = target_vectorized_size; ph_v > 1; ph_v /= 2 (halving
each iteration), move the calculation of group = target_vectorized_size / ph_v
inside that loop, perform the CanProveEqual contiguity and alignment checks for
each ph_v and return true if any phase passes, otherwise return false after the
loop; then add unit tests for IndiceCanVectorize that exercise both contiguous
and aligned cases (including smaller physical widths) to validate the restored
behavior.


For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
VectorizePlanner planner;
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}

Expand Down
4 changes: 4 additions & 0 deletions src/transform/loop_vectorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);

For VectorizeLoop(const For &loop, int vectorize_hint = -1);

// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer);
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
Expand Down
Loading