-
Notifications
You must be signed in to change notification settings - Fork 156
[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop #844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop #844
Changes from all commits
5e76a22
50e9d3c
6e54530
4552b00
a12ba41
d16f083
51c6006
5f079ae
41c903d
475b2ba
c3e4492
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
"required for layout inference."; | ||
|
||
// Run InferLayout | ||
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n'; | ||
auto updates = | ||
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, | ||
&analyzer_, buffer_oob}, | ||
level); | ||
|
||
// Process the returned updates | ||
for (const auto &[buffer, layout] : updates) { | ||
DLOG(INFO) << " consider update " << buffer << " as " | ||
<< layout->DebugOutput() << '\n'; | ||
|
||
// Basic validity checks | ||
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; | ||
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; | ||
|
@@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
if (ProveFragmentContains(src_layout, dst_layout, indices, indices, | ||
inner_analyzer)) { | ||
layout_map.Set(buffer, layout); | ||
DLOG(INFO) << " layout broadcast from " | ||
<< src_layout->DebugOutput() << ", accepted" << '\n'; | ||
continue; | ||
} | ||
} | ||
|
@@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
} else { | ||
// Otherwise, update map | ||
layout_map.Set(buffer, layout); | ||
DLOG(INFO) << " new layout accepted" << '\n'; | ||
if (!update_queue) | ||
continue; | ||
|
||
|
@@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " | ||
"length."; | ||
|
||
DLOG(INFO) << "[InferLayout] all participating operators:" << '\n'; | ||
for (int i = 0; i < infer_list_stmt_.size(); ++i) { | ||
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n'; | ||
} | ||
|
||
// If needed, you can also check that annotated_layout_map_ is not empty, or | ||
// anything else relevant to your setup. | ||
|
||
|
@@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
|
||
void InferInFreeMode(LayoutMap &layout_map, | ||
const LayoutMap &strict_layout_map) { | ||
|
||
DLOG(INFO) << "Enforced layout maps:" << '\n'; | ||
for (auto &&[k, v] : layout_map) { | ||
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n'; | ||
} | ||
DLOG(INFO) << '\n'; | ||
|
||
// Group operators into connected components | ||
UnionFind<int> uf; | ||
for (int i = 0; i < infer_list_.size(); i++) { | ||
|
@@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
std::vector<bool> in_queue(infer_list_.size(), false); | ||
|
||
for (auto &&[root, members] : components) { | ||
DLOG(INFO) << "======================= processing component " << root | ||
<< '\n'; | ||
decltype(infer_list_) best_infer_list; | ||
LayoutMap best_layout_map; | ||
int64_t min_reg_num = INT64_MAX; | ||
int min_reg_num_infer_root = -1; | ||
|
||
// Try each member as the root of inference for this component | ||
for (int attempt_infer_root : members) { | ||
// backup infer_list_ in class member | ||
DLOG(INFO) << "----------------------- try root " << attempt_infer_root | ||
<< '\n'; | ||
// Backup the current infer_list_ state | ||
auto back_infer_list = BackupInferList(); | ||
// create temporarily used layout_map, new handle so that it copies on | ||
// write | ||
// Copy the current layout_map for temporary use | ||
LayoutMap tmp_layout_map = layout_map; | ||
// infer from attempt_infer_root in free mode | ||
bool do_update = true; | ||
try { | ||
// Run inference starting from attempt_infer_root | ||
RunInferStep(attempt_infer_root, InferLevel::kFree, true, | ||
tmp_layout_map, strict_layout_map, q, in_queue); | ||
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, | ||
q, in_queue); | ||
// Silly workaround: we have no clue if single root will iterate over | ||
// the entire component, since the InferLayout implementations have | ||
// complicated conditioning inside and we know nothing about it. | ||
// This would constantly result in incomplete layouts for buffers in | ||
// this component. Instead of trying all combinations of root | ||
// selection order, we simply go through all other loops in order | ||
// after the first search from attempt_infer_root. | ||
|
||
// After the first search, run inference for all other members in | ||
// order | ||
for (int other_infer_root : members) { | ||
if (other_infer_root != attempt_infer_root) { | ||
RunInferStep(other_infer_root, InferLevel::kFree, true, | ||
tmp_layout_map, strict_layout_map, q, in_queue); | ||
// must also be kFree here to avoid conflicts. | ||
FinishInferQueue(InferLevel::kFree, tmp_layout_map, | ||
strict_layout_map, q, in_queue); | ||
} | ||
} | ||
} catch (LayoutConflictException e) { | ||
// such an order fails, try others | ||
} catch (const LayoutConflictException &e) { | ||
do_update = false; | ||
} catch (NormalizeIterException e) { | ||
// such an order encounters iterators that is not normalizable, try | ||
// others e.g. i * 576 % 2048 | ||
DLOG(INFO) << "attempt failed due to LayoutConflictException " | ||
<< e.what() << '\n'; | ||
} catch (const NormalizeIterException &e) { | ||
do_update = false; | ||
DLOG(INFO) << "attempt failed due to NormalizeIterException " | ||
<< e.what() << '\n'; | ||
} | ||
|
||
if (do_update) { | ||
// compute total register number | ||
// Compute the total register number for this layout | ||
int64_t reg_num = 0; | ||
for (auto &&[buffer, layout] : tmp_layout_map) { | ||
for (const auto &[buffer, layout] : tmp_layout_map) { | ||
if (auto frag = layout.as<Fragment>()) { | ||
int64_t frag_reg_num = 1; | ||
for (auto i : frag.value()->OutputShape()) { | ||
|
@@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { | |
reg_num += frag_reg_num; | ||
} | ||
} | ||
// if it's any better, update the best_* storage | ||
// Update the best plan if this one uses fewer registers | ||
if (reg_num < min_reg_num) { | ||
best_infer_list = std::move(infer_list_); | ||
best_infer_list = | ||
BackupInferList(); // Use backup to avoid moving out infer_list_ | ||
best_layout_map = tmp_layout_map; | ||
min_reg_num = reg_num; | ||
min_reg_num_infer_root = attempt_infer_root; | ||
} | ||
} | ||
// recover stateful infer_list_, head on next | ||
// Restore infer_list_ state for the next attempt | ||
infer_list_ = std::move(back_infer_list); | ||
} | ||
if (min_reg_num < INT64_MAX) { | ||
// now apply the best plan for this component | ||
infer_list_ = std::move(best_infer_list); | ||
layout_map = best_layout_map; | ||
} | ||
ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n'; | ||
// Apply the best plan for this component | ||
infer_list_ = std::move(best_infer_list); | ||
layout_map = best_layout_map; | ||
DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = " | ||
<< min_reg_num_infer_root << '\n'; | ||
} | ||
} | ||
}; | ||
|
@@ -695,7 +717,8 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { | |
}); | ||
|
||
auto loop_layout = result_.for_map[root]; | ||
bool parallel_loop = !is_register_store && !skip_thread_partition_; | ||
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart | ||
bool parallel_loop = !skip_thread_partition_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition for Could you elaborate on the implications of this change? Forcing partitioning on loops that only use thread-local registers might be unnecessary if they are not shared. It would be helpful to understand the plan to address the |
||
|
||
if (parallel_loop) { | ||
for_node = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to
CheckWGMMA()
has been removed from the condition for enabling WGMMA. This function performs crucial checks for data types, K-dimension divisibility, and transpose flags, which are required for the correctness of WGMMA on the Hopper architecture. Without these checks, the compiler might generate WGMMA instructions for unsupported configurations, potentially leading to compilation errors or runtime failures.