Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,
-clang-analyzer-deadcode.DeadStores,

WarningsAsErrors: '*'

Expand Down
2 changes: 1 addition & 1 deletion examples/cast/example_group_per_split_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32")
row_offset = T.alloc_fragment((1,), "int32")

T.annotate_layout({
y_local:
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*!
* \brief Whether to disable dynamic tail split
Expand Down
8 changes: 6 additions & 2 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const {
}

GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
Expand Down
43 changes: 37 additions & 6 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,13 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
* visitor's reducer_info_map_. Continues traversal into the loop body.
*/
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
if (op->kind == ForKind::kParallel)
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kDataPar));
else
p->inner_vars_.Set(op->loop_var,
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
Expand Down Expand Up @@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result;
if (IsCommonAccessIndice(buffer)) {
return src_layout;
result = src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
<< result->DebugOutput() << '\n';
return result;
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
Expand Down Expand Up @@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);

DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';

PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent;
DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
<< '\n';
while (!analyzer_.CanProve(
floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1)
vector_size /= 2;
DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
<< vector_size << '\n';

// Check if coalesced_width is defined
if (auto coalesced_width =
Expand All @@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG(FATAL) << "coalesced_width should be an IntImmNode.";
}
}
DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
<< " ############# vector_size = " << vector_size
<< ", thread_bounds = " << T.thread_bounds << '\n';
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
} else {
return {};
Expand Down
3 changes: 3 additions & 0 deletions src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ParallelOpNode : public TileOperatorNode {
void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}

// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;

Expand All @@ -139,6 +140,8 @@ class ParallelOpNode : public TileOperatorNode {
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_;
// The inner_vars_
Map<Var, IterVar> inner_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info.
Expand Down
95 changes: 61 additions & 34 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
"required for layout inference.";

// Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);

// Process the returned updates
for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';

// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
Expand Down Expand Up @@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from "
<< src_layout->DebugOutput() << ", accepted" << '\n';
continue;
}
}
Expand All @@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n';
if (!update_queue)
continue;

Expand Down Expand Up @@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";

DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
for (int i = 0; i < infer_list_stmt_.size(); ++i) {
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n';
}

// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.

Expand Down Expand Up @@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {

void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {

DLOG(INFO) << "Enforced layout maps:" << '\n';
for (auto &&[k, v] : layout_map) {
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n';
}
DLOG(INFO) << '\n';

// Group operators into connected components
UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) {
Expand Down Expand Up @@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
std::vector<bool> in_queue(infer_list_.size(), false);

for (auto &&[root, members] : components) {
DLOG(INFO) << "======================= processing component " << root
<< '\n';
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
int min_reg_num_infer_root = -1;

// Try each member as the root of inference for this component
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n';
// Backup the current infer_list_ state
auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on
// write
// Copy the current layout_map for temporary use
LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true;
try {
// Run inference starting from attempt_infer_root
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.

// After the first search, run inference for all other members in
// order
for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue);
}
}
} catch (LayoutConflictException e) {
// such an order fails, try others
} catch (const LayoutConflictException &e) {
do_update = false;
} catch (NormalizeIterException e) {
// such an order encounters iterators that is not normalizable, try
// others e.g. i * 576 % 2048
DLOG(INFO) << "attempt failed due to LayoutConflictException "
<< e.what() << '\n';
} catch (const NormalizeIterException &e) {
do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n';
}

if (do_update) {
// compute total register number
// Compute the total register number for this layout
int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) {
for (const auto &[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
Expand All @@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
reg_num += frag_reg_num;
}
}
// if it's any better, update the best_* storage
// Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_);
best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map;
min_reg_num = reg_num;
min_reg_num_infer_root = attempt_infer_root;
}
}
// recover stateful infer_list_, head on next
// Restore infer_list_ state for the next attempt
infer_list_ = std::move(back_infer_list);
}
if (min_reg_num < INT64_MAX) {
// now apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
}
ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
// Apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
<< min_reg_num_infer_root << '\n';
}
}
};
Expand Down Expand Up @@ -682,20 +704,25 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
//
// We use PostOrderVisit to detect whether the buffer store targets a
// "local" buffer, which indicates register usage and justifies skipping
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool is_register_store = false;
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
is_register_store = true;
if (store->buffer.scope() != "local") {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});
Comment on lines 709 to 721
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cover every thread-local storage scope when detecting register-only loops.

