diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 2731a2e4f..3b33fa985 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -24,17 +24,14 @@ #include "loop_vectorize.h" -#include -#include -#include - -#include - -#include "../layout/layout.h" -#include "../layout/utils.h" #include "arith/int_operator.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_vectorization_utils.h" +#include "tvm/tir/analysis.h" +#include "tvm/tir/var.h" +#include +#include +#include namespace tvm { namespace tl { @@ -56,15 +53,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { return vector_size_; } - bool GetDynamic() { return dynamic_; } - - PrimExpr GetCondition() { return condition_; } - private: void VisitStmt_(const ForNode *node) final { inner_for_ = node; - iter_map_.Set(node->loop_var, Range(node->min, node->extent)); - + auto extent_ptr = as_const_int(node->extent); + // Here I disable dynamic shape completely, + // In order to do it, the Planner should accept an analyzer with + // arithmetic info outside to prove the dividiblity of vector size + if (!extent_ptr) { + vector_size_ = 1; + return; + } + vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); arith::IRVisitorWithAnalyzer::VisitStmt_(node); } @@ -113,76 +113,47 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { void UpdateVectorSize(const Array &indices, const Buffer &buffer) { if (!inner_for_) return; - auto extent_ptr = inner_for_->extent.as(); - if (!extent_ptr) + // 1. Compute raw element offset + auto strides = buffer->strides; + if (buffer->strides.empty()) { + PrimExpr stride = 1; + for (int i = indices.size() - 1; i >= 0; --i) { + strides.push_back(stride); + stride = stride * buffer->shape[i]; + } + strides = Array{strides.rbegin(), strides.rend()}; + } + PrimExpr elem_offset = 0; + for (int i = 0; i < indices.size(); ++i) { + elem_offset += indices[i] * strides[i]; + } + + // 2. If element offset is independent with loop_var, ignore it + if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { return; + } - const DataType &access_type = buffer->dtype; - // i // 2, i % 8 can also be vectorized as factor 16 - int max_vector_size = vector_load_bits_max_ / access_type.bits(); - // so we should disable this GCD optimization - max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); - auto last_dim = buffer->shape.back(); - auto mod_set = analyzer_.modular_set(last_dim); - // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block - // conditionally tail vectorize - if (buffer->shape.back().as()) { - max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); - auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); - // If gcd_base is equal to the last dimension, - // we should analyze the second-to-last dimension - // in relation to the last dimension. - if (gcd_base < Downcast(last_dim)->value) { - max_vector_size = gcd_base; - } - vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); - - // Generate strides if not existed - auto strides = buffer->strides; - if (buffer->strides.empty()) { - PrimExpr stride = 1; - for (int i = indices.size() - 1; i >= 0; --i) { - strides.push_back(stride); - stride = stride * buffer->shape[i]; - } - strides = Array{strides.rbegin(), strides.rend()}; - } + // 3. Tight vectorize bound + vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ / + buffer->dtype.bits()); - // Generate and check element offset expression - ICHECK(indices.size() == strides.size()) << "Invalid indices and strides"; - PrimExpr elem_offset = 0; - for (int i = 0; i < indices.size(); ++i) { - elem_offset += indices[i] * strides[i]; - } - while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, - &analyzer_)) { - vector_size_ /= 2; - } - } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) { - // dynamic shape load: get the vectorization condition - dynamic_ = true; - PrimExpr offset = buffer.OffsetOf(indices).back(); - condition_ = (FloorMod(offset, vector_size_) == 0); + // 4. Try to vectorize buffer load + while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, &analyzer_)) { + vector_size_ /= 2; } } const int vector_load_bits_max_ = 128; const ForNode *inner_for_{}; - Map iter_map_; bool has_nonlocal_memory_access_ = false; int vector_size_ = 128; - // conditionally vectorize - bool dynamic_ = false; - PrimExpr condition_; }; class VectorizeRewriter : public StmtExprMutator { public: - VectorizeRewriter(const VectorizePlanResult &plan) - : vector_size_(plan.vector_size), condition_(plan.condition), - dynamic_(plan.dynamic) {} + VectorizeRewriter(int vector_size) : vector_size_(vector_size) {} private: Stmt VisitStmt_(const ForNode *node) final { @@ -197,23 +168,19 @@ class VectorizeRewriter : public StmtExprMutator { ICHECK(extent % vector_size_ == 0) << "extent: " << extent << " vector_size_: " << vector_size_; ICHECK(is_zero(fnode->min)); - if (!dynamic_) { // check dynamic shape - if (extent == vector_size_) { - fnode.CopyOnWrite()->kind = ForKind::kVectorized; - return fnode; - } else { - Var inner_var = Var("vec"); - Var outer_var = Var(old_var->name_hint); - Map vmap; - vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); - Stmt body = Substitute(fnode->body, vmap); - body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); - body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, - fnode->thread_binding, fnode->annotations, fnode->span); - return body; - } - } else { + if (extent == vector_size_) { + fnode.CopyOnWrite()->kind = ForKind::kVectorized; return fnode; + } else { + Var inner_var = Var("vec"); + Var outer_var = Var(old_var->name_hint); + Map vmap; + vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); + Stmt body = Substitute(fnode->body, vmap); + body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); + body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, + fnode->thread_binding, fnode->annotations, fnode->span); + return body; } } else { return ret; @@ -222,18 +189,25 @@ class VectorizeRewriter : public StmtExprMutator { const ForNode *inner_for_{}; const int vector_size_; - const PrimExpr condition_; - const bool dynamic_; }; int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } -VectorizePlanResult GetVectorizePlanResult(const For &loop) { - VectorizePlanner planner; - int vector_size = planner.Plan(loop); - bool dynamic = planner.GetDynamic(); - PrimExpr condition = planner.GetCondition(); - return {vector_size, dynamic, condition}; +bool CanProveIndependent(const PrimExpr &expr, Var var, + arith::Analyzer *analyzer) { + // 1. if var doesn't exist, it is independent + bool used_var = UsesVar( + expr, [&](const VarNode *v) { return GetRef(v).same_as(var); }); + if (!used_var) { + return true; + } + // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v + Var var_1("_t", var.dtype()); + auto expr_1 = Substitute(expr, {{var, var_1}}); + if (analyzer->CanProveEqual(expr, expr_1)) { + return true; + } + return false; } bool IndiceCanVectorize(const PrimExpr &expr, Var var, @@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, } For VectorizeLoop(const For &loop, int vectorize_hint) { - VectorizePlanResult res{128, false, 0}; if (vectorize_hint <= 0) { - res = GetVectorizePlanResult(loop); - vectorize_hint = res.vector_size; + VectorizePlanner planner; + vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) return loop; - auto rewriter = VectorizeRewriter(res); + auto rewriter = VectorizeRewriter(vectorize_hint); return Downcast(rewriter(loop)); } diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 253461e8a..4ab20c668 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop); For VectorizeLoop(const For &loop, int vectorize_hint = -1); +// Can prove expr is independent with var, i.e. the value of expr doesn't change +// when var changes +bool CanProveIndependent(const PrimExpr &expr, Var var, + arith::Analyzer *analyzer); bool IndiceCanVectorize(const PrimExpr &expr, Var var, const PrimExpr &iter_var_size, int target_vectorized_size, arith::Analyzer *analyzer);