24
24
25
25
#include " loop_vectorize.h"
26
26
27
- #include < tvm/arith/iter_affine_map.h>
28
- #include < tvm/tir/builtin.h>
29
- #include < tvm/tir/stmt_functor.h>
30
-
31
- #include < numeric>
32
-
33
- #include " ../layout/layout.h"
34
- #include " ../layout/utils.h"
35
27
#include " arith/int_operator.h"
36
28
#include " arith/ir_visitor_with_analyzer.h"
37
29
#include " common/loop_vectorization_utils.h"
30
+ #include " tvm/tir/analysis.h"
31
+ #include " tvm/tir/var.h"
32
+ #include < tvm/arith/iter_affine_map.h>
33
+ #include < tvm/tir/builtin.h>
34
+ #include < tvm/tir/stmt_functor.h>
38
35
39
36
namespace tvm {
40
37
namespace tl {
@@ -56,15 +53,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
56
53
return vector_size_;
57
54
}
58
55
59
- bool GetDynamic () { return dynamic_; }
60
-
61
- PrimExpr GetCondition () { return condition_; }
62
-
63
56
private:
64
57
void VisitStmt_ (const ForNode *node) final {
65
58
inner_for_ = node;
66
- iter_map_.Set (node->loop_var , Range (node->min , node->extent ));
67
-
59
+ auto extent_ptr = as_const_int (node->extent );
60
+ // Here I disable dynamic shape completely,
61
+ // In order to do it, the Planner should accept an analyzer with
62
+ // arithmetic info outside to prove the dividiblity of vector size
63
+ if (!extent_ptr) {
64
+ vector_size_ = 1 ;
65
+ return ;
66
+ }
67
+ vector_size_ = arith::ZeroAwareGCD (vector_size_, *extent_ptr);
68
68
arith::IRVisitorWithAnalyzer::VisitStmt_ (node);
69
69
}
70
70
@@ -113,76 +113,47 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
113
113
void UpdateVectorSize (const Array<PrimExpr> &indices, const Buffer &buffer) {
114
114
if (!inner_for_)
115
115
return ;
116
- auto extent_ptr = inner_for_->extent .as <IntImmNode>();
117
- if (!extent_ptr)
116
+ // 1. Compute raw element offset
117
+ auto strides = buffer->strides ;
118
+ if (buffer->strides .empty ()) {
119
+ PrimExpr stride = 1 ;
120
+ for (int i = indices.size () - 1 ; i >= 0 ; --i) {
121
+ strides.push_back (stride);
122
+ stride = stride * buffer->shape [i];
123
+ }
124
+ strides = Array<PrimExpr>{strides.rbegin (), strides.rend ()};
125
+ }
126
+ PrimExpr elem_offset = 0 ;
127
+ for (int i = 0 ; i < indices.size (); ++i) {
128
+ elem_offset += indices[i] * strides[i];
129
+ }
130
+
131
+ // 2. If element offset is independent with loop_var, ignore it
132
+ if (CanProveIndependent (elem_offset, inner_for_->loop_var , &analyzer_)) {
118
133
return ;
134
+ }
119
135
120
- const DataType &access_type = buffer->dtype ;
121
- // i // 2, i % 8 can also be vectorized as factor 16
122
- int max_vector_size = vector_load_bits_max_ / access_type.bits ();
123
- // so we should disable this GCD optimization
124
- max_vector_size = arith::ZeroAwareGCD (max_vector_size, extent_ptr->value );
125
- auto last_dim = buffer->shape .back ();
126
- auto mod_set = analyzer_.modular_set (last_dim);
127
- // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
128
- // conditionally tail vectorize
129
- if (buffer->shape .back ().as <IntImmNode>()) {
130
- max_vector_size = arith::ZeroAwareGCD (max_vector_size, mod_set->coeff );
131
- auto gcd_base = arith::ZeroAwareGCD (max_vector_size, mod_set->base );
132
- // If gcd_base is equal to the last dimension,
133
- // we should analyze the second-to-last dimension
134
- // in relation to the last dimension.
135
- if (gcd_base < Downcast<IntImm>(last_dim)->value ) {
136
- max_vector_size = gcd_base;
137
- }
138
- vector_size_ = arith::ZeroAwareGCD (max_vector_size, vector_size_);
139
-
140
- // Generate strides if not existed
141
- auto strides = buffer->strides ;
142
- if (buffer->strides .empty ()) {
143
- PrimExpr stride = 1 ;
144
- for (int i = indices.size () - 1 ; i >= 0 ; --i) {
145
- strides.push_back (stride);
146
- stride = stride * buffer->shape [i];
147
- }
148
- strides = Array<PrimExpr>{strides.rbegin (), strides.rend ()};
149
- }
136
+ // 3. Tight vectorize bound
137
+ vector_size_ = arith::ZeroAwareGCD (vector_size_, vector_load_bits_max_ /
138
+ buffer->dtype .bits ());
150
139
151
- // Generate and check element offset expression
152
- ICHECK (indices.size () == strides.size ()) << " Invalid indices and strides" ;
153
- PrimExpr elem_offset = 0 ;
154
- for (int i = 0 ; i < indices.size (); ++i) {
155
- elem_offset += indices[i] * strides[i];
156
- }
157
- while (!IndiceCanVectorize (elem_offset, inner_for_->loop_var ,
158
- inner_for_->extent , vector_size_,
159
- &analyzer_)) {
160
- vector_size_ /= 2 ;
161
- }
162
- } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype .bits ()) {
163
- // dynamic shape load: get the vectorization condition
164
- dynamic_ = true ;
165
- PrimExpr offset = buffer.OffsetOf (indices).back ();
166
- condition_ = (FloorMod (offset, vector_size_) == 0 );
140
+ // 4. Try to vectorize buffer load
141
+ while (!IndiceCanVectorize (elem_offset, inner_for_->loop_var ,
142
+ inner_for_->extent , vector_size_, &analyzer_)) {
143
+ vector_size_ /= 2 ;
167
144
}
168
145
}
169
146
170
147
const int vector_load_bits_max_ = 128 ;
171
148
172
149
const ForNode *inner_for_{};
173
- Map<Var, Range> iter_map_;
174
150
bool has_nonlocal_memory_access_ = false ;
175
151
int vector_size_ = 128 ;
176
- // conditionally vectorize
177
- bool dynamic_ = false ;
178
- PrimExpr condition_;
179
152
};
180
153
181
154
class VectorizeRewriter : public StmtExprMutator {
182
155
public:
183
- VectorizeRewriter (const VectorizePlanResult &plan)
184
- : vector_size_(plan.vector_size), condition_(plan.condition),
185
- dynamic_ (plan.dynamic) {}
156
+ VectorizeRewriter (int vector_size) : vector_size_(vector_size) {}
186
157
187
158
private:
188
159
Stmt VisitStmt_ (const ForNode *node) final {
@@ -197,23 +168,19 @@ class VectorizeRewriter : public StmtExprMutator {
197
168
ICHECK (extent % vector_size_ == 0 )
198
169
<< " extent: " << extent << " vector_size_: " << vector_size_;
199
170
ICHECK (is_zero (fnode->min ));
200
- if (!dynamic_) { // check dynamic shape
201
- if (extent == vector_size_) {
202
- fnode.CopyOnWrite ()->kind = ForKind::kVectorized ;
203
- return fnode;
204
- } else {
205
- Var inner_var = Var (" vec" );
206
- Var outer_var = Var (old_var->name_hint );
207
- Map<Var, PrimExpr> vmap;
208
- vmap.Set (fnode->loop_var , outer_var * vector_size_ + inner_var);
209
- Stmt body = Substitute (fnode->body , vmap);
210
- body = For (inner_var, 0 , vector_size_, ForKind::kVectorized , body);
211
- body = For (outer_var, 0 , extent / vector_size_, fnode->kind , body,
212
- fnode->thread_binding , fnode->annotations , fnode->span );
213
- return body;
214
- }
215
- } else {
171
+ if (extent == vector_size_) {
172
+ fnode.CopyOnWrite ()->kind = ForKind::kVectorized ;
216
173
return fnode;
174
+ } else {
175
+ Var inner_var = Var (" vec" );
176
+ Var outer_var = Var (old_var->name_hint );
177
+ Map<Var, PrimExpr> vmap;
178
+ vmap.Set (fnode->loop_var , outer_var * vector_size_ + inner_var);
179
+ Stmt body = Substitute (fnode->body , vmap);
180
+ body = For (inner_var, 0 , vector_size_, ForKind::kVectorized , body);
181
+ body = For (outer_var, 0 , extent / vector_size_, fnode->kind , body,
182
+ fnode->thread_binding , fnode->annotations , fnode->span );
183
+ return body;
217
184
}
218
185
} else {
219
186
return ret;
@@ -222,18 +189,25 @@ class VectorizeRewriter : public StmtExprMutator {
222
189
223
190
const ForNode *inner_for_{};
224
191
const int vector_size_;
225
- const PrimExpr condition_;
226
- const bool dynamic_;
227
192
};
228
193
229
194
int GetVectorizeSize (const For &loop) { return VectorizePlanner ().Plan (loop); }
230
195
231
- VectorizePlanResult GetVectorizePlanResult (const For &loop) {
232
- VectorizePlanner planner;
233
- int vector_size = planner.Plan (loop);
234
- bool dynamic = planner.GetDynamic ();
235
- PrimExpr condition = planner.GetCondition ();
236
- return {vector_size, dynamic, condition};
196
+ bool CanProveIndependent (const PrimExpr &expr, Var var,
197
+ arith::Analyzer *analyzer) {
198
+ // 1. if var doesn't exist, it is independent
199
+ bool used_var = UsesVar (
200
+ expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as (var); });
201
+ if (!used_var) {
202
+ return true ;
203
+ }
204
+ // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
205
+ Var var_1 (" _t" , var.dtype ());
206
+ auto expr_1 = Substitute (expr, {{var, var_1}});
207
+ if (analyzer->CanProveEqual (expr, expr_1)) {
208
+ return true ;
209
+ }
210
+ return false ;
237
211
}
238
212
239
213
bool IndiceCanVectorize (const PrimExpr &expr, Var var,
@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
280
254
}
281
255
282
256
For VectorizeLoop (const For &loop, int vectorize_hint) {
283
- VectorizePlanResult res{128 , false , 0 };
284
257
if (vectorize_hint <= 0 ) {
285
- res = GetVectorizePlanResult (loop) ;
286
- vectorize_hint = res. vector_size ;
258
+ VectorizePlanner planner ;
259
+ vectorize_hint = planner. Plan (loop) ;
287
260
}
288
261
if (vectorize_hint == 1 )
289
262
return loop;
290
- auto rewriter = VectorizeRewriter (res );
263
+ auto rewriter = VectorizeRewriter (vectorize_hint );
291
264
return Downcast<For>(rewriter (loop));
292
265
}
293
266
0 commit comments