Skip to content
Open
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,
-clang-analyzer-deadcode.DeadStores,

WarningsAsErrors: '*'

Expand Down
2 changes: 1 addition & 1 deletion examples/cast/example_group_per_split_token_cast_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32")
row_offset = T.alloc_fragment((1,), "int32")

T.annotate_layout({
y_local:
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*!
* \brief Whether to disable dynamic tail split
Expand Down
7 changes: 5 additions & 2 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,13 @@ TileOperator GemmNode::Clone() const {
}

GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();

int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0);
Comment on lines +99 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to CheckWGMMA() has been removed from the condition for enabling WGMMA. This function performs crucial checks for data types, K-dimension divisibility, and transpose flags, which are required for the correctness of WGMMA on the Hopper architecture. Without these checks, the compiler might generate WGMMA instructions for unsupported configurations, potentially leading to compilation errors or runtime failures.

Suggested change
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0);
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && CheckWGMMA();

if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
Expand Down
43 changes: 37 additions & 6 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,13 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
* visitor's reducer_info_map_. Continues traversal into the loop body.
*/
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
if (op->kind == ForKind::kParallel)
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kDataPar));
else
p->inner_vars_.Set(op->loop_var,
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
Expand Down Expand Up @@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result;
if (IsCommonAccessIndice(buffer)) {
return src_layout;
result = src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
<< result->DebugOutput() << '\n';
return result;
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
Expand Down Expand Up @@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);

DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';

PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent;
DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
<< '\n';
while (!analyzer_.CanProve(
floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1)
vector_size /= 2;
DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
<< vector_size << '\n';

// Check if coalesced_width is defined
if (auto coalesced_width =
Expand All @@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG(FATAL) << "coalesced_width should be an IntImmNode.";
}
}
DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
<< " ############# vector_size = " << vector_size
<< ", thread_bounds = " << T.thread_bounds << '\n';
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
} else {
return {};
Expand Down
3 changes: 3 additions & 0 deletions src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ParallelOpNode : public TileOperatorNode {
void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}

// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;

Expand All @@ -139,6 +140,8 @@ class ParallelOpNode : public TileOperatorNode {
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_;
// The inner_vars_
Map<Var, IterVar> inner_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info.
Expand Down
81 changes: 52 additions & 29 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
"required for layout inference.";

// Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);

// Process the returned updates
for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';

// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
Expand Down Expand Up @@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from "
<< src_layout->DebugOutput() << ", accepted" << '\n';
continue;
}
}
Expand All @@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n';
if (!update_queue)
continue;

Expand Down Expand Up @@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";

DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
for (int i = 0; i < infer_list_stmt_.size(); ++i) {
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n';
}

// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.

Expand Down Expand Up @@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {

void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {

DLOG(INFO) << "Enforced layout maps:" << '\n';
for (auto &&[k, v] : layout_map) {
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n';
}
DLOG(INFO) << '\n';

// Group operators into connected components
UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) {
Expand Down Expand Up @@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
std::vector<bool> in_queue(infer_list_.size(), false);

for (auto &&[root, members] : components) {
DLOG(INFO) << "======================= processing component " << root
<< '\n';
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
int min_reg_num_infer_root = -1;

// Try each member as the root of inference for this component
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n';
// Backup the current infer_list_ state
auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on
// write
// Copy the current layout_map for temporary use
LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true;
try {
// Run inference starting from attempt_infer_root
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.

// After the first search, run inference for all other members in
// order
for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue);
}
}
} catch (LayoutConflictException e) {
// such an order fails, try others
} catch (const LayoutConflictException &e) {
do_update = false;
} catch (NormalizeIterException e) {
// such an order encounters iterators that is not normalizable, try
// others e.g. i * 576 % 2048
DLOG(INFO) << "attempt failed due to LayoutConflictException "
<< e.what() << '\n';
} catch (const NormalizeIterException &e) {
do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n';
}

if (do_update) {
// compute total register number
// Compute the total register number for this layout
int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) {
for (const auto &[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
Expand All @@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
reg_num += frag_reg_num;
}
}
// if it's any better, update the best_* storage
// Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_);
best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map;
min_reg_num = reg_num;
min_reg_num_infer_root = attempt_infer_root;
}
}
// recover stateful infer_list_, head on next
// Restore infer_list_ state for the next attempt
infer_list_ = std::move(back_infer_list);
}
if (min_reg_num < INT64_MAX) {
// now apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
}
ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
// Apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
<< min_reg_num_infer_root << '\n';
}
}
};
Expand Down Expand Up @@ -695,7 +717,8 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
});

auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_;
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
bool parallel_loop = !skip_thread_partition_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition for parallel_loop no longer includes !is_register_store. This will cause loops that only store to register-local buffers ("local" scope) to be partitioned across threads. While this seems intended to support local buffers in T.Parallel, the FIXME on the preceding line indicates this may not be a complete solution.

Could you elaborate on the implications of this change? Forcing partitioning on loops that only use thread-local registers might be unnecessary if they are not shared. It would be helpful to understand the plan to address the FIXME and correctly differentiate between local buffers that require partitioning and those that do not.


if (parallel_loop) {
for_node =
Expand Down
3 changes: 2 additions & 1 deletion src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop
bool should_annotate = false;
if (!inside_reducer_range_.empty() && !already_annotated_) {
if (!inside_reducer_range_.empty() && !already_annotated_ &&
op->kind == ForKind::kParallel) {
should_annotate = true;
already_annotated_ = true;
}
Expand Down
6 changes: 3 additions & 3 deletions src/transform/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,13 +639,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
};

void PlanAlignment(const Stmt &stmt) {
LOG(INFO) << "PlanAlignment";
DLOG(INFO) << "PlanAlignment";
PostOrderVisit(stmt, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(tl::tl_gemm()) ||
call->op.same_as(tl::tl_gemm_sp())) {
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
}
}
});
Expand Down
4 changes: 2 additions & 2 deletions src/transform/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1789,8 +1789,8 @@ class VectorTypeRewriter : public StmtExprMutator {
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
}
Expand Down
Loading
Loading