Skip to content

Commit 2d4b848

Browse files
authored
[Fix] tilelang can now vectorize B[i,j] = c[i] + A[i,j] (#798)
* Fix bug 0905: vectorize with broadcasted value * fix lint error * [Refactor] Use `tvm::tir::UseVar` and use Vectorizer * Add loop size check in vectorize planner * fix lint error
1 parent fa4fd0b commit 2d4b848

File tree

2 files changed

+74
-97
lines changed

2 files changed

+74
-97
lines changed

src/transform/loop_vectorize.cc

Lines changed: 70 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,14 @@
2424

2525
#include "loop_vectorize.h"
2626

27-
#include <tvm/arith/iter_affine_map.h>
28-
#include <tvm/tir/builtin.h>
29-
#include <tvm/tir/stmt_functor.h>
30-
31-
#include <numeric>
32-
33-
#include "../layout/layout.h"
34-
#include "../layout/utils.h"
3527
#include "arith/int_operator.h"
3628
#include "arith/ir_visitor_with_analyzer.h"
3729
#include "common/loop_vectorization_utils.h"
30+
#include "tvm/tir/analysis.h"
31+
#include "tvm/tir/var.h"
32+
#include <tvm/arith/iter_affine_map.h>
33+
#include <tvm/tir/builtin.h>
34+
#include <tvm/tir/stmt_functor.h>
3835

3936
namespace tvm {
4037
namespace tl {
@@ -56,15 +53,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
5653
return vector_size_;
5754
}
5855

59-
bool GetDynamic() { return dynamic_; }
60-
61-
PrimExpr GetCondition() { return condition_; }
62-
6356
private:
6457
void VisitStmt_(const ForNode *node) final {
6558
inner_for_ = node;
66-
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
67-
59+
auto extent_ptr = as_const_int(node->extent);
60+
// Here I disable dynamic shape completely,
61+
// In order to do it, the Planner should accept an analyzer with
62+
// arithmetic info outside to prove the dividiblity of vector size
63+
if (!extent_ptr) {
64+
vector_size_ = 1;
65+
return;
66+
}
67+
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
6868
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
6969
}
7070

@@ -113,76 +113,47 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
113113
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
114114
if (!inner_for_)
115115
return;
116-
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
117-
if (!extent_ptr)
116+
// 1. Compute raw element offset
117+
auto strides = buffer->strides;
118+
if (buffer->strides.empty()) {
119+
PrimExpr stride = 1;
120+
for (int i = indices.size() - 1; i >= 0; --i) {
121+
strides.push_back(stride);
122+
stride = stride * buffer->shape[i];
123+
}
124+
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
125+
}
126+
PrimExpr elem_offset = 0;
127+
for (int i = 0; i < indices.size(); ++i) {
128+
elem_offset += indices[i] * strides[i];
129+
}
130+
131+
// 2. If element offset is independent with loop_var, ignore it
132+
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
118133
return;
134+
}
119135

120-
const DataType &access_type = buffer->dtype;
121-
// i // 2, i % 8 can also be vectorized as factor 16
122-
int max_vector_size = vector_load_bits_max_ / access_type.bits();
123-
// so we should disable this GCD optimization
124-
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
125-
auto last_dim = buffer->shape.back();
126-
auto mod_set = analyzer_.modular_set(last_dim);
127-
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
128-
// conditionally tail vectorize
129-
if (buffer->shape.back().as<IntImmNode>()) {
130-
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
131-
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
132-
// If gcd_base is equal to the last dimension,
133-
// we should analyze the second-to-last dimension
134-
// in relation to the last dimension.
135-
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
136-
max_vector_size = gcd_base;
137-
}
138-
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
139-
140-
// Generate strides if not existed
141-
auto strides = buffer->strides;
142-
if (buffer->strides.empty()) {
143-
PrimExpr stride = 1;
144-
for (int i = indices.size() - 1; i >= 0; --i) {
145-
strides.push_back(stride);
146-
stride = stride * buffer->shape[i];
147-
}
148-
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
149-
}
136+
// 3. Tight vectorize bound
137+
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
138+
buffer->dtype.bits());
150139

151-
// Generate and check element offset expression
152-
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
153-
PrimExpr elem_offset = 0;
154-
for (int i = 0; i < indices.size(); ++i) {
155-
elem_offset += indices[i] * strides[i];
156-
}
157-
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
158-
inner_for_->extent, vector_size_,
159-
&analyzer_)) {
160-
vector_size_ /= 2;
161-
}
162-
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
163-
// dynamic shape load: get the vectorization condition
164-
dynamic_ = true;
165-
PrimExpr offset = buffer.OffsetOf(indices).back();
166-
condition_ = (FloorMod(offset, vector_size_) == 0);
140+
// 4. Try to vectorize buffer load
141+
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
142+
inner_for_->extent, vector_size_, &analyzer_)) {
143+
vector_size_ /= 2;
167144
}
168145
}
169146

