Skip to content
Open
34 changes: 17 additions & 17 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;

auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
Expand Down Expand Up @@ -530,8 +530,8 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
}

Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor) {
if (kfactor == 2)
bool k_inner) {
if (k_inner)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
Expand All @@ -558,29 +558,29 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
* select specific swizzling strategies. It might be the same as mat_continuous
* or different based on tiling or hardware details.
* \param element_size The size of each element in the matrix, in bits (e.g., 8,
* 16, 32, 64). \param kfactor An integer factor that influences layout
* 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop.
* selection, particularly for fp64 and int8 types. It often relates to how the
* K dimension of the GEMM (M x K * K x N) is handled or tiled.
* - For fp64 (element_size == 64):
* - kfactor == 1 often implies K is in the "outer" loop (e.g.,
* KxN matrix).
* - kfactor == 2 often implies K is in the "inner" loop (e.g.,
* NxK matrix).
* - k_inner == false often implies K is in the "outer" loop
* (e.g., KxN matrix).
* - k_inner == true often implies K is in the "inner" loop
* (e.g., NxK matrix).
* - For int8 (element_size == 8):
* - kfactor == 1 uses a padded layout.
* - k_inner == false uses a padded layout.
* \return A Layout object representing the chosen memory layout.
*/
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
int element_size, bool k_inner) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
Expand All @@ -592,17 +592,17 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
}

Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor) {
int continuity, int element_size, bool k_inner) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (!k_inner && element_size == 8) // int8 KxN
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
Expand Down
41 changes: 38 additions & 3 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void LayoutNode::RegisterReflection() {
}

void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
for (const auto &[var, dom] : getVarMap()) {
for (const auto &[var, dom] : LayoutNode::getVarMap()) {
analyzer->Bind(var, dom);
}
}
Expand Down Expand Up @@ -458,6 +458,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Layout layout) { return layout->GetForwardIndex(); })
.def("tl.Layout_forward_vars",
[](Layout layout) { return layout->GetForwardVars(); })
.def("tl.Layout_is_equal",
[](Layout layout, Layout other) {
const LayoutNode *other_node = other.as<LayoutNode>();
return layout->IsEqual(other_node);
})
.def_packed("tl.Fragment",
[](PackedArgs args, Any *rv) {
*rv = Fragment(
Expand All @@ -466,6 +471,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
/*forward_thread=*/args[2].cast<PrimExpr>(),
/*thread_replicate=*/args[3].cast<IterVar>());
})
.def("tl.Fragment_is_equal",
[](Fragment fragment, Fragment other) {
const FragmentNode *other_node = other.as<FragmentNode>();
return fragment->IsEqual(other_node);
})
.def("tl.Fragment_thread_size",
[](Fragment fragment) { return fragment->ThreadExtent(); })
.def("tl.Fragment_thread",
Expand All @@ -483,9 +493,34 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.Fragment_condense_rep_var",
[](Fragment fragment) { return fragment->CondenseReplicateVar(); })
.def("tl.make_swizzled_layout",
[](int stride, int continuous, int element_size, bool k_inner,
bool allow_pad = true) {
if (allow_pad) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, k_inner);
} else {
return makeGemmABLayoutHopper(stride, continuous, continuous,
element_size, k_inner);
}
})
.def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_half_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeHalfBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_quarter_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, 0);
return makeQuarterBankSwizzleLayout(stride, continuous,
element_size);
});
});

Expand Down
9 changes: 5 additions & 4 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
int element_size, bool k_inner = true);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
int continuity, int element_size,
bool k_inner = true);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
int kPack);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
Expand All @@ -173,7 +174,7 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor);
bool k_inner = true);

Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
Expand Down
20 changes: 20 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -161,5 +171,15 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
40 changes: 40 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ TVM_DLL const Op &mbarrier_wait_parity();
*/
TVM_DLL const Op &mbarrier_expect_tx();

/*!
* \brief tvm intrinsic for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_ss();

/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

/*!
* \brief tvm intrinsics for ldmatrix
*
Expand Down Expand Up @@ -319,6 +341,24 @@ TVM_DLL const Op &tl_gemm_sp();
*/
TVM_DLL const Op &tl_shuffle_elect();

/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* wgmma/utcmma.
*
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();

/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
* buffer for wgmma/utcmma.
*
* This op is used to represent a descriptor start address setting operation in
* tilelang.
*/
TVM_DLL const Op &increase_descriptor_offset();

} // namespace tl
} // namespace tvm

Expand Down
18 changes: 9 additions & 9 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace tir;
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* @note If `kPack` is provided it must be 1; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
Expand Down Expand Up @@ -478,7 +478,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]),
true, trans_A ? 1 : 2));
true, !trans_A));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
Expand All @@ -491,7 +491,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_B = B->shape.size();
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
false, trans_B));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target)) {
auto fragment =
Expand All @@ -504,7 +504,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
A->dtype.bits(), !trans_A));
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
Expand All @@ -518,7 +518,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
B->dtype.bits(), trans_B));
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
Expand All @@ -542,9 +542,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
A->dtype.bits(), !trans_A)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2);
A->dtype.bits(), !trans_A);
results.Set(A, ABLayout);
} else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
Expand All @@ -560,9 +560,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
B->dtype.bits(), trans_B)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1);
B->dtype.bits(), trans_B);
results.Set(B, ABLayout);
} else {
auto fragment =
Expand Down
18 changes: 16 additions & 2 deletions src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
return GemmInst::kMMA; // This line will never be reached due to ICHECK, but
// satisfies compiler
}
}

Expand Down Expand Up @@ -225,8 +227,9 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);

if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = Downcast<PrimFunc>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
auto prim_func =
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
Expand All @@ -249,6 +252,8 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_py";
return Stmt(); // This line will never be reached due to LOG(FATAL), but
// satisfies compiler
}
}

Expand All @@ -275,5 +280,14 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
Integer(CallEffectKind::kOpaque));

TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target);
});
});

} // namespace tl
} // namespace tvm
2 changes: 1 addition & 1 deletion src/op/gemm_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ class GemmPyNode : public TileOperatorNode {

TileOperator Clone() const;

private:
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;

private:
mutable bool completed_ = false;
};

Expand Down
Loading