-
Notifications
You must be signed in to change notification settings - Fork 232
[Fix] Fix bug 0905: tilelang doesn't vectorize B[i,j] = c[i] + A[i,j]
#798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
96128ed
1fb176a
c08b815
9ff2008
cbf7b3c
d8ec462
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,15 +56,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { | |
return vector_size_; | ||
} | ||
|
||
bool GetDynamic() { return dynamic_; } | ||
|
||
PrimExpr GetCondition() { return condition_; } | ||
|
||
private: | ||
void VisitStmt_(const ForNode *node) final { | ||
inner_for_ = node; | ||
iter_map_.Set(node->loop_var, Range(node->min, node->extent)); | ||
|
||
arith::IRVisitorWithAnalyzer::VisitStmt_(node); | ||
} | ||
|
||
|
@@ -117,72 +111,48 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { | |
if (!extent_ptr) | ||
return; | ||
|
||
const DataType &access_type = buffer->dtype; | ||
// i // 2, i % 8 can also be vectorized as factor 16 | ||
int max_vector_size = vector_load_bits_max_ / access_type.bits(); | ||
// so we should disable this GCD optimization | ||
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); | ||
auto last_dim = buffer->shape.back(); | ||
auto mod_set = analyzer_.modular_set(last_dim); | ||
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block | ||
// conditionally tail vectorize | ||
if (buffer->shape.back().as<IntImmNode>()) { | ||
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); | ||
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); | ||
// If gcd_base is equal to the last dimension, | ||
// we should analyze the second-to-last dimension | ||
// in relation to the last dimension. | ||
if (gcd_base < Downcast<IntImm>(last_dim)->value) { | ||
max_vector_size = gcd_base; | ||
} | ||
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); | ||
|
||
// Generate strides if not existed | ||
auto strides = buffer->strides; | ||
if (buffer->strides.empty()) { | ||
PrimExpr stride = 1; | ||
for (int i = indices.size() - 1; i >= 0; --i) { | ||
strides.push_back(stride); | ||
stride = stride * buffer->shape[i]; | ||
} | ||
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()}; | ||
// 1. Compute raw element offset | ||
auto strides = buffer->strides; | ||
if (buffer->strides.empty()) { | ||
PrimExpr stride = 1; | ||
for (int i = indices.size() - 1; i >= 0; --i) { | ||
strides.push_back(stride); | ||
stride = stride * buffer->shape[i]; | ||
} | ||
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()}; | ||
} | ||
PrimExpr elem_offset = 0; | ||
for (int i = 0; i < indices.size(); ++i) { | ||
elem_offset += indices[i] * strides[i]; | ||
} | ||
|
||
// Generate and check element offset expression | ||
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides"; | ||
PrimExpr elem_offset = 0; | ||
for (int i = 0; i < indices.size(); ++i) { | ||
elem_offset += indices[i] * strides[i]; | ||
} | ||
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, | ||
inner_for_->extent, vector_size_, | ||
&analyzer_)) { | ||
vector_size_ /= 2; | ||
} | ||
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) { | ||
// dynamic shape load: get the vectorization condition | ||
dynamic_ = true; | ||
PrimExpr offset = buffer.OffsetOf(indices).back(); | ||
condition_ = (FloorMod(offset, vector_size_) == 0); | ||
// 2. If element offset is independent with loop_var, ignore it | ||
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { | ||
return; | ||
} | ||
|
||
// 3. Tight vectorize bound | ||
int max_vector_size = vector_load_bits_max_ / buffer->dtype.bits(); | ||
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); | ||
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); | ||
|
||
// 4. Try to vectorize buffer load | ||
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, | ||
inner_for_->extent, vector_size_, &analyzer_)) { | ||
vector_size_ /= 2; | ||
} | ||
} | ||
|
||
const int vector_load_bits_max_ = 128; | ||
|
||
const ForNode *inner_for_{}; | ||
Map<Var, Range> iter_map_; | ||
bool has_nonlocal_memory_access_ = false; | ||
int vector_size_ = 128; | ||
// conditionally vectorize | ||
bool dynamic_ = false; | ||
PrimExpr condition_; | ||
}; | ||
|
||
class VectorizeRewriter : public StmtExprMutator { | ||
public: | ||
VectorizeRewriter(const VectorizePlanResult &plan) | ||
: vector_size_(plan.vector_size), condition_(plan.condition), | ||
dynamic_(plan.dynamic) {} | ||
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {} | ||
|
||
private: | ||
Stmt VisitStmt_(const ForNode *node) final { | ||
|
@@ -197,23 +167,19 @@ class VectorizeRewriter : public StmtExprMutator { | |
ICHECK(extent % vector_size_ == 0) | ||
<< "extent: " << extent << " vector_size_: " << vector_size_; | ||
ICHECK(is_zero(fnode->min)); | ||
if (!dynamic_) { // check dynamic shape | ||
if (extent == vector_size_) { | ||
fnode.CopyOnWrite()->kind = ForKind::kVectorized; | ||
return fnode; | ||
} else { | ||
Var inner_var = Var("vec"); | ||
Var outer_var = Var(old_var->name_hint); | ||
Map<Var, PrimExpr> vmap; | ||
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); | ||
Stmt body = Substitute(fnode->body, vmap); | ||
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); | ||
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, | ||
fnode->thread_binding, fnode->annotations, fnode->span); | ||
return body; | ||
} | ||
} else { | ||
if (extent == vector_size_) { | ||
fnode.CopyOnWrite()->kind = ForKind::kVectorized; | ||
return fnode; | ||
} else { | ||
Var inner_var = Var("vec"); | ||
Var outer_var = Var(old_var->name_hint); | ||
Map<Var, PrimExpr> vmap; | ||
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); | ||
Stmt body = Substitute(fnode->body, vmap); | ||
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); | ||
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, | ||
fnode->thread_binding, fnode->annotations, fnode->span); | ||
return body; | ||
} | ||
} else { | ||
return ret; | ||
|
@@ -222,18 +188,35 @@ class VectorizeRewriter : public StmtExprMutator { | |
|
||
const ForNode *inner_for_{}; | ||
const int vector_size_; | ||
const PrimExpr condition_; | ||
const bool dynamic_; | ||
}; | ||
|
||
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } | ||
|
||
VectorizePlanResult GetVectorizePlanResult(const For &loop) { | ||
VectorizePlanner planner; | ||
int vector_size = planner.Plan(loop); | ||
bool dynamic = planner.GetDynamic(); | ||
PrimExpr condition = planner.GetCondition(); | ||
return {vector_size, dynamic, condition}; | ||
bool CanProveIndependent(const PrimExpr &expr, Var var, | ||
arith::Analyzer *analyzer) { | ||
// 1. if var doesn't exist, it is independent | ||
struct FindVarVisitor : ExprVisitor { | ||
Var target; | ||
bool found = false; | ||
FindVarVisitor(Var target) : target(std::move(target)) {} | ||
void run(const PrimExpr &expr) { this->VisitExpr(expr); } | ||
void VisitExpr_(const VarNode *node) final { | ||
if (node == target.get()) { | ||
found = true; | ||
} | ||
} | ||
}; | ||
FindVarVisitor visitor(var); | ||
visitor.run(expr); | ||
if (!visitor.found) | ||
return true; | ||
|
||
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v | ||
Var var_1("_t", var.dtype()); | ||
auto expr_1 = Substitute(expr, {{var, var_1}}); | ||
if (analyzer->CanProveEqual(expr, expr_1)) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool IndiceCanVectorize(const PrimExpr &expr, Var var, | ||
|
@@ -245,12 +228,7 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, | |
|
||
// Extent must be divisible | ||
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), | ||
0)) | ||
return false; | ||
|
||
// The base offset must be divisible | ||
if (!analyzer->CanProveEqual( | ||
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) { | ||
0)) { | ||
return false; | ||
} | ||
|
||
|
@@ -259,35 +237,34 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, | |
analyzer->Bind(v0, Range(0, target_vectorized_size)); | ||
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv( | ||
iter_var_size, target_vectorized_size)))); | ||
PrimExpr expr_transformed = analyzer->Simplify( | ||
PrimExpr access_pos = analyzer->Simplify( | ||
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); | ||
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); | ||
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); | ||
|
||
// This simplify is necessary for thread region specified | ||
// optimizations. | ||
expr_vectorized = analyzer->Simplify(expr_vectorized); | ||
auto ramp_node = expr_vectorized.as<RampNode>(); | ||
if (!ramp_node) { | ||
// Broadcast value | ||
if (expr_vectorized.dtype().lanes() == 1) | ||
return true; | ||
else | ||
return false; | ||
} else { | ||
return is_one(ramp_node->stride); | ||
// for (int ph_v = target_vectorized_size; ph_v > 1; ph_v /= 2) { | ||
// ph_v: physical load/store vectorized size | ||
// TODO: allow a more generalized vectorize: B[i] = A[i // 2] | ||
auto ph_v = target_vectorized_size; | ||
auto group = target_vectorized_size / ph_v; | ||
// Check if access_pos is contingentous: ap === v0 // group (mod ph_v) | ||
auto is_contingous = | ||
analyzer->CanProveEqual(FloorMod(access_pos, ph_v), FloorDiv(v0, group)); | ||
// Check if access is aligned | ||
auto is_aligned = | ||
analyzer->CanProveEqual(FloorMod(Substitute(expr, {{var, 0}}), ph_v), 0); | ||
if (is_contingous && is_aligned) { | ||
return true; | ||
} | ||
// } | ||
return false; | ||
} | ||
|
||
|
||
For VectorizeLoop(const For &loop, int vectorize_hint) { | ||
VectorizePlanResult res{128, false, 0}; | ||
if (vectorize_hint <= 0) { | ||
res = GetVectorizePlanResult(loop); | ||
vectorize_hint = res.vector_size; | ||
VectorizePlanner planner; | ||
vectorize_hint = planner.Plan(loop); | ||
} | ||
if (vectorize_hint == 1) | ||
return loop; | ||
auto rewriter = VectorizeRewriter(res); | ||
auto rewriter = VectorizeRewriter(vectorize_hint); | ||
return Downcast<For>(rewriter(loop)); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preserve loop var dtype and extents when splitting; current code may mis-type on non-i32 loops.
Var("vec")
andVar(old_var->name_hint)
default to i32. Ifold_var
is not i32 (e.g., i64 target), this breaks type checking. Also ensureextent
literals match the loop var dtype.📝 Committable suggestion
🤖 Prompt for AI Agents