@@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
105
105
" required for layout inference." ;
106
106
107
107
// Run InferLayout
108
+ DLOG (INFO) << " [RunInferStep] working on " << cur_infer_id << ' \n ' ;
108
109
auto updates =
109
110
next->InferLayout (LayoutInferArgs{target_, thread_bounds, layout_map,
110
111
&analyzer_, buffer_oob},
111
112
level);
112
-
113
113
// Process the returned updates
114
114
for (const auto &[buffer, layout] : updates) {
115
+ DLOG (INFO) << " consider update " << buffer << " as "
116
+ << layout->DebugOutput () << ' \n ' ;
117
+
115
118
// Basic validity checks
116
119
ICHECK (buffer.defined ()) << " InferLayout returned an undefined buffer." ;
117
120
ICHECK (layout.defined ()) << " InferLayout returned an undefined layout." ;
@@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
140
143
if (ProveFragmentContains (src_layout, dst_layout, indices, indices,
141
144
inner_analyzer)) {
142
145
layout_map.Set (buffer, layout);
146
+ DLOG (INFO) << " layout broadcast from "
147
+ << src_layout->DebugOutput () << " , accepted" << ' \n ' ;
143
148
continue ;
144
149
}
145
150
}
@@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
151
156
} else {
152
157
// Otherwise, update map
153
158
layout_map.Set (buffer, layout);
159
+ DLOG (INFO) << " new layout accepted" << ' \n ' ;
154
160
if (!update_queue)
155
161
continue ;
156
162
@@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
210
216
<< " Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
211
217
" length." ;
212
218
219
+ DLOG (INFO) << " [InferLayout] all participating operators:" << ' \n ' ;
220
+ for (int i = 0 ; i < infer_list_stmt_.size (); ++i) {
221
+ DLOG (INFO) << " op " << i << " :" << infer_list_stmt_[i] << ' \n ' ;
222
+ }
223
+
213
224
// If needed, you can also check that annotated_layout_map_ is not empty, or
214
225
// anything else relevant to your setup.
215
226
@@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
470
481
471
482
void InferInFreeMode (LayoutMap &layout_map,
472
483
const LayoutMap &strict_layout_map) {
484
+
485
+ DLOG (INFO) << " Enforced layout maps:" << ' \n ' ;
486
+ for (auto &&[k, v] : layout_map) {
487
+ DLOG (INFO) << " " << k << " : " << v->DebugOutput () << ' \n ' ;
488
+ }
489
+ DLOG (INFO) << ' \n ' ;
490
+
473
491
// Group operators into connected components
474
492
UnionFind<int > uf;
475
493
for (int i = 0 ; i < infer_list_.size (); i++) {
@@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
505
523
std::vector<bool > in_queue (infer_list_.size (), false );
506
524
507
525
for (auto &&[root, members] : components) {
526
+ DLOG (INFO) << " ======================= processing component " << root
527
+ << ' \n ' ;
508
528
decltype (infer_list_) best_infer_list;
509
529
LayoutMap best_layout_map;
510
530
int64_t min_reg_num = INT64_MAX;
531
+ int min_reg_num_infer_root = -1 ;
511
532
533
+ // Try each member as the root of inference for this component
512
534
for (int attempt_infer_root : members) {
513
- // backup infer_list_ in class member
535
+ DLOG (INFO) << " ----------------------- try root " << attempt_infer_root
536
+ << ' \n ' ;
537
+ // Backup the current infer_list_ state
514
538
auto back_infer_list = BackupInferList ();
515
- // create temporarily used layout_map, new handle so that it copies on
516
- // write
539
+ // Copy the current layout_map for temporary use
517
540
LayoutMap tmp_layout_map = layout_map;
518
- // infer from attempt_infer_root in free mode
519
541
bool do_update = true ;
520
542
try {
543
+ // Run inference starting from attempt_infer_root
521
544
RunInferStep (attempt_infer_root, InferLevel::kFree , true ,
522
545
tmp_layout_map, strict_layout_map, q, in_queue);
523
546
FinishInferQueue (InferLevel::kFree , tmp_layout_map, strict_layout_map,
524
547
q, in_queue);
525
- // Silly workaround: we have no clue if single root will iterate over
526
- // the entire component, since the InferLayout implementations have
527
- // complicated conditioning inside and we know nothing about it.
528
- // This would constantly result in incomplete layouts for buffers in
529
- // this component. Instead of trying all combinations of root
530
- // selection order, we simply go through all other loops in order
531
- // after the first search from attempt_infer_root.
548
+
549
+ // After the first search, run inference for all other members in
550
+ // order
532
551
for (int other_infer_root : members) {
533
552
if (other_infer_root != attempt_infer_root) {
534
553
RunInferStep (other_infer_root, InferLevel::kFree , true ,
535
554
tmp_layout_map, strict_layout_map, q, in_queue);
536
- // must also be kFree here to avoid conflicts.
537
555
FinishInferQueue (InferLevel::kFree , tmp_layout_map,
538
556
strict_layout_map, q, in_queue);
539
557
}
540
558
}
541
- } catch (LayoutConflictException e) {
542
- // such an order fails, try others
559
+ } catch (const LayoutConflictException &e) {
543
560
do_update = false ;
544
- } catch (NormalizeIterException e) {
545
- // such an order encounters iterators that is not normalizable, try
546
- // others e.g. i * 576 % 2048
561
+ DLOG (INFO) << " attempt failed due to LayoutConflictException "
562
+ << e. what () << ' \n ' ;
563
+ } catch ( const NormalizeIterException &e) {
547
564
do_update = false ;
565
+ DLOG (INFO) << " attempt failed due to NormalizeIterException "
566
+ << e.what () << ' \n ' ;
548
567
}
549
568
550
569
if (do_update) {
551
- // compute total register number
570
+ // Compute the total register number for this layout
552
571
int64_t reg_num = 0 ;
553
- for (auto & &[buffer, layout] : tmp_layout_map) {
572
+ for (const auto &[buffer, layout] : tmp_layout_map) {
554
573
if (auto frag = layout.as <Fragment>()) {
555
574
int64_t frag_reg_num = 1 ;
556
575
for (auto i : frag.value ()->OutputShape ()) {
@@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
561
580
reg_num += frag_reg_num;
562
581
}
563
582
}
564
- // if it's any better, update the best_* storage
583
+ // Update the best plan if this one uses fewer registers
565
584
if (reg_num < min_reg_num) {
566
- best_infer_list = std::move (infer_list_);
585
+ best_infer_list =
586
+ BackupInferList (); // Use backup to avoid moving out infer_list_
567
587
best_layout_map = tmp_layout_map;
568
588
min_reg_num = reg_num;
589
+ min_reg_num_infer_root = attempt_infer_root;
569
590
}
570
591
}
571
- // recover stateful infer_list_, head on next
592
+ // Restore infer_list_ state for the next attempt
572
593
infer_list_ = std::move (back_infer_list);
573
594
}
574
- if (min_reg_num < INT64_MAX) {
575
- // now apply the best plan for this component
576
- infer_list_ = std::move (best_infer_list);
577
- layout_map = best_layout_map;
578
- }
595
+ ICHECK (min_reg_num < INT64_MAX) << " no available layout found" << ' \n ' ;
596
+ // Apply the best plan for this component
597
+ infer_list_ = std::move (best_infer_list);
598
+ layout_map = best_layout_map;
599
+ DLOG (INFO) << " [InferInFreeMode] Final selection is attempt_infer_root = "
600
+ << min_reg_num_infer_root << ' \n ' ;
579
601
}
580
602
}
581
603
};
@@ -682,20 +704,25 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
682
704
// Here, A_local is a register-local buffer held independently by each
683
705
// thread, so explicit thread binding is not required.
684
706
//
685
- // We use PostOrderVisit to detect whether the buffer store targets a
686
- // "local" buffer , which indicates register usage and justifies skipping
707
+ // We use PostOrderVisit to detect whether the loop only manuplates
708
+ // "local" buffers , which indicates register usage and justifies skipping
687
709
// thread binding.
688
- bool is_register_store = false ;
710
+ bool local_register_only = true ;
689
711
PostOrderVisit (root, [&](const ObjectRef &obj) {
690
712
if (const auto *store = obj.as <BufferStoreNode>()) {
691
- if (store->buffer .scope () == " local" ) {
692
- is_register_store = true ;
713
+ if (store->buffer .scope () != " local" ) {
714
+ local_register_only = false ;
715
+ }
716
+ } else if (const auto *load = obj.as <BufferLoadNode>()) {
717
+ if (load->buffer .scope () != " local" ) {
718
+ local_register_only = false ;
693
719
}
694
720
}
695
721
});
696
722
697
723
auto loop_layout = result_.for_map [root];
698
- bool parallel_loop = !is_register_store && !skip_thread_partition_;
724
+ // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
725
+ bool parallel_loop = !skip_thread_partition_ && !local_register_only;
699
726
700
727
if (parallel_loop) {
701
728
for_node =
0 commit comments