Skip to content

Commit c382dcb

Browse files
LeiWang1999Huanqi Cao
andauthored
[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop (#844)
* Support T.serial and local buffers inside T.Parallel loop. * Fix reducer layout in T.Parallel nested inside other loops * Debug output with LOG(INFO) * Add disable option for WGMMA. * fix * Use DLOG; fix missing registration for new pass config * bug fix * lint fix * Enhance GEMM instruction set with UTCMMA and improve local buffer handling in casting example * Update format.sh shebang, improve logging in layout inference, and enhance buffer store wrapper with detailed comments * Enhance GEMM instantiation logic and improve layout inference for local buffer detection - Updated the GEMM instantiation logic to include a check for WGMMA compatibility, ensuring that the conditions for using WGMMA are more robust. - Refined the layout inference process to better identify when loops manipulate only local buffers, improving the accuracy of thread binding decisions in parallel loops. --------- Co-authored-by: Huanqi Cao <[email protected]>
1 parent bf67fb1 commit c382dcb

File tree

12 files changed

+121
-49
lines changed

12 files changed

+121
-49
lines changed

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Checks: >
4242
-cppcoreguidelines-pro-type-static-cast-downcast,
4343
-performance-unnecessary-value-param,
4444
-performance-enum-size,
45+
-clang-analyzer-deadcode.DeadStores,
4546
4647
WarningsAsErrors: '*'
4748

examples/cast/example_group_per_split_token_cast_to_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor
2929
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
3030
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
3131
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
32-
row_offset = T.alloc_local((1,), "int32")
32+
row_offset = T.alloc_fragment((1,), "int32")
3333

3434
T.annotate_layout({
3535
y_local:

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
2929
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
3030
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
3131
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
32+
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
3233
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
3334

3435
DataType cuTensorMapType() { return DataType::UInt(8, 128); }

src/op/builtin.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
4545
"tl.ptxas_register_usage_level";
4646
static constexpr const char *kEnablePTXASVerboseOutput =
4747
"tl.enable_ptxas_verbose_output";
48+
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
4849
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
4950
/*!
5051
* \brief Whether to disable dynamic tail split

src/op/gemm.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const {
9292
}
9393

9494
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
95+
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
96+
9597
int warp_size = TargetGetWarpSize(target);
9698
int num_warps = block_size / warp_size;
97-
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
98-
(num_warps % 4 == 0) && CheckWGMMA();
99+
bool allow_wgmma =
100+
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
101+
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
102+
CheckWGMMA();
99103
if (allow_wgmma) {
100104
return GemmInst::kWGMMA;
101105
} else if (TargetIsCDNA(target)) {

src/op/parallel.cc

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,13 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
128128
* visitor's reducer_info_map_. Continues traversal into the loop body.
129129
*/
130130
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
131-
ICHECK(op->kind == ForKind::kParallel);
132-
p->loop_vars_.push_back(
133-
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
131+
if (op->kind == ForKind::kParallel)
132+
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
133+
IterVarType::kDataPar));
134+
else
135+
p->inner_vars_.Set(op->loop_var,
136+
IterVar(Range(op->min, op->extent), op->loop_var,
137+
IterVarType::kOrdered));
134138
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
135139
auto reducer_info_map =
136140
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
@@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
244248
}
245249
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
246250
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
251+
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
252+
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
253+
Fragment result;
247254
if (IsCommonAccessIndice(buffer)) {
248-
return src_layout;
255+
result = src_layout;
249256
} else {
250257
Var rep;
251258
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
252259
IterVarType::kDataPar);
253260
PrimExpr loop_var_to_thread =
254261
src_layout->ForwardThread(indice_map_[buffer], rep);
255-
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
256-
->BindThreadRange(T.thread_bounds);
262+
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
263+
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
264+
if (auto opt_var = objref.as<Var>();
265+
opt_var && inner_vars_.count(*opt_var)) {
266+
std::ostringstream oss;
267+
oss << "loop_var_to_thread = " << loop_var_to_thread
268+
<< "contains inner var" << *opt_var;
269+
throw LayoutConflictException(oss.str());
270+
}
271+
});
272+
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
273+
->BindThreadRange(T.thread_bounds);
257274
}
275+
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
276+
<< result->DebugOutput() << '\n';
277+
return result;
258278
};
259279
if (source_buffer.defined()) {
260280
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
@@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
317337
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
318338
int vector_size = GetVectorizeSize(maybe_remapped_root_);
319339

340+
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
341+
320342
PrimExpr loop_total_size = 1;
321343
for (Stmt l = root_; l.as<For>().has_value();
322344
l = l.as<For>().value()->body)
323345
loop_total_size = loop_total_size * l.as<For>().value()->extent;
346+
DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
347+
<< '\n';
324348
while (!analyzer_.CanProve(
325349
floormod(loop_total_size,
326350
T.thread_bounds->extent * vector_size) == 0) &&
327351
vector_size > 1)
328352
vector_size /= 2;
353+
DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
354+
<< vector_size << '\n';
329355

330356
// Check if coalesced_width is defined
331357
if (auto coalesced_width =
@@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
342368
LOG(FATAL) << "coalesced_width should be an IntImmNode.";
343369
}
344370
}
371+
DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
372+
<< " ############# vector_size = " << vector_size
373+
<< ", thread_bounds = " << T.thread_bounds << '\n';
345374
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
375+
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
376+
<< loop_layout_->DebugOutput() << '\n';
346377
}
347378
} else {
348379
return {};

src/op/parallel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class ParallelOpNode : public TileOperatorNode {
128128
void AddPredicate(const PrimExpr &expr) const {
129129
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
130130
}
131+
131132
// Allow ParallelLoopNestVisitor to access private members.
132133
friend class ParallelLoopNestVisitor;
133134

@@ -139,6 +140,8 @@ class ParallelOpNode : public TileOperatorNode {
139140
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
140141
// The loop variables for the parallel loop nest.
141142
Array<IterVar> loop_vars_;
143+
// The inner_vars_
144+
Map<Var, IterVar> inner_vars_;
142145
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
143146
mutable arith::Analyzer analyzer_;
144147
// Mapping from buffer to reducer info.

src/transform/layout_inference.cc

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
105105
"required for layout inference.";
106106

107107
// Run InferLayout
108+
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
108109
auto updates =
109110
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
110111
&analyzer_, buffer_oob},
111112
level);
112-
113113
// Process the returned updates
114114
for (const auto &[buffer, layout] : updates) {
115+
DLOG(INFO) << " consider update " << buffer << " as "
116+
<< layout->DebugOutput() << '\n';
117+
115118
// Basic validity checks
116119
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
117120
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
@@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
140143
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
141144
inner_analyzer)) {
142145
layout_map.Set(buffer, layout);
146+
DLOG(INFO) << " layout broadcast from "
147+
<< src_layout->DebugOutput() << ", accepted" << '\n';
143148
continue;
144149
}
145150
}
@@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
151156
} else {
152157
// Otherwise, update map
153158
layout_map.Set(buffer, layout);
159+
DLOG(INFO) << " new layout accepted" << '\n';
154160
if (!update_queue)
155161
continue;
156162

@@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
210216
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
211217
"length.";
212218

219+
DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
220+
for (int i = 0; i < infer_list_stmt_.size(); ++i) {
221+
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n';
222+
}
223+
213224
// If needed, you can also check that annotated_layout_map_ is not empty, or
214225
// anything else relevant to your setup.
215226

@@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
470481

471482
void InferInFreeMode(LayoutMap &layout_map,
472483
const LayoutMap &strict_layout_map) {
484+
485+
DLOG(INFO) << "Enforced layout maps:" << '\n';
486+
for (auto &&[k, v] : layout_map) {
487+
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n';
488+
}
489+
DLOG(INFO) << '\n';
490+
473491
// Group operators into connected components
474492
UnionFind<int> uf;
475493
for (int i = 0; i < infer_list_.size(); i++) {
@@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
505523
std::vector<bool> in_queue(infer_list_.size(), false);
506524

507525
for (auto &&[root, members] : components) {
526+
DLOG(INFO) << "======================= processing component " << root
527+
<< '\n';
508528
decltype(infer_list_) best_infer_list;
509529
LayoutMap best_layout_map;
510530
int64_t min_reg_num = INT64_MAX;
531+
int min_reg_num_infer_root = -1;
511532

533+
// Try each member as the root of inference for this component
512534
for (int attempt_infer_root : members) {
513-
// backup infer_list_ in class member
535+
DLOG(INFO) << "----------------------- try root " << attempt_infer_root
536+
<< '\n';
537+
// Backup the current infer_list_ state
514538
auto back_infer_list = BackupInferList();
515-
// create temporarily used layout_map, new handle so that it copies on
516-
// write
539+
// Copy the current layout_map for temporary use
517540
LayoutMap tmp_layout_map = layout_map;
518-
// infer from attempt_infer_root in free mode
519541
bool do_update = true;
520542
try {
543+
// Run inference starting from attempt_infer_root
521544
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
522545
tmp_layout_map, strict_layout_map, q, in_queue);
523546
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
524547
q, in_queue);
525-
// Silly workaround: we have no clue if single root will iterate over
526-
// the entire component, since the InferLayout implementations have
527-
// complicated conditioning inside and we know nothing about it.
528-
// This would constantly result in incomplete layouts for buffers in
529-
// this component. Instead of trying all combinations of root
530-
// selection order, we simply go through all other loops in order
531-
// after the first search from attempt_infer_root.
548+
549+
// After the first search, run inference for all other members in
550+
// order
532551
for (int other_infer_root : members) {
533552
if (other_infer_root != attempt_infer_root) {
534553
RunInferStep(other_infer_root, InferLevel::kFree, true,
535554
tmp_layout_map, strict_layout_map, q, in_queue);
536-
// must also be kFree here to avoid conflicts.
537555
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
538556
strict_layout_map, q, in_queue);
539557
}
540558
}
541-
} catch (LayoutConflictException e) {
542-
// such an order fails, try others
559+
} catch (const LayoutConflictException &e) {
543560
do_update = false;
544-
} catch (NormalizeIterException e) {
545-
// such an order encounters iterators that is not normalizable, try
546-
// others e.g. i * 576 % 2048
561+
DLOG(INFO) << "attempt failed due to LayoutConflictException "
562+
<< e.what() << '\n';
563+
} catch (const NormalizeIterException &e) {
547564
do_update = false;
565+
DLOG(INFO) << "attempt failed due to NormalizeIterException "
566+
<< e.what() << '\n';
548567
}
549568

550569
if (do_update) {
551-
// compute total register number
570+
// Compute the total register number for this layout
552571
int64_t reg_num = 0;
553-
for (auto &&[buffer, layout] : tmp_layout_map) {
572+
for (const auto &[buffer, layout] : tmp_layout_map) {
554573
if (auto frag = layout.as<Fragment>()) {
555574
int64_t frag_reg_num = 1;
556575
for (auto i : frag.value()->OutputShape()) {
@@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
561580
reg_num += frag_reg_num;
562581
}
563582
}
564-
// if it's any better, update the best_* storage
583+
// Update the best plan if this one uses fewer registers
565584
if (reg_num < min_reg_num) {
566-
best_infer_list = std::move(infer_list_);
585+
best_infer_list =
586+
BackupInferList(); // Use backup to avoid moving out infer_list_
567587
best_layout_map = tmp_layout_map;
568588
min_reg_num = reg_num;
589+
min_reg_num_infer_root = attempt_infer_root;
569590
}
570591
}
571-
// recover stateful infer_list_, head on next
592+
// Restore infer_list_ state for the next attempt
572593
infer_list_ = std::move(back_infer_list);
573594
}
574-
if (min_reg_num < INT64_MAX) {
575-
// now apply the best plan for this component
576-
infer_list_ = std::move(best_infer_list);
577-
layout_map = best_layout_map;
578-
}
595+
ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
596+
// Apply the best plan for this component
597+
infer_list_ = std::move(best_infer_list);
598+
layout_map = best_layout_map;
599+
DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
600+
<< min_reg_num_infer_root << '\n';
579601
}
580602
}
581603
};
@@ -682,20 +704,25 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
682704
// Here, A_local is a register-local buffer held independently by each
683705
// thread, so explicit thread binding is not required.
684706
//
685-
// We use PostOrderVisit to detect whether the buffer store targets a
686-
// "local" buffer, which indicates register usage and justifies skipping
707+
// We use PostOrderVisit to detect whether the loop only manuplates
708+
// "local" buffers, which indicates register usage and justifies skipping
687709
// thread binding.
688-
bool is_register_store = false;
710+
bool local_register_only = true;
689711
PostOrderVisit(root, [&](const ObjectRef &obj) {
690712
if (const auto *store = obj.as<BufferStoreNode>()) {
691-
if (store->buffer.scope() == "local") {
692-
is_register_store = true;
713+
if (store->buffer.scope() != "local") {
714+
local_register_only = false;
715+
}
716+
} else if (const auto *load = obj.as<BufferLoadNode>()) {
717+
if (load->buffer.scope() != "local") {
718+
local_register_only = false;
693719
}
694720
}
695721
});
696722

697723
auto loop_layout = result_.for_map[root];
698-
bool parallel_loop = !is_register_store && !skip_thread_partition_;
724+
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
725+
bool parallel_loop = !skip_thread_partition_ && !local_register_only;
699726

700727
if (parallel_loop) {
701728
for_node =

src/transform/layout_reducer.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
178178
Stmt VisitStmt_(const ForNode *op) final {
179179
// only annotate the outermost loop
180180
bool should_annotate = false;
181-
if (!inside_reducer_range_.empty() && !already_annotated_) {
181+
if (!inside_reducer_range_.empty() && !already_annotated_ &&
182+
op->kind == ForKind::kParallel) {
182183
should_annotate = true;
183184
already_annotated_ = true;
184185
}

src/transform/merge_shared_memory_allocations.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,13 +639,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
639639
};
640640

641641
void PlanAlignment(const Stmt &stmt) {
642-
LOG(INFO) << "PlanAlignment";
642+
DLOG(INFO) << "PlanAlignment";
643643
PostOrderVisit(stmt, [&](const ObjectRef &node) {
644644
if (const auto *call = node.as<CallNode>()) {
645645
if (call->op.same_as(tl::tl_gemm()) ||
646646
call->op.same_as(tl::tl_gemm_sp())) {
647-
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
648-
<< call->op;
647+
DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
648+
<< call->op;
649649
}
650650
}
651651
});

0 commit comments

Comments
 (0)