170147
const int vector_load_bits_max_ = 128;
171148

172149
const ForNode *inner_for_{};
173-
Map<Var, Range> iter_map_;
174150
bool has_nonlocal_memory_access_ = false;
175151
int vector_size_ = 128;
176-
// conditionally vectorize
177-
bool dynamic_ = false;
178-
PrimExpr condition_;
179152
};
180153

181154
class VectorizeRewriter : public StmtExprMutator {
182155
public:
183-
VectorizeRewriter(const VectorizePlanResult &plan)
184-
: vector_size_(plan.vector_size), condition_(plan.condition),
185-
dynamic_(plan.dynamic) {}
156+
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
186157

187158
private:
188159
Stmt VisitStmt_(const ForNode *node) final {
@@ -197,23 +168,19 @@ class VectorizeRewriter : public StmtExprMutator {
197168
ICHECK(extent % vector_size_ == 0)
198169
<< "extent: " << extent << " vector_size_: " << vector_size_;
199170
ICHECK(is_zero(fnode->min));
200-
if (!dynamic_) { // check dynamic shape
201-
if (extent == vector_size_) {
202-
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
203-
return fnode;
204-
} else {
205-
Var inner_var = Var("vec");
206-
Var outer_var = Var(old_var->name_hint);
207-
Map<Var, PrimExpr> vmap;
208-
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
209-
Stmt body = Substitute(fnode->body, vmap);
210-
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
211-
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
212-
fnode->thread_binding, fnode->annotations, fnode->span);
213-
return body;
214-
}
215-
} else {
171+
if (extent == vector_size_) {
172+
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
216173
return fnode;
174+
} else {
175+
Var inner_var = Var("vec");
176+
Var outer_var = Var(old_var->name_hint);
177+
Map<Var, PrimExpr> vmap;
178+
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
179+
Stmt body = Substitute(fnode->body, vmap);
180+
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
181+
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
182+
fnode->thread_binding, fnode->annotations, fnode->span);
183+
return body;
217184
}
218185
} else {
219186
return ret;
@@ -222,18 +189,25 @@ class VectorizeRewriter : public StmtExprMutator {
222189

223190
const ForNode *inner_for_{};
224191
const int vector_size_;
225-
const PrimExpr condition_;
226-
const bool dynamic_;
227192
};
228193

229194
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
230195

231-
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
232-
VectorizePlanner planner;
233-
int vector_size = planner.Plan(loop);
234-
bool dynamic = planner.GetDynamic();
235-
PrimExpr condition = planner.GetCondition();
236-
return {vector_size, dynamic, condition};
196+
bool CanProveIndependent(const PrimExpr &expr, Var var,
197+
arith::Analyzer *analyzer) {
198+
// 1. if var doesn't exist, it is independent
199+
bool used_var = UsesVar(
200+
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
201+
if (!used_var) {
202+
return true;
203+
}
204+
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
205+
Var var_1("_t", var.dtype());
206+
auto expr_1 = Substitute(expr, {{var, var_1}});
207+
if (analyzer->CanProveEqual(expr, expr_1)) {
208+
return true;
209+
}
210+
return false;
237211
}
238212

239213
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
280254
}
281255

282256
For VectorizeLoop(const For &loop, int vectorize_hint) {
283-
VectorizePlanResult res{128, false, 0};
284257
if (vectorize_hint <= 0) {
285-
res = GetVectorizePlanResult(loop);
286-
vectorize_hint = res.vector_size;
258+
VectorizePlanner planner;
259+
vectorize_hint = planner.Plan(loop);
287260
}
288261
if (vectorize_hint == 1)
289262
return loop;
290-
auto rewriter = VectorizeRewriter(res);
263+
auto rewriter = VectorizeRewriter(vectorize_hint);
291264
return Downcast<For>(rewriter(loop));
292265
}
293266

src/transform/loop_vectorize.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);
3737

3838
For VectorizeLoop(const For &loop, int vectorize_hint = -1);
3939

40+
// Can prove expr is independent with var, i.e. the value of expr doesn't change
41+
// when var changes
42+
bool CanProveIndependent(const PrimExpr &expr, Var var,
43+
arith::Analyzer *analyzer);
4044
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
4145
const PrimExpr &iter_var_size,
4246
int target_vectorized_size, arith::Analyzer *analyzer);

0 commit comments

Comments
 (0)