local_register_only currently flips to false for anything except scope "local", but we also emit "local.var" and "local.fragment" for per-thread allocations (see IsLocalBuffer in src/transform/legalize_safe_memory_access.cc). Loops that only touch those buffers will still be treated as needing thread partitioning, so the regression this PR is trying to fix persists for those cases. Please treat all thread-local scopes as register-only before skipping partition.

You can address it along these lines:

-      bool local_register_only = true;
-      PostOrderVisit(root, [&](const ObjectRef &obj) {
-        if (const auto *store = obj.as<BufferStoreNode>()) {
-          if (store->buffer.scope() != "local") {
-            local_register_only = false;
-          }
-        } else if (const auto *load = obj.as<BufferLoadNode>()) {
-          if (load->buffer.scope() != "local") {
-            local_register_only = false;
-          }
-        }
-      });
+      auto is_thread_local_scope = [](const String &scope) {
+        return scope == "local" || scope == "local.fragment" ||
+               scope == "local.var";
+      };
+      bool local_register_only = true;
+      PostOrderVisit(root, [&](const ObjectRef &obj) {
+        if (const auto *store = obj.as<BufferStoreNode>()) {
+          if (!is_thread_local_scope(store->buffer.scope())) {
+            local_register_only = false;
+          }
+        } else if (const auto *load = obj.as<BufferLoadNode>()) {
+          if (!is_thread_local_scope(load->buffer.scope())) {
+            local_register_only = false;
+          }
+        }
+      });

Based on learnings

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// thread binding.
bool is_register_store = false;
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
is_register_store = true;
if (store->buffer.scope() != "local") {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});
// thread binding.
auto is_thread_local_scope = [](const String &scope) {
return scope == "local" || scope == "local.fragment" ||
scope == "local.var";
};
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (!is_thread_local_scope(store->buffer.scope())) {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (!is_thread_local_scope(load->buffer.scope())) {
local_register_only = false;
}
}
});
🤖 Prompt for AI Agents
In src/transform/layout_inference.cc around lines 709 to 721, the
local_register_only check only treats scope "local" as
thread-local/register-only, but other per-thread scopes like "local.var" and
"local.fragment" are omitted; update the predicate to cover all thread-local
buffer scopes (e.g., treat any scope that equals "local" or starts with "local."
as register-only, or call the existing IsLocalBuffer/IsThreadLocal helper if
available) so loops touching those buffers are correctly considered
register-only and skip partitioning.


auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_;
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
bool parallel_loop = !skip_thread_partition_ && !local_register_only;

if (parallel_loop) {
for_node =
Expand Down
3 changes: 2 additions & 1 deletion src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop
bool should_annotate = false;
if (!inside_reducer_range_.empty() && !already_annotated_) {
if (!inside_reducer_range_.empty() && !already_annotated_ &&
op->kind == ForKind::kParallel) {
should_annotate = true;
already_annotated_ = true;
}
Expand Down
6 changes: 3 additions & 3 deletions src/transform/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
};

void PlanAlignment(const Stmt &stmt) {
LOG(INFO) << "PlanAlignment";
DLOG(INFO) << "PlanAlignment";
PostOrderVisit(stmt, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(tl::tl_gemm()) ||
call->op.same_as(tl::tl_gemm_sp())) {
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
}
}
});
Expand Down
4 changes: 2 additions & 2 deletions src/transform/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1789,8 +1789,8 @@ class VectorTypeRewriter : public StmtExprMutator {
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
}
Expand Down
Loading
Loading