diff --git a/.clang-tidy b/.clang-tidy index 8631d9211..e4a5f5519 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -42,6 +42,7 @@ Checks: > -cppcoreguidelines-pro-type-static-cast-downcast, -performance-unnecessary-value-param, -performance-enum-size, + -clang-analyzer-deadcode.DeadStores, WarningsAsErrors: '*' diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 52e78f807..ee6ad8aed 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -29,7 +29,7 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - row_offset = T.alloc_local((1,), "int32") + row_offset = T.alloc_fragment((1,), "int32") T.annotate_layout({ y_local: diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 4d2723f49..bb1b79133 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index 8ed37896f..1e4d4f4d1 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel = "tl.ptxas_register_usage_level"; static constexpr const char *kEnablePTXASVerboseOutput = "tl.enable_ptxas_verbose_output"; +static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; /*! * \brief Whether to disable dynamic tail split diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 3aae1f262..543de9090 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const { } GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; - bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && - (num_warps % 4 == 0) && CheckWGMMA(); + bool allow_wgmma = + !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && + CheckWGMMA(); if (allow_wgmma) { return GemmInst::kWGMMA; } else if (TargetIsCDNA(target)) { diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 19d17a6ee..402bbdc2b 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -128,9 +128,13 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { * visitor's reducer_info_map_. Continues traversal into the loop body. */ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { - ICHECK(op->kind == ForKind::kParallel); - p->loop_vars_.push_back( - IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); + if (op->kind == ForKind::kParallel) + p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, + IterVarType::kDataPar)); + else + p->inner_vars_.Set(op->loop_var, + IterVar(Range(op->min, op->extent), op->loop_var, + IterVarType::kOrdered)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); auto reducer_info_map = op->annotations.Get(attr::kReducerInfo)->as>(); @@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) { Fragment src_layout = T.layout_map[buffer].as().value(); + DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `" + << buffer << "` of layout " << src_layout->DebugOutput() << '\n'; + Fragment result; if (IsCommonAccessIndice(buffer)) { - return src_layout; + result = src_layout; } else { Var rep; auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); - return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) - ->BindThreadRange(T.thread_bounds); + loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); + PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { + if (auto opt_var = objref.as(); + opt_var && inner_vars_.count(*opt_var)) { + std::ostringstream oss; + oss << "loop_var_to_thread = " << loop_var_to_thread + << "contains inner var" << *opt_var; + throw LayoutConflictException(oss.str()); + } + }); + result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) + ->BindThreadRange(T.thread_bounds); } + DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " + << result->DebugOutput() << '\n'; + return result; }; if (source_buffer.defined()) { loop_layout_ = compute_loop_layout_from_buffer(source_buffer); @@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); int vector_size = GetVectorizeSize(maybe_remapped_root_); + DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; + PrimExpr loop_total_size = 1; for (Stmt l = root_; l.as().has_value(); l = l.as().value()->body) loop_total_size = loop_total_size * l.as().value()->extent; + DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size + << '\n'; while (!analyzer_.CanProve( floormod(loop_total_size, T.thread_bounds->extent * vector_size) == 0) && vector_size > 1) vector_size /= 2; + DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = " + << vector_size << '\n'; // Check if coalesced_width is defined if (auto coalesced_width = @@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, LOG(FATAL) << "coalesced_width should be an IntImmNode."; } } + DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ + << " ############# vector_size = " << vector_size + << ", thread_bounds = " << T.thread_bounds << '\n'; loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds); + DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = " + << loop_layout_->DebugOutput() << '\n'; } } else { return {}; diff --git a/src/op/parallel.h b/src/op/parallel.h index 3bc15c1e6..5f1f5a887 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -128,6 +128,7 @@ class ParallelOpNode : public TileOperatorNode { void AddPredicate(const PrimExpr &expr) const { predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; } + // Allow ParallelLoopNestVisitor to access private members. friend class ParallelLoopNestVisitor; @@ -139,6 +140,8 @@ class ParallelOpNode : public TileOperatorNode { std::unordered_set buffer_is_write_; // The loop variables for the parallel loop nest. Array loop_vars_; + // The inner_vars_ + Map inner_vars_; // Analyzer for simplifying and analyzing expressions, mutable for lazy use. mutable arith::Analyzer analyzer_; // Mapping from buffer to reducer info. diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 6e3806f1b..ce28e48be 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { "required for layout inference."; // Run InferLayout + DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n'; auto updates = next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, &analyzer_, buffer_oob}, level); - // Process the returned updates for (const auto &[buffer, layout] : updates) { + DLOG(INFO) << " consider update " << buffer << " as " + << layout->DebugOutput() << '\n'; + // Basic validity checks ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; @@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (ProveFragmentContains(src_layout, dst_layout, indices, indices, inner_analyzer)) { layout_map.Set(buffer, layout); + DLOG(INFO) << " layout broadcast from " + << src_layout->DebugOutput() << ", accepted" << '\n'; continue; } } @@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { // Otherwise, update map layout_map.Set(buffer, layout); + DLOG(INFO) << " new layout accepted" << '\n'; if (!update_queue) continue; @@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " "length."; + DLOG(INFO) << "[InferLayout] all participating operators:" << '\n'; + for (int i = 0; i < infer_list_stmt_.size(); ++i) { + DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n'; + } + // If needed, you can also check that annotated_layout_map_ is not empty, or // anything else relevant to your setup. @@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void InferInFreeMode(LayoutMap &layout_map, const LayoutMap &strict_layout_map) { + + DLOG(INFO) << "Enforced layout maps:" << '\n'; + for (auto &&[k, v] : layout_map) { + DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n'; + } + DLOG(INFO) << '\n'; + // Group operators into connected components UnionFind uf; for (int i = 0; i < infer_list_.size(); i++) { @@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { std::vector in_queue(infer_list_.size(), false); for (auto &&[root, members] : components) { + DLOG(INFO) << "======================= processing component " << root + << '\n'; decltype(infer_list_) best_infer_list; LayoutMap best_layout_map; int64_t min_reg_num = INT64_MAX; + int min_reg_num_infer_root = -1; + // Try each member as the root of inference for this component for (int attempt_infer_root : members) { - // backup infer_list_ in class member + DLOG(INFO) << "----------------------- try root " << attempt_infer_root + << '\n'; + // Backup the current infer_list_ state auto back_infer_list = BackupInferList(); - // create temporarily used layout_map, new handle so that it copies on - // write + // Copy the current layout_map for temporary use LayoutMap tmp_layout_map = layout_map; - // infer from attempt_infer_root in free mode bool do_update = true; try { + // Run inference starting from attempt_infer_root RunInferStep(attempt_infer_root, InferLevel::kFree, true, tmp_layout_map, strict_layout_map, q, in_queue); FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, q, in_queue); - // Silly workaround: we have no clue if single root will iterate over - // the entire component, since the InferLayout implementations have - // complicated conditioning inside and we know nothing about it. - // This would constantly result in incomplete layouts for buffers in - // this component. Instead of trying all combinations of root - // selection order, we simply go through all other loops in order - // after the first search from attempt_infer_root. + + // After the first search, run inference for all other members in + // order for (int other_infer_root : members) { if (other_infer_root != attempt_infer_root) { RunInferStep(other_infer_root, InferLevel::kFree, true, tmp_layout_map, strict_layout_map, q, in_queue); - // must also be kFree here to avoid conflicts. FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, q, in_queue); } } - } catch (LayoutConflictException e) { - // such an order fails, try others + } catch (const LayoutConflictException &e) { do_update = false; - } catch (NormalizeIterException e) { - // such an order encounters iterators that is not normalizable, try - // others e.g. i * 576 % 2048 + DLOG(INFO) << "attempt failed due to LayoutConflictException " + << e.what() << '\n'; + } catch (const NormalizeIterException &e) { do_update = false; + DLOG(INFO) << "attempt failed due to NormalizeIterException " + << e.what() << '\n'; } if (do_update) { - // compute total register number + // Compute the total register number for this layout int64_t reg_num = 0; - for (auto &&[buffer, layout] : tmp_layout_map) { + for (const auto &[buffer, layout] : tmp_layout_map) { if (auto frag = layout.as()) { int64_t frag_reg_num = 1; for (auto i : frag.value()->OutputShape()) { @@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { reg_num += frag_reg_num; } } - // if it's any better, update the best_* storage + // Update the best plan if this one uses fewer registers if (reg_num < min_reg_num) { - best_infer_list = std::move(infer_list_); + best_infer_list = + BackupInferList(); // Use backup to avoid moving out infer_list_ best_layout_map = tmp_layout_map; min_reg_num = reg_num; + min_reg_num_infer_root = attempt_infer_root; } } - // recover stateful infer_list_, head on next + // Restore infer_list_ state for the next attempt infer_list_ = std::move(back_infer_list); } - if (min_reg_num < INT64_MAX) { - // now apply the best plan for this component - infer_list_ = std::move(best_infer_list); - layout_map = best_layout_map; - } + ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n'; + // Apply the best plan for this component + infer_list_ = std::move(best_infer_list); + layout_map = best_layout_map; + DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = " + << min_reg_num_infer_root << '\n'; } } }; @@ -682,20 +704,25 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { // Here, A_local is a register-local buffer held independently by each // thread, so explicit thread binding is not required. // - // We use PostOrderVisit to detect whether the buffer store targets a - // "local" buffer, which indicates register usage and justifies skipping + // We use PostOrderVisit to detect whether the loop only manuplates + // "local" buffers, which indicates register usage and justifies skipping // thread binding. - bool is_register_store = false; + bool local_register_only = true; PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - if (store->buffer.scope() == "local") { - is_register_store = true; + if (store->buffer.scope() != "local") { + local_register_only = false; + } + } else if (const auto *load = obj.as()) { + if (load->buffer.scope() != "local") { + local_register_only = false; } } }); auto loop_layout = result_.for_map[root]; - bool parallel_loop = !is_register_store && !skip_thread_partition_; + // FIXME: tell in-Parallel and out-of-Parallel `local`s apart + bool parallel_loop = !skip_thread_partition_ && !local_register_only; if (parallel_loop) { for_node = diff --git a/src/transform/layout_reducer.cc b/src/transform/layout_reducer.cc index b216dbfe9..788e72a4d 100644 --- a/src/transform/layout_reducer.cc +++ b/src/transform/layout_reducer.cc @@ -178,7 +178,8 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer { Stmt VisitStmt_(const ForNode *op) final { // only annotate the outermost loop bool should_annotate = false; - if (!inside_reducer_range_.empty() && !already_annotated_) { + if (!inside_reducer_range_.empty() && !already_annotated_ && + op->kind == ForKind::kParallel) { should_annotate = true; already_annotated_ = true; } diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 326e56076..e3d667dec 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -639,13 +639,13 @@ class SharedMemoryRewriter : public StmtExprMutator { }; void PlanAlignment(const Stmt &stmt) { - LOG(INFO) << "PlanAlignment"; + DLOG(INFO) << "PlanAlignment"; PostOrderVisit(stmt, [&](const ObjectRef &node) { if (const auto *call = node.as()) { if (call->op.same_as(tl::tl_gemm()) || call->op.same_as(tl::tl_gemm_sp())) { - LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " - << call->op; + DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " + << call->op; } } }); diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index 9d3d3c661..fe22b783e 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -1789,8 +1789,8 @@ class VectorTypeRewriter : public StmtExprMutator { PrimExpr last_extent = extents[extents.size() - 1]; extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); - LOG(INFO) << "Allocate with " << new_buffer_var << " and " - << info.new_element_dtype << " extents: " << extents; + DLOG(INFO) << "Allocate with " << new_buffer_var << " and " + << info.new_element_dtype << " extents: " << extents; return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 20d230fa5..6e0485a17 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -45,6 +45,9 @@ class PassConfigKey(str, Enum): TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" """Disable safe memory access optimization. Default: False""" + TL_DISABLE_WGMMA = "tl.disable_wgmma" + """Disable usage of Hopper WGMMA. Default: False""" + TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" """Enable debug information for merge shared memory allocations. Default: False"""