diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 567bc644b..3b9ad6680 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -138,8 +138,8 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size) { ICHECK(block_m % warp_m == 0); - // ICHECK(block_n == warp_n); ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; + auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, false); // 16 x N (1 warp) auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, @@ -530,8 +530,8 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) { } Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, - int kfactor) { - if (kfactor == 2) + bool k_inner) { + if (k_inner) return MakeGemmVoltaABLayoutCrosswise(stride, continuous); if (is_a && continuous % 64 == 0) return MakeGemmVoltaALayoutCongruous(stride, continuous); @@ -558,29 +558,29 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, * select specific swizzling strategies. It might be the same as mat_continuous * or different based on tiling or hardware details. * \param element_size The size of each element in the matrix, in bits (e.g., 8, - * 16, 32, 64). \param kfactor An integer factor that influences layout + * 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop. * selection, particularly for fp64 and int8 types. It often relates to how the * K dimension of the GEMM (M x K * K x N) is handled or tiled. * - For fp64 (element_size == 64): - * - kfactor == 1 often implies K is in the "outer" loop (e.g., - * KxN matrix). - * - kfactor == 2 often implies K is in the "inner" loop (e.g., - * NxK matrix). + * - k_inner == false often implies K is in the "outer" loop + * (e.g., KxN matrix). + * - k_inner == true often implies K is in the "inner" loop + * (e.g., NxK matrix). * - For int8 (element_size == 8): - * - kfactor == 1 uses a padded layout. + * - k_inner == false uses a padded layout. * \return A Layout object representing the chosen memory layout. */ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, - int element_size, int kfactor) { + int element_size, bool k_inner) { if (element_size == 64) { - if (kfactor == 1 && continuity % 16 == 0) // float64 KxN + if (!k_inner && continuity % 16 == 0) // float64 KxN return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); - if (kfactor == 2 && continuity % 16 == 0) // float64 NxK + if (k_inner && continuity % 16 == 0) // float64 NxK return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); } int vector_size = 128 / element_size; - if (kfactor == 1 && element_size == 8) // int8 KxN + if (!k_inner && element_size == 8) // int8 KxN return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); else if (mat_continuous % (vector_size * 8) == 0) return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); @@ -592,17 +592,17 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, } Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, - int continuity, int element_size, int kfactor) { + int continuity, int element_size, bool k_inner) { if (element_size == 64) { - if (kfactor == 1 && continuity % 16 == 0) // float64 KxN + if (!k_inner && continuity % 16 == 0) // float64 KxN return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); - if (kfactor == 2 && continuity % 16 == 0) // float64 NxK + if (k_inner && continuity % 16 == 0) // float64 NxK return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, element_size); } int vector_size = 128 / element_size; - if (kfactor == 1 && element_size == 8) // int8 KxN + if (!k_inner && element_size == 8) // int8 KxN return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, element_size); else if (mat_continuous % (vector_size * 8) == 0) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index f682fd3ee..f54802c6d 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -82,7 +82,7 @@ void LayoutNode::RegisterReflection() { } void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const { - for (const auto &[var, dom] : getVarMap()) { + for (const auto &[var, dom] : LayoutNode::getVarMap()) { analyzer->Bind(var, dom); } } @@ -458,6 +458,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Layout layout) { return layout->GetForwardIndex(); }) .def("tl.Layout_forward_vars", [](Layout layout) { return layout->GetForwardVars(); }) + .def("tl.Layout_is_equal", + [](Layout layout, Layout other) { + const LayoutNode *other_node = other.as(); + return layout->IsEqual(other_node); + }) .def_packed("tl.Fragment", [](PackedArgs args, Any *rv) { *rv = Fragment( @@ -466,6 +471,11 @@ TVM_FFI_STATIC_INIT_BLOCK({ /*forward_thread=*/args[2].cast(), /*thread_replicate=*/args[3].cast()); }) + .def("tl.Fragment_is_equal", + [](Fragment fragment, Fragment other) { + const FragmentNode *other_node = other.as(); + return fragment->IsEqual(other_node); + }) .def("tl.Fragment_thread_size", [](Fragment fragment) { return fragment->ThreadExtent(); }) .def("tl.Fragment_thread", @@ -483,9 +493,34 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("tl.Fragment_condense_rep_var", [](Fragment fragment) { return fragment->CondenseReplicateVar(); }) .def("tl.make_swizzled_layout", + [](int stride, int continuous, int element_size, bool k_inner, + bool allow_pad = true) { + if (allow_pad) { + return makeGemmABLayout(stride, continuous, continuous, + element_size, k_inner); + } else { + return makeGemmABLayoutHopper(stride, continuous, continuous, + element_size, k_inner); + } + }) + .def("tl.make_wgmma_swizzled_layout", + [](int stride, int mat_continuous, int continuity, int element_size, + bool k_inner) { + return makeGemmABLayoutHopper(stride, mat_continuous, continuity, + element_size, k_inner); + }) + .def("tl.make_full_bank_swizzled_layout", + [](int stride, int continuous, int element_size) { + return makeFullBankSwizzleLayout(stride, continuous, element_size); + }) + .def("tl.make_half_bank_swizzled_layout", + [](int stride, int continuous, int element_size) { + return makeHalfBankSwizzleLayout(stride, continuous, element_size); + }) + .def("tl.make_quarter_bank_swizzled_layout", [](int stride, int continuous, int element_size) { - return makeGemmABLayout(stride, continuous, continuous, - element_size, 0); + return makeQuarterBankSwizzleLayout(stride, continuous, + element_size); }); }); diff --git a/src/layout/layout.h b/src/layout/layout.h index fe2e809a7..536c4f48b 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -160,11 +160,12 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, - int element_size, int kfactor); + int element_size, bool k_inner = true); Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, - int continuity, int element_size, int kfactor); + int continuity, int element_size, + bool k_inner = true); Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, - int kfactor); + int kPack); Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, @@ -173,7 +174,7 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n); Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, - int kfactor); + bool k_inner = true); Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e80867738..c1c0274df 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -80,6 +80,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss) + .set_num_inputs(15) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) + .set_num_inputs(15) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) .set_num_inputs(4) .set_attr("TCallEffectKind", @@ -161,5 +171,15 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(initialize_descriptor) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index aeb68c4e1..05aa0c173 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -161,6 +161,28 @@ TVM_DLL const Op &mbarrier_wait_parity(); */ TVM_DLL const Op &mbarrier_expect_tx(); +/*! + * \brief tvm intrinsic for ptx tensor core wgmma instructions. + * + * void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool + * a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var + * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool + * scale_in_a, bool scale_in_b); + */ +TVM_DLL const Op &ptx_wgmma_ss(); + +/*! + * \brief tvm intrinsics for ptx tensor core wgmma instructions. + * + * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool + * a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var + * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool + * scale_in_a, bool scale_in_b); + */ +TVM_DLL const Op &ptx_wgmma_rs(); + /*! * \brief tvm intrinsics for ldmatrix * @@ -319,6 +341,24 @@ TVM_DLL const Op &tl_gemm_sp(); */ TVM_DLL const Op &tl_shuffle_elect(); +/*! + * \brief tilelang intrinsic for initializing a descriptor buffer for + * wgmma/utcmma. + * + * This op is used to represent a descriptor initialization operation in + * tilelang. + */ +TVM_DLL const Op &initialize_descriptor(); + +/*! + * \brief tilelang intrinsic for setting the start address of a descriptor + * buffer for wgmma/utcmma. + * + * This op is used to represent a descriptor start address setting operation in + * tilelang. + */ +TVM_DLL const Op &increase_descriptor_offset(); + } // namespace tl } // namespace tvm diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 94abc12d3..c7142601c 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -42,7 +42,7 @@ using namespace tir; * @param vmap Mapping from access pointer vars to Buffer objects used to * resolve the Buffer corresponding to each pointer argument. * - * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * @note If `kPack` is provided it must be 1; otherwise the constructor * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ @@ -478,7 +478,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, int dim_A = A->shape.size(); results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), *as_const_int(A->shape[dim_A - 1]), - true, trans_A ? 1 : 2)); + true, !trans_A)); } else if (A.scope() == "local.fragment") { ICHECK(trans_A == false); auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); @@ -491,7 +491,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, int dim_B = B->shape.size(); results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]), *as_const_int(B->shape[dim_B - 1]), - false, trans_B ? 2 : 1)); + false, trans_B)); } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || TargetIsSM120(T.target)) { auto fragment = @@ -504,7 +504,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - A->dtype.bits(), trans_A ? 1 : 2)); + A->dtype.bits(), !trans_A)); } else if (A.scope() == "local.fragment") { auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits(), trans_A); @@ -518,7 +518,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - B->dtype.bits(), trans_B ? 2 : 1)); + B->dtype.bits(), trans_B)); } else if (B.scope() == "local.fragment") { auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); @@ -542,9 +542,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, auto ABLayout = gemm_inst == GemmInst::kWGMMA ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, - A->dtype.bits(), trans_A ? 1 : 2) + A->dtype.bits(), !trans_A) : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - A->dtype.bits(), trans_A ? 1 : 2); + A->dtype.bits(), !trans_A); results.Set(A, ABLayout); } else { auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, @@ -560,9 +560,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, auto ABLayout = gemm_inst == GemmInst::kWGMMA ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, - B->dtype.bits(), trans_B ? 2 : 1) + B->dtype.bits(), trans_B) : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, - B->dtype.bits(), trans_B ? 2 : 1); + B->dtype.bits(), trans_B); results.Set(B, ABLayout); } else { auto fragment = diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 4d1c31513..0773cc11a 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -106,6 +106,8 @@ GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, return GemmInst::kMMA; } else { ICHECK(0) << "Unsupported target for gemm: " << target->str(); + return GemmInst::kMMA; // This line will never be reached due to ICHECK, but + // satisfies compiler } } @@ -225,8 +227,9 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { - auto prim_func = Downcast( - (*f)(GetRef(this), T.target, T.thread_bounds, T.thread_var)); + auto prim_func = + Downcast((*f)(GetRef(this), T.layout_map, T.target, + T.thread_bounds, T.thread_var)); ICHECK(prim_func->attrs.defined()); auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); ICHECK(global_symbol.defined()); @@ -249,6 +252,8 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { /*name_hint=*/global_symbol.value(), prim_func->body)); } else { LOG(FATAL) << "No lower function found for gemm_py"; + return Stmt(); // This line will never be reached due to LOG(FATAL), but + // satisfies compiler } } @@ -275,5 +280,14 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py) Integer(CallEffectKind::kOpaque)); TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.GemmPyGemmInst", + [](GemmPy gemm_py, int block_size, Target target) { + return gemm_py->GetGemmInst(block_size, target); + }); +}); + } // namespace tl } // namespace tvm diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index fa3e22c1e..55584d2e4 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -105,11 +105,11 @@ class GemmPyNode : public TileOperatorNode { TileOperator Clone() const; -private: // Target GEMM instruction enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA }; GemmInst GetGemmInst(int block_size, Target target) const; +private: mutable bool completed_ = false; }; diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 21dc509cf..b8ed7ec32 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -895,7 +895,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, if (scope.empty()) { scope = GetPtrStorageScope(buffer->data); } - if (scope == "local.var") { + if (scope == "local.var" || scope == "local.descriptor") { os << vid; return os.str(); } @@ -1302,6 +1302,99 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; + } else if (op->op.same_as(tl::ptx_wgmma_ss())) { + // arg 0: dtype + // arg 1: shape + // arg 2: A_layout + // arg 3: B_layout + // arg 4: A_dtype + // arg 5: B_dtype + // arg 6: C_dtype + // arg 7: multiplicand_a + // arg 8: multiplicand_b + // arg 9: accumulator + // arg 10: saturate + ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_ss args is " << op->args; + std::string shape = Downcast(op->args[0])->value; + bool a_is_k_major = Downcast(op->args[1])->value; + bool b_is_k_major = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_desc = this->PrintExpr(op->args[6]); + std::string A_offset = this->PrintExpr(op->args[7]); + std::string b_desc = this->PrintExpr(op->args[8]); + std::string B_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + bool scale_out = Downcast(op->args[12])->value; + bool scale_in_a = Downcast(op->args[13])->value; + bool scale_in_b = Downcast(op->args[14])->value; + + const bool a_is_shared = true; + this->PrintIndent(); + std::string asm_code = PrintWGMMAAssembly( + shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset, + b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, + a_is_shared, "", "", "", false); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n"; + // replace patterns + tl::codegen::Replacer replacer; + replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype)); + replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype)); + replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspA)", a_is_k_major? "false": "true"); + replacer.register_rule("(tnspB)", b_is_k_major? "false": "true"); + replacer.register_rule("(scaleA)", scale_in_a? "1": "-1"); + replacer.register_rule("(scaleB)", scale_in_b? "1": "-1"); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref + " + " + c_offset); + replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); + wgmma_asm_code = replacer.rewrite(wgmma_asm_code); + this->stream << wgmma_asm_code; + } else if (op->op.same_as(tl::ptx_wgmma_rs())) { + // arg 0: dtype + // arg 1: shape + // arg 2: A_layout + // arg 3: B_layout + // arg 4: A_dtype + // arg 5: B_dtype + // arg 6: C_dtype + // arg 7: multiplicand_a + // arg 8: multiplicand_b + // arg 9: accumulator + // arg 10: saturate + ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args; + std::string shape = Downcast(op->args[0])->value; + bool A_layout = Downcast(op->args[1])->value; + bool B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string A_offset = this->PrintExpr(op->args[7]); + std::string b_desc = this->PrintExpr(op->args[8]); + std::string B_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + bool scale_out = Downcast(op->args[12])->value; + bool scale_in_a = Downcast(op->args[13])->value; + bool scale_in_b = Downcast(op->args[14])->value; + + const bool a_is_shared = false; + this->PrintIndent(); + std::string asm_code = PrintWGMMAAssembly( + shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset, + b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, + a_is_shared, "", "", "", false); + this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. @@ -1626,6 +1719,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { op->args, true, os); } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; + } else if (op->op.same_as(tl::initialize_descriptor())) { + ICHECK(op->args.size() == 5) + << "tl_initialize_descriptor expects 5 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto start_address = op->args[1]; + auto layout_type = op->args[2]; + auto leading_byte_offset = op->args[3]; + auto stride_byte_offset = op->args[4]; + os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", " + << PrintExpr(leading_byte_offset) << ", " + << PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", " + << PrintExpr(start_address) << ")"; + } else if (op->op.same_as(tl::increase_descriptor_offset())) { + ICHECK(op->args.size() == 2) + << "tl_increase_descriptor_offset expects 2 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto offset = op->args[1]; + os << "tl::increase_descriptor_offset(" << PrintExpr(descriptor) + << ", " << PrintExpr(offset) << ")"; } else { CodeGenC::VisitExpr_(op, os); } @@ -1692,6 +1806,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { << "Accumulator only support half, float and int type for now"; } PrintWmmaScope(scope, op->dtype, buffer, stream); + } else if (scope == "local.descriptor") { + stream << "tl::GmmaDescriptor " << vid << ";\n"; } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); @@ -1725,7 +1841,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { } else if (scope == "local.var") { stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0)) << ";\n"; - } else { + } else if (scope != "local.descriptor") { ICHECK(false) << "Unsupported scope: " << scope; } } diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 14d1b0460..1d2b4bae6 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -35,39 +35,10 @@ namespace codegen { // PTX related data structures and functions. namespace ptx { -/*! - * \brief PTX data type. - * \note - * PTX fundamental data types: - * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types - * PTX matrix data types: - * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types - */ -enum class DataType : int { - kInt4 = 0, - kUInt4 = 1, - kInt8 = 2, - kUInt8 = 3, - kInt16 = 4, - kUInt16 = 5, - kInt32 = 6, - kUInt32 = 7, - kInt64 = 8, - kUInt64 = 9, - kFloat8_e4m3 = 10, - kFloat8_e5m2 = 11, - kFloat16 = 12, - kBFloat16 = 13, - kFloat16x2 = 14, - kFloat32 = 15, - kTensorFloat32 = 16, - kFloat64 = 17, - kBit1 = 18, - kBit8 = 19, - kBit16 = 20, - kBit32 = 21, - kBit64 = 22 -}; +static const char *enum_to_str[] = { + "kInt4", "kUInt4", "kInt8", "kUInt8", "kInt16", "kUInt16", "kInt32", "kUInt32", + "kInt64", "kUInt64", "kFloat8_e4m3", "kFloat8_e5m2", "kFloat16", "kBFloat16", "kFloat16x2", "kFloat32", + "kTensorFloat32", "kFloat64", "kBit1", "kBit8", "kBit16", "kBit32", "kBit64"}; static const char *dtype_str[] = { ".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32", @@ -80,7 +51,7 @@ static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, /*! * \brief Create PTX data type from string. */ -inline DataType DTypeFromString(const std::string str) { +DataType DTypeFromString(const std::string str) { if (str == "int4" || str == ".s4") { return DataType::kInt4; } else if (str == "uint4" || str == ".u4") { @@ -132,6 +103,15 @@ inline DataType DTypeFromString(const std::string str) { } } + +std::string DTypeEnumToString(const ptx::DataType &dtype) { + return "tl::DataType::" + std::string(enum_to_str[static_cast(dtype)]); +} + +std::string DTypeEnumToString(const std::string &dtype) { + return "tl::DataType::" + std::string(enum_to_str[static_cast(DTypeFromString(dtype))]); +} + /*! * \brief Get the string representation of given PTX data type. */ @@ -146,10 +126,18 @@ inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } +inline bool DTypeIsInteger(DataType dtype) { + return dtype == DataType::kInt4 || dtype == DataType::kInt8 || + dtype == DataType::kInt16 || dtype == DataType::kInt32 || + dtype == DataType::kInt64 || dtype == DataType::kUInt4 || + dtype == DataType::kUInt8 || dtype == DataType::kUInt16 || + dtype == DataType::kUInt32 || dtype == DataType::kUInt64; +} + /*! * \brief Extract the value m, n, k from string m*n*k* */ -inline std::tuple ParseMMAShape(const std::string &str) { +std::tuple ParseMMAShape(const std::string &str) { size_t pos_m = str.find('m'), pos_n = str.find('n'), pos_k = str.find('k'); CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) << "Cannot parse MMA shape " << str; @@ -177,6 +165,17 @@ LayoutType LayoutTypeFromString(const std::string &str) { } } +/*! + * \brief Parse layout type from bool. + */ +LayoutType LayoutTypeFromBool(const bool &layout) { + if (layout) { + return LayoutType::kRowMajor; + } else { + return LayoutType::kColumnMajor; + } +} + static const char *layout_type_str[] = {"row", "col"}; /*! @@ -256,6 +255,450 @@ const MMAConfig valid_mma_configs[] = { MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true), }; +struct WGMMAConfig { + explicit WGMMAConfig(int m, int n, int k, DataType dtype_a, DataType dtype_b, + DataType dtype_c, bool sparse) + : m(m), n(n), k(k), dtype_a(dtype_a), dtype_b(dtype_b), dtype_c(dtype_c), + sparse(sparse) {} + int m, n, k; + DataType dtype_a, dtype_b, dtype_c; + bool sparse; + inline bool operator==(const WGMMAConfig &other) { + return m == other.m && n == other.n && k == other.k && + dtype_a == other.dtype_a && dtype_b == other.dtype_b && + dtype_c == other.dtype_c && sparse == other.sparse; + } +}; + +const WGMMAConfig valid_wgmma_configs[] = { + // Dense FP16 configurations + WGMMAConfig(64, 8, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, false), + + // Dense FP16 to FP32 accumulation + WGMMAConfig(64, 8, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 16, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, false), + + // Dense BFloat16 configurations + WGMMAConfig(64, 8, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 16, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, false), + + // Dense TF32 configurations + WGMMAConfig(64, 8, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 24, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + WGMMAConfig(64, 40, 8, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, false), + + // Dense INT8 configurations + WGMMAConfig(64, 8, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 16, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 32, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 64, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 96, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 128, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 192, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + WGMMAConfig(64, 256, 32, DataType::kInt8, DataType::kInt8, DataType::kInt32, + false), + + // Dense UINT8 configurations + WGMMAConfig(64, 8, 32, DataType::kUInt8, DataType::kUInt8, DataType::kInt32, + false), + WGMMAConfig(64, 16, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 32, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 64, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 96, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 128, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 192, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + WGMMAConfig(64, 256, 32, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, false), + + // Dense INT4 configurations + WGMMAConfig(64, 8, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 16, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 32, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 64, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 96, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 128, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 192, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + WGMMAConfig(64, 256, 64, DataType::kInt4, DataType::kInt4, DataType::kInt32, + false), + + // Dense UINT4 configurations + WGMMAConfig(64, 8, 64, DataType::kUInt4, DataType::kUInt4, DataType::kInt32, + false), + WGMMAConfig(64, 16, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 32, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 64, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 96, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 128, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 192, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + WGMMAConfig(64, 256, 64, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, false), + + // Dense FP8 E4M3 configurations + WGMMAConfig(64, 8, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, false), + WGMMAConfig(64, 8, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, false), + + // Dense FP8 E5M2 configurations + WGMMAConfig(64, 8, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, false), + WGMMAConfig(64, 8, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 16, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 32, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 64, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 96, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 128, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 192, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + WGMMAConfig(64, 256, 32, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, false), + + // Sparse FP16 configurations (k doubled for sparsity) + WGMMAConfig(64, 8, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat16, true), + + // Sparse FP16 to FP32 accumulation + WGMMAConfig(64, 8, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 32, DataType::kFloat16, DataType::kFloat16, + DataType::kFloat32, true), + + // Sparse BFloat16 configurations + WGMMAConfig(64, 8, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 32, DataType::kBFloat16, DataType::kBFloat16, + DataType::kFloat32, true), + + // Sparse TF32 configurations + WGMMAConfig(64, 8, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 16, DataType::kTensorFloat32, DataType::kTensorFloat32, + DataType::kFloat32, true), + + // Sparse INT8 configurations + WGMMAConfig(64, 8, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 16, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 32, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 64, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 96, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 128, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 192, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + WGMMAConfig(64, 256, 64, DataType::kInt8, DataType::kInt8, DataType::kInt32, + true), + + // Sparse UINT8 configurations + WGMMAConfig(64, 8, 64, DataType::kUInt8, DataType::kUInt8, DataType::kInt32, + true), + WGMMAConfig(64, 16, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 32, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 64, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 96, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 128, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 192, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + WGMMAConfig(64, 256, 64, DataType::kUInt8, DataType::kUInt8, + DataType::kInt32, true), + + // Sparse INT4 configurations + WGMMAConfig(64, 8, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 16, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 32, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 64, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 96, 128, DataType::kInt4, DataType::kInt4, DataType::kInt32, + true), + WGMMAConfig(64, 128, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + WGMMAConfig(64, 192, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + WGMMAConfig(64, 256, 128, DataType::kInt4, DataType::kInt4, + DataType::kInt32, true), + + // Sparse UINT4 configurations + WGMMAConfig(64, 8, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 16, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 32, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 64, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 96, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 128, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 192, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + WGMMAConfig(64, 256, 128, DataType::kUInt4, DataType::kUInt4, + DataType::kInt32, true), + + // Sparse FP8 E4M3 configurations + WGMMAConfig(64, 8, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat16, true), + WGMMAConfig(64, 8, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, + DataType::kFloat32, true), + + // Sparse FP8 E5M2 configurations + WGMMAConfig(64, 8, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat16, true), + WGMMAConfig(64, 8, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 16, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 32, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 64, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 96, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 128, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 192, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true), + WGMMAConfig(64, 256, 64, DataType::kFloat8_e5m2, DataType::kFloat8_e5m2, + DataType::kFloat32, true)}; + /*! * \brief Check whether the multiplicand data type and accumulator data type is * valid for MMA computation. \param dtype_a The data type of multiplicand a. @@ -393,6 +836,27 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, CHECK(match) << "Cannot find matched MMA configurations."; } +void CheckWGMMAConfigValidity(int m, int n, int k, LayoutType layout_a, + LayoutType layout_b, DataType dtype_a, + DataType dtype_b, DataType dtype_c, bool sparse) { + // Same DataType Compatibility as MMA + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); + + // Check if configuration exists in valid_wgmma_configs + WGMMAConfig config(m, n, k, dtype_a, dtype_b, dtype_c, sparse); + bool match = false; + for (const WGMMAConfig &valid_config : valid_wgmma_configs) { + if (config == valid_config) { + match = true; + break; + } + } + CHECK(match) << "Cannot find matched WGMMA configurations for m " << m + << " n " << n << " k " << k << " dtype_a " + << DTypeToString(dtype_a) << " dtype_b " + << DTypeToString(dtype_b) << " dtype_c " + << DTypeToString(dtype_c) << " sparse " << sparse; +} /*! * \brief Fragment attributes */ @@ -439,35 +903,6 @@ inline FragAttrs GetFragAttrs(DataType dtype) { }; // namespace ptx -/*! - * \brief Replace patterns with replacement strings. - * \note should use std::format instead when codebase is ported to C++20. - */ -class Replacer { -public: - void register_rule(const std::string &pattern, - const std::string &replacement) { - _rules.emplace_back(pattern, replacement); - } - std::string rewrite(std::string str) { - for (auto &&rule : _rules) { - auto [pattern, replacement] = rule; - size_t len = pattern.size(); - size_t new_len = replacement.size(); - size_t pos = str.find(pattern); - while (pos != std::string::npos) { - str = str.replace(pos, len, replacement); - pos = str.find(pattern, pos + new_len); - } - } - return str; - } - void empty_rules() { _rules.clear(); } - -private: - std::vector> _rules; -}; - /*! * \brief Get the number of MMA computations for given shape and datatype. */ @@ -566,6 +1001,123 @@ GetMMAOperands(int m, int n, int k, ptx::DataType dtype_a, return std::make_tuple(templates.str(), inputs.str(), outputs.str()); } +inline std::tuple +GetWGMMAOperands(int m, int n, int k, ptx::DataType dtype_a, + ptx::DataType dtype_b, ptx::DataType dtype_c, bool sparse, + bool a_is_shared) { + std::stringstream templates, inputs, outputs, predicate; + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); + constexpr uint32_t warp_size = 32; + const uint32_t threads = + 4 * warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) / + frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_c = + (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + const bool support_ldmatrix_transposed = + ptx::DTypeBits(dtype_a) == 16 && ptx::DTypeBits(dtype_b) == 16; + const bool support_scale_input = + !ptx::DTypeIsInteger(dtype_a) || !ptx::DTypeIsInteger(dtype_b); + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + if (!a_is_shared) { + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + } else { + templates << "}, %" << arg_counter++; + } + + // desc_b + templates << ", " + << "%" << arg_counter++; + + // scale_out + predicate << "%" << arg_counter++; + templates << ", " + << "p"; + + // scale_in_a + if (support_scale_input) { + templates << ", " + << "%" << arg_counter++; + // scale_in_b + templates << ", " + << "%" << arg_counter++; + } + if (support_ldmatrix_transposed) { + if (a_is_shared) { + // trans_a + templates << ", " + << "%" << arg_counter++; + } + // trans_b + templates << ", " + << "%" << arg_counter++; + } + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + LOG(FATAL) << "Sparse WGMMA is not supported yet."; + } + + // generate inputs + if (a_is_shared) { + inputs << "\"l\"(uint64_t((desc_a) + (A_offset)))"; + } else { + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; + } + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type + << "((A)))[" << i << "])"; + } + } + inputs << ", \"l\"(uint64_t((desc_b) + (B_offset)))"; + + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)((E)))[0])"; + } + + inputs << ", \"r\"(int32_t((scale_out)))"; + // scale_in_a + if (support_scale_input) { + inputs << ", \"n\"(int32_t((scale_in_a)))"; + // scale_in_b + inputs << ", \"n\"(int32_t((scale_in_b)))"; + } + if (support_ldmatrix_transposed) { + if (a_is_shared) { + // trans_a + inputs << ", \"n\"(int32_t((trans_a)))"; + } + // trans_b + inputs << ", \"n\"(int32_t((trans_b)))"; + } + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << "\"+" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type + << "((D)))[" << i << "])"; + } + + return std::make_tuple(templates.str(), inputs.str(), outputs.str(), + predicate.str()); +} + std::string PrintMMAAssembly(const std::string &shape, const std::string &A_layout, const std::string &B_layout, const std::string &A_dtype, @@ -631,6 +1183,79 @@ PrintMMAAssembly(const std::string &shape, const std::string &A_layout, return asm_code; } +std::string PrintWGMMAAssembly( + const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, + const std::string &A_dtype, const std::string &B_dtype, + const std::string &C_dtype, const std::string &a_desc, + const std::string &A_offset, const std::string &b_desc, + const std::string &B_offset, const std::string &c_ptr, + const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, + const bool &scale_in_b, const bool &a_is_shared, + const std::string &metadata, const std::string &metadata_offset, + const std::string &sparsity_selector, bool sparse) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), + dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + if (dtype_a == ptx::DataType::kFloat32) { + dtype_a = ptx::DataType::kTensorFloat32; + } + if (dtype_b == ptx::DataType::kFloat32) { + dtype_b = ptx::DataType::kTensorFloat32; + } + + ptx::LayoutType layout_a = ptx::LayoutTypeFromBool(!a_is_k_major), + layout_b = ptx::LayoutTypeFromBool(b_is_k_major); + auto [m, n, k] = ptx::ParseMMAShape(shape); + CheckWGMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, + dtype_c, sparse); + std::string asm_code = R"( + { + __asm__ __volatile__( + "{.reg .pred p;\n" + "setp.ne.b32 p, {predicate}, 0;\n" + "wgmma.mma_async{.sparse}.sync.aligned{.shape}{.dtype}{.atype}{.btype}" + "{templates};\n}" + : {outputs} + : {inputs}); + } +)"; + auto [templates_str, inputs_str, outputs_str, predicate_str] = + GetWGMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse, a_is_shared); + + // replace patterns + Replacer replacer; + replacer.register_rule("{.sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{.shape}", "." + shape); + replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + replacer.register_rule("{predicate}", predicate_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + if (a_is_shared) { + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + } else { + replacer.register_rule("(A)", a_desc + " + " + A_offset); + } + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ptr + " + " + c_offset); + replacer.register_rule("(D)", c_ptr + " + " + c_offset); + replacer.register_rule("(E)", metadata + " + " + metadata_offset); + replacer.register_rule("(F)", sparsity_selector); + replacer.register_rule("(scale_out)", scale_out ? "1" : "0"); + replacer.register_rule("(scale_in_a)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scale_in_b)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(trans_a)", a_is_k_major ? "0" : "1"); + replacer.register_rule("(trans_b)", b_is_k_major ? "0" : "1"); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + inline std::tuple GetLoadMatrixOperands(int num, const std::string &local_ptr, const std::string &local_elem_offset) { diff --git a/src/target/ptx.h b/src/target/ptx.h index 15acb96b1..dfc1bd11c 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -32,6 +32,94 @@ namespace tvm::tl { namespace codegen { +namespace ptx { + +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +/*! + * \brief Print ptx data type from string. + */ +DataType DTypeFromString(const std::string str); + + +/*! + * \brief Print ptx data type from enum. + */ +std::string DTypeEnumToString(const DataType &dtype); + +/*! + * \brief Print ptx data type from string. + */ +std::string DTypeEnumToString(const std::string &dtype); + +/*! + * \brief Parse MMA shape from string. + */ +std::tuple ParseMMAShape(const std::string &str); +} // namespace ptx + + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { +public: + void register_rule(const std::string &pattern, + const std::string &replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto &&rule : _rules) { + auto [pattern, replacement] = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + +private: + std::vector> _rules; +}; + /*! * \brief Print MMA assembly string given parameters. * \param shape The shape string mMnNkK @@ -65,6 +153,26 @@ PrintMMAAssembly(const std::string &shape, const std::string &A_layout, const std::string &sparsity_selector, const std::string &bit_op, bool sparse, bool saturate); +/*! + * \brief Print WGMMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + */ +std::string PrintWGMMAAssembly( + const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major, + const std::string &A_dtype, const std::string &B_dtype, + const std::string &C_dtype, const std::string &a_desc, + const std::string &A_offset, const std::string &b_desc, + const std::string &B_offset, const std::string &c_ptr, + const std::string &c_offset, const bool &scale_out, const bool &scale_in_a, + const bool &scale_in_b, const bool &a_is_shared, + const std::string &metadata, const std::string &metadata_offset, + const std::string &sparsity_selector, bool sparse); + /*! * \brief Print ldmatrix assembly string given parameters. * \param trans: whether the matrix is loaded in column major format or not. diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 06f88c4c2..0f9c2ee70 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -5,6 +5,8 @@ #endif #include +#include +#include #include #include #include @@ -13,6 +15,8 @@ using cutlass::bfloat16_t; using cutlass::half_t; using cutlass::tfloat32_t; +using cute::cast_smem_ptr_to_uint; + using int4_t = int4; #define hexp cutlass::fast_exp @@ -49,6 +53,7 @@ using int4_t = int4; } \ } while (0) + // abs function for bfloat_t and half_t since there is no implicit conversion // method TL_PATCH TL_DEVICE half_t __habs(const half_t x) { @@ -299,6 +304,101 @@ TL_DEVICE /** } namespace tl { +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 +}; + +union GmmaDescriptor { + CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(uint64_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr GmmaDescriptor & + operator=(GmmaDescriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 + // brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, + base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { + return desc_; + } + template + CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const { + GmmaDescriptor ret; + ret.reg32_[0] = reg32_[0] + uint32_t(offset); + ret.reg32_[1] = reg32_[1]; + return ret; + } +}; + // Any template TL_DEVICE bool Any(T *a, int size) { for (int i = 0; i < size; i++) { @@ -334,6 +434,25 @@ template TL_DEVICE void __sync_thread_partial() { asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count)); } + +template +TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, + T *start_address) { + descriptor.bitfield.start_address_ = + cute::cast_smem_ptr_to_uint(start_address) >> 4; + descriptor.bitfield.layout_type_ = layout_type; + descriptor.bitfield.base_offset_ = 0; + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; +} + +template +TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, + T offset) { + descriptor.reg32_[0] += (offset >> 4); +} + } // namespace tl namespace cutlass { diff --git a/src/tl_templates/cuda/gemm.h b/src/tl_templates/cuda/gemm.h index 41a026290..20af35d78 100644 --- a/src/tl_templates/cuda/gemm.h +++ b/src/tl_templates/cuda/gemm.h @@ -2,6 +2,7 @@ #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) #include "gemm_sm120.h" #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "./instruction/wgmma.h" #include "gemm_sm90.h" #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) #include "gemm_sm89.h" diff --git a/src/tl_templates/cuda/instruction/wgmma.h b/src/tl_templates/cuda/instruction/wgmma.h new file mode 100644 index 000000000..e6dc3e54e --- /dev/null +++ b/src/tl_templates/cuda/instruction/wgmma.h @@ -0,0 +1,627 @@ +#pragma once +#include "../common.h" +#include "cute/arch/mma_sm90_gmma.hpp" + +namespace tl { + +template inline constexpr bool always_false_v = false; + +// 主类模板 - 移除默认参数,因为特化不能有默认参数 +template < + DataType A_type, + DataType B_type, + DataType C_type, + int M, + int N, + int K, + bool tnspA, + bool tnspB, + int scaleA, + int scaleB +> +struct WgmmaSSImpl { + TL_DEVICE static void execute( + uint64_t desc_a, + uint64_t desc_b, + uint32_t *c, + bool scale_out + ) { + printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, scaleB=%d\n", + (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, K, (int)tnspA, (int)tnspB, scaleA, scaleB); + // 暂时注释掉 static_assert 来看调试输出 + // static_assert(always_false_v, + // "wgmma_ss: No specialization available for given template parameters!"); + }; +}; + +// ================================= F16 x F16 -> F16 ================================= + +// M64N8K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N16K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N32K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N64K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15}," + " %16, %17, p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N96K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23}, " + "%24, %25, p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]), + "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N128K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31}, " + "%32, %33, p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]), + "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), + "+r"(c[24]), "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), + "+r"(c[28]), "+r"(c[29]), "+r"(c[30]), "+r"(c[31]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N192K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31, " + "%32, %33, %34, %35, %36, %37, %38, %39, " + "%40, %41, %42, %43, %44, %45, %46, %47}, " + "%48, %49, p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]), + "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), + "+r"(c[24]), "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), + "+r"(c[28]), "+r"(c[29]), "+r"(c[30]), "+r"(c[31]), + "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), "+r"(c[35]), + "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), + "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), + "+r"(c[44]), "+r"(c[45]), "+r"(c[46]), "+r"(c[47]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N256K16 F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31, " + "%32, %33, %34, %35, %36, %37, %38, %39, " + "%40, %41, %42, %43, %44, %45, %46, %47, " + "%48, %49, %50, %51, %52, %53, %54, %55, " + "%56, %57, %58, %59, %60, %61, %62, %63}, " + "%64, %65, p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]), + "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), + "+r"(c[24]), "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), + "+r"(c[28]), "+r"(c[29]), "+r"(c[30]), "+r"(c[31]), + "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), "+r"(c[35]), + "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), + "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), + "+r"(c[44]), "+r"(c[45]), "+r"(c[46]), "+r"(c[47]), + "+r"(c[48]), "+r"(c[49]), "+r"(c[50]), "+r"(c[51]), + "+r"(c[52]), "+r"(c[53]), "+r"(c[54]), "+r"(c[55]), + "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]), + "+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= F16 x F16 -> F32 ================================= + +// M64N8K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N16K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N32K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15}, " + "%16, %17, p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N64K16 F16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + "%8, %9, %10, %11, %12, %13, %14, %15, " + "%16, %17, %18, %19, %20, %21, %22, %23, " + "%24, %25, %26, %27, %28, %29, %30, %31}, " + "%32, %33, p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), + "+r"(c[8]), "+r"(c[9]), "+r"(c[10]), "+r"(c[11]), + "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), "+r"(c[15]), + "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), + "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), + "+r"(c[24]), "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), + "+r"(c[28]), "+r"(c[29]), "+r"(c[30]), "+r"(c[31]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= BF16 x BF16 -> F32 ================================= + +// M64N8K16 BF16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N16K16 BF16->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= TF32 x TF32 -> F32 ================================= + +// M64N8K8 TF32->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N16K8 TF32->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), + "+r"(c[4]), "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= INT8 x INT8 -> INT32 ================================= + +// M64N8K32 S8->S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N16K32 S8->S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= FP8 x FP8 -> F16/F32 ================================= + +// M64N8K32 E4M3->F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// M64N8K32 E4M3->F32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// 函数模板委托给类模板 +template < + DataType A_type, + DataType B_type, + DataType C_type, + int M, + int N, + int K, + bool tnspA, + bool tnspB, + int scaleA = 1, + int scaleB = 1 +> +TL_DEVICE void wgmma_ss( + uint64_t desc_a, + uint64_t desc_b, + uint32_t *c, + bool scale_out +) { + WgmmaSSImpl::execute( + desc_a, desc_b, c, scale_out + ); +} + +// ================================= Mixed Precision Support ================================= + +// Mixed precision: S8 x U8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// Mixed precision: U8 x S8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// Mixed precision: U8 x U8 -> S32 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// Mixed precision FP8: E4M3 x E5M2 -> F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// Mixed precision FP8: E5M2 x E4M3 -> F16 +template +struct WgmmaSSImpl { + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(c[0]), "+r"(c[1]) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), + "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)) + ); + } +}; + +// ================================= Convenience Templates ================================= + +// Type trait to determine the number of output registers needed +template +struct WgmmaOutputRegs { + static constexpr int value = (M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8); +}; + +// Type trait to get element size in bits +template +struct ElementBits { + static constexpr int value = + (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || dtype == DataType::kInt32) ? 32 : + (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || dtype == DataType::kInt16 || dtype == DataType::kUInt16) ? 16 : + (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) ? 8 : + (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 : 8; +}; + +} // namespace tl \ No newline at end of file diff --git a/src/transform/lower_device_storage_access_info.cc b/src/transform/lower_device_storage_access_info.cc index be5c41fa9..635a3fdb8 100644 --- a/src/transform/lower_device_storage_access_info.cc +++ b/src/transform/lower_device_storage_access_info.cc @@ -45,7 +45,7 @@ class StorageAccessInfoLower : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode *op) final { auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" && - scope.tag != ".barrier") { + scope.tag != ".barrier" && scope.tag != ".descriptor") { auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index d86817d9e..6c4ae427c 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -674,7 +674,7 @@ class StoragePlanRewriter : public StmtExprMutator { bool IsSpecialTaggedMemory(const StorageScope &scope) { return !scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".barrier" && scope.tag != ".workspace" && - scope.tag != ".vtcm"; + scope.tag != ".vtcm" && scope.tag != ".descriptor"; } // Allocate entry of node. @@ -844,7 +844,8 @@ class StoragePlanRewriter : public StmtExprMutator { // allocate with element type. ICHECK_NE(e->const_nbits, 0U); MemoryInfo info; - if (e->scope.tag != ".barrier" && e->scope.tag != ".var") { + if (e->scope.tag != ".barrier" && e->scope.tag != ".var" && + e->scope.tag != ".descriptor") { info = GetMemoryInfo(e->scope.to_string()); } uint64_t total_bits = e->const_nbits; diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py new file mode 100644 index 000000000..68f4b04cc --- /dev/null +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -0,0 +1,508 @@ +import tilelang.language as T +from enum import IntEnum +from typing import Optional, Callable +from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter +from tvm import DataType +from tvm.tir import PrimExpr, Buffer, Var, IndexMap +from tilelang.utils import is_fragment +from tilelang.layout import ( + Layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, +) +from tvm.runtime import convert +from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a) + +lift = convert + + +class SwizzleMode(IntEnum): + # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + NONE = 0 + SWIZZLE_128B = 1 + SWIZZLE_64B = 2 + SWIZZLE_32B = 3 + + def is_none(self) -> bool: + return self == SwizzleMode.NONE + + def is_swizzle_32b(self) -> bool: + return self == SwizzleMode.SWIZZLE_32B + + def is_swizzle_64b(self) -> bool: + return self == SwizzleMode.SWIZZLE_64B + + def is_swizzle_128b(self) -> bool: + return self == SwizzleMode.SWIZZLE_128B + + def swizzle_byte_size(self) -> int: + if self.is_swizzle_32b(): + return 32 + elif self.is_swizzle_64b(): + return 64 + elif self.is_swizzle_128b(): + return 128 + else: + return 1 + + def swizzle_atom_size(self) -> int: + if self.is_swizzle_32b(): + return 32 // 16 + elif self.is_swizzle_64b(): + return 64 // 16 + elif self.is_swizzle_128b(): + return 128 // 16 + else: + return 1 + + +# derive from MMAIntrinEmitter as some layouts are the same +class TensorCoreIntrinEmitter(MMAIntrinEmitter): + """ + To eliminate Python syntax within TIR Macro. + """ + + # should be rewritten to support dynamic k_dim + wgmma_prefix: str + + a_shared_layout: Layout = None + b_shared_layout: Layout = None + + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: Optional[bool] = False, + thread_var: Optional[Var] = None, + ): + super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, + block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, + num_elems_per_byte, is_m_first, thread_var) + self._initialize_wgmma_prefix(self.n_dim) + + def _assign_a_shared_layout(self, layout: Layout): + self.a_shared_layout = layout + return self + + def _assign_b_shared_layout(self, layout: Layout): + self.b_shared_layout = layout + return self + + def _initialize_wgmma_prefix(self, n_dim: int = 16): + inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles + # 256 bits per instruction + inst_k = 256 // DataType(self.a_dtype).bits + self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + # four warps per block + self.warp_rows = warp_row_tiles // m_dim + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: + if layout is None: + return SwizzleMode.NONE + elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_32B + elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_64B + elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + else: + raise ValueError(f"Unsupported swizzle mode: {layout}") + + def wgmma(self, + A_buf: Buffer, + B_buf: Buffer, + C_local_buf: Buffer, + clear_accum: PrimExpr = False): + + if is_fragment(A_buf): + return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum) + + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_out = not clear_accum + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + + elems_in_bytes = DataType(self.a_dtype).bits // 8 + + # by default, we utilize non-swizzle layout offset + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * + elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * + elems_in_bytes) + + if not a_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if a_is_k_major: + a_leading_byte_offset = 16 + else: + # MN Major + # LBO represents the distance between two atoms along the M dimension + # SBO represents the distance between two atoms along the K dimension + a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size() + a_stride_byte_offset = 8 * 64 * elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * + elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * + elems_in_bytes) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + else: + # MN Major, K * N + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + + + print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}") + print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}") + + print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}") + print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}") + print(f"m_dim: {m_dim}") + print(f"n_dim: {n_dim}") + print(f"k_dim: {k_dim}") + print(f"micro_size_k: {micro_size_k}") + print(f"a_leading_byte_offset: {a_leading_byte_offset}") + print(f"a_stride_byte_offset: {a_stride_byte_offset}") + print(f"b_leading_byte_offset: {b_leading_byte_offset}") + print(f"b_stride_byte_offset: {b_stride_byte_offset}") + # exit() + @T.macro + def _warp_mma(A_buf, B_buf, C_local_buf): + desc_a = T.alloc_descriptor() + desc_b = T.alloc_descriptor() + T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) + T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + for ki in T.serial(0, (k_dim // micro_size_k)): + for i in T.serial(m_dim // 64): + k_dim_offset = ki * micro_size_k + A_offset = i * 64 * A_buf.shape[ + -1] + k_dim_offset if a_is_k_major else ki * micro_size_k * 64 + i * 64 * k_dim + B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + C_offset = i * warp_cols * local_size_out # 4 warps as an unit + T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, + b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, + accum_dtype_abbrv, desc_a.data, (A_offset * elems_in_bytes) >> 4, + desc_b.data, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, + C_offset, scale_out, scale_in_a, scale_in_b) + + return _warp_mma(A_buf, B_buf, C_local_buf) + + def wgmma_rs(self, + A_buf: Buffer, + B_buf: Buffer, + C_local_buf: Buffer, + clear_accum: PrimExpr = False): + local_size_a = self.local_size_a + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_rows, warp_cols = self.warp_rows, self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_out = not clear_accum + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + elems_in_bytes = DataType(self.a_dtype).bits // 8 + + b_is_k_major = self.b_transposed + + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * + elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * + elems_in_bytes) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + else: + # MN Major + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + + @T.macro + def _warp_mma(A_buf, B_buf, C_local_buf): + desc_b = T.alloc_descriptor() + T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + for ki in T.serial(0, (k_dim // micro_size_k)): + for i in T.serial(m_dim // 64): + k_dim_offset = ki * micro_size_k + A_offset = ki * warp_rows * local_size_a + i * local_size_a + B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1] + C_offset = i * warp_cols * local_size_out # 4 warps as an unit + T.ptx_wgmma_rs( + accum_dtype, + wgmma_prefix, + self.a_transposed, + not self.b_transposed, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf.data, + A_offset, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_local_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + return _warp_mma(A_buf, B_buf, C_local_buf) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + from tilelang.utils import is_fragment + assert matrix in ["A"], "matrix should be A for WGMMA" + dtype = self.a_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed + + # s represents spatial axis + # r represents reduction axis + # sr represents the two dims are spatial + reduction + # rs represents the two dims are reduction + spatial + # sr also can represent a non-transposed basic layout + # then rs also can represent a transposed basic layout + transform_func_sr_a: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_conditions = [False] + is_sr_conditions.append(not transposed) + is_sr_axis_order = any(is_sr_conditions) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) + + assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( + local_buf.scope()) + + micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k + + block_row_warps, block_col_warps = ( + self.block_row_warps, + self.block_col_warps, + ) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + """ + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows = self.warp_rows + chunk = self.chunk + + warp_s = warp_rows + warp_r = chunk // micro_size_r + block_s = block_row_warps + replicate = block_col_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([block_s, 1], + repeat_on_thread=True, + lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_s, warp_r], + repeat_on_thread=False, + lower_dim_first=False) + else: + # rs condition, transposed_a matrix + warp_fragment = base_fragment.repeat([1, block_s], + repeat_on_thread=True, + lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_r, warp_s], + repeat_on_thread=False, + lower_dim_first=True) + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + lane_id, _ = inverse_mma_store_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + _, local_id = inverse_mma_store_layout.map_indices([i, j]) + return local_id + + # reproduce src/layout/gemm_layouts.cc::makeGemmFragmentCHopper + base_fragment = T.Fragment( + [micro_size_x, micro_size_y], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + warp_n_layout = base_fragment.repeat([1, warp_cols], False, False) + block_layout = warp_n_layout.repeat([block_row_warps, block_col_warps], True, False) + warp_m_layout = block_layout.repeat([warp_rows, 1], False, False) + return warp_m_layout diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 9d52ae602..9d7334433 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -42,6 +42,7 @@ alloc_fragment, # noqa: F401 alloc_barrier, # noqa: F401 alloc_reducer, # noqa: F401 + alloc_descriptor, # noqa: F401 ) from .copy import copy, c2d_im2col # noqa: F401 from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401 @@ -80,11 +81,11 @@ def symbolic(name: str, dtype: str = "int32"): """ Create a TIR symbolic variable. - + Parameters: name (str): Identifier for the variable in generated TIR. dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32". - + Returns: tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels. """ @@ -108,7 +109,7 @@ def annotate_layout(layout_map: Dict): Returns: block_attr: a block attribute - + Example: @T.prim_func def main( @@ -149,7 +150,7 @@ def annotate_padding(padding_map: Dict): Returns: block_attr: a block attribute - + Example: @T.prim_func def main( diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 3601102ad..e456f4320 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -124,3 +124,12 @@ def alloc_reducer(shape, dtype, op="sum", replication=None): TL.block_attr({"reducer_info": {reducer.data: {"rep": replication, "op": op}}}) return reducer + + +def alloc_descriptor(dtype="uint64", scope="local.descriptor"): + """Allocate a descriptor buffer for wgmma and utcmma. + + Returns: + T.Buffer: A TVM buffer object allocated as a descriptor + """ + return T.alloc_buffer([1], dtype, scope=scope) diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index e49e6d5c3..0948cdfa7 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -1892,6 +1892,8 @@ def wrapped(*args, **kwargs): call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) ptx_mma = _dtype_forward(_tir_op.ptx_mma) ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) +ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) @@ -2141,6 +2143,8 @@ def wrapped(*args, **kwargs): "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", + "ptx_wgmma_ss", + "ptx_wgmma_rs", "ptx_ldmatrix", "ptx_cp_async", "ptx_cp_async_bulk", diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index bfee1d2e3..6c247a6aa 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -6,7 +6,7 @@ from tilelang.utils.target import check_hip_availability from tvm import tir from typing import Union, Any -from tvm.tir import PrimExpr, Var, Call +from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad _IS_HIP_AVAILABLE = check_hip_availability() @@ -350,3 +350,62 @@ def sync_grid(): """Synchronize all threads in a grid. """ return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) + + +def initialize_descriptor(descriptor: Buffer, + start_address: PrimExpr, + layout_type_: int = 0, + leading_byte_offset: int = 0, + stride_byte_offset: int = 0) -> PrimExpr: + """ + Initialize a memory descriptor with the given parameters. + + Parameters: + descriptor (Buffer): The memory descriptor to initialize. + start_address (PrimExpr): The starting address of the memory region. + layout_type_ (int, optional): Layout type identifier. Defaults to 0. + leading_byte_offset (int, optional): Leading byte offset. Defaults to 0. + stride_byte_offset (int, optional): Stride byte offset. Defaults to 0. + + Returns: + PrimExpr: A handle representing the initialized descriptor. + """ + + if not isinstance(descriptor, (BufferLoad, Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") + + if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( + descriptor, [0]) + + return evaluate( + tir.call_intrin("handle", tir.op.Op.get("tl.initialize_descriptor"), descriptor, + start_address, layout_type_, int(leading_byte_offset), + int(stride_byte_offset))) + + +def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: + """ + Increase the offset of a memory descriptor. + + Parameters: + descriptor (PrimExpr): The memory descriptor to modify. + offset (PrimExpr): The offset value to increase. + + Returns: + PrimExpr: A handle representing the modified descriptor. + """ + if not isinstance(descriptor, (BufferLoad, Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") + + if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( + descriptor, [0]) + + return evaluate( + tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, + offset)) diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 9ea0ebc3a..5f801a0c2 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -20,18 +20,18 @@ def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): """ Create a tile memory-region descriptor for a BufferLoad. - + Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents. - + Parameters: buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices. access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access. *args (tir.PrimExpr): Extent expressions for each region dimension. - + Returns: tir.Call: A call to the `tl.region` intrinsic describing the memory region. - + Raises: KeyError: If access_type is not one of 'r', 'w', or 'rw'. """ @@ -83,15 +83,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, extents: List[PrimExpr]): """ Create a tl region descriptor for the given BufferRegion. - + Parameters: buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents. access_type (str): Access mode: "r", "w", or "rw". extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region. - + Returns: tir.Call: A tile-region descriptor (tl.region) covering the buffer_region. - + Raises: AssertionError: If the number of extents in buffer_region.region is smaller than len(extents). """ @@ -107,15 +107,15 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: """ Perform an atomic maximum on the value stored at dst with an optional memory-order. - + If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern. - + Parameters: dst (Buffer): Destination buffer/address to apply the atomic max. value (PrimExpr): Value to compare/store atomically. memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst"). If provided, it is translated to the corresponding numeric memory-order id before the call. - + Returns: PrimExpr: A handle/expression representing the issued atomic maximum operation. """ @@ -129,14 +129,14 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: """ Atomically update the value at dst to the minimum of its current value and value. - + If memory_order is provided, it selects the memory-order semantic used by the underlying extern call; allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument. - + Parameters: memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering. - + Returns: PrimExpr: A handle expression representing the atomic-min operation. """ @@ -150,9 +150,9 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: """ Atomically add `value` into `dst`, returning a handle to the operation. - + Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`. - + Returns: PrimExpr: A handle representing the atomic addition operation. """ @@ -160,11 +160,11 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> def get_extent(data): """ Return the inferred extent (shape) of a buffer-like object. - + If `data` is a Var bound to a let value, the let value is resolved before inspection. Parameters: data: A Var, Buffer, or BufferRegion to inspect. - + Returns: The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined. """ @@ -252,12 +252,12 @@ def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: """Clamps the input value dst between [min_val, max_val] - + Args: dst: Input value to be clamped min_val: Minimum value max_val: Maximum value - + Returns: Value clamped to the specified range """ @@ -268,7 +268,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: """Reshapes the input buffer to the specified shape. - + Args: src (Buffer): Input buffer to be reshaped shape (List[PrimExpr]): New shape for the buffer @@ -284,7 +284,7 @@ def view(src: Buffer, dtype: Union[str, None] = None) -> Buffer: """ Return a Tensor view of the input buffer with an optional new shape and dtype. - + If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy). """ if shape is None: @@ -297,7 +297,7 @@ def view(src: Buffer, def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: """ Load a value from the given buffer using the specified atomic memory ordering. - + Performs an atomic load from `src` and returns a PrimExpr representing the loaded value. memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire", "release", "acq_rel", or "seq_cst" (default). @@ -310,17 +310,17 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr: """ Perform an atomic store of `src` into `dst` with the given memory ordering. - + Parameters: dst (Buffer): Destination buffer to store into. src (PrimExpr): Value to store. memory_order (str, optional): Memory ordering name; one of "relaxed", "consume", "acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst". The name is mapped to an internal numeric ID used by the underlying runtime. - + Returns: PrimExpr: A handle representing the issued atomic store operation. - + Raises: KeyError: If `memory_order` is not one of the supported names. """ diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index cbce46f22..1143f2a9e 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -291,6 +291,8 @@ def wrapped(*args, **kwargs): call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) ptx_mma = _dtype_forward(_tir_op.ptx_mma) ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) +ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) +ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 302de9d19..10ca7ca93 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1061,6 +1061,88 @@ def ptx_mma_sp( ) +def ptx_wgmma_ss( + dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + """TVM intrinsic for ptx tensor core wmma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-wmma + """ + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_ss"), + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + +def ptx_wgmma_rs( + dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_rs"), + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): """TVM intrinsic for storing the result of PTX MMA into a destination pointer diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 5269c199a..25fcda753 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -3,5 +3,11 @@ from .layout import Layout # noqa: F401 from .fragment import Fragment # noqa: F401 -from .swizzle import make_swizzled_layout # noqa: F401 -from .gemm_sp import make_metadata_layout # noqa: F401 \ No newline at end of file +from .swizzle import ( + make_swizzled_layout, # noqa: F401 + make_wgmma_swizzled_layout, # noqa: F401 + make_full_bank_swizzled_layout, # noqa: F401 + make_half_bank_swizzled_layout, # noqa: F401 + make_quarter_bank_swizzled_layout, # noqa: F401 +) +from .gemm_sp import make_metadata_layout # noqa: F401 diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index 2cd64563e..b018a9f11 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -204,13 +204,10 @@ def __repr__(self): str A string showing the thread dimension and the index dimension. """ - return f"Fragment" + return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" - -def make_swizzled_layout(buffer: tvm.tir.Buffer): - assert len(buffer.shape) == 2 - return _ffi_api.make_swizzled_layout( - int(buffer.shape[0]), - int(buffer.shape[1]), - int(tvm.DataType(buffer.dtype).bits), - ) + def is_equal(self, other: "Fragment") -> bool: + """ + Check if the current fragment is equal to another fragment. + """ + return _ffi_api.Fragment_is_equal(self, other) diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index ee0bd8ea3..fd8e31225 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -89,6 +89,9 @@ def get_forward_vars(self): """ return _ffi_api.Layout_forward_vars(self) + def get_forward_index(self): + return self.index + def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr: """ Compute the forward index mapping for a given set of input indices. @@ -129,3 +132,17 @@ def inverse(self) -> "Layout": A new Layout object representing the inverse transformation. """ return _ffi_api.Layout_inverse(self) + + def is_equal(self, other: "Layout") -> bool: + """ + Check if the current layout is equal to another layout. + + Parameters + ---------- + other : Layout + The layout to compare with. + """ + return _ffi_api.Layout_is_equal(self, other) + + def __repr__(self): + return f"Layout<{self.get_input_shape()}->{self.get_output_shape()}, {self.get_forward_vars()} -> {self.get_forward_index()}>" diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 9fd2582b3..e193eaf6b 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -7,10 +7,103 @@ # Use a stable swizzled layout to ensure consistent memory access patterns. # Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied. -def make_swizzled_layout(buffer: tvm.tir.Buffer): +def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad: bool = True): assert len(buffer.shape) == 2 return _ffi_api.make_swizzled_layout( int(buffer.shape[0]), int(buffer.shape[1]), int(tvm.DataType(buffer.dtype).bits), + k_major, + allow_pad, + ) + + +# for WGMMA Intrinsics +def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, + continuity: int = None, + k_major: bool = True): + assert len(buffer.shape) == 2 + if continuity is None: + continuity = int(buffer.shape[1]) + return _ffi_api.make_wgmma_swizzled_layout( + int(buffer.shape[0]), + int(buffer.shape[1]), + continuity, + int(tvm.DataType(buffer.dtype).bits), + k_major, + ) + + +# swizzle 128B +# args: buffer or (stride, continuous, element_size) +def make_full_bank_swizzled_layout(*args): + """ + Args: + args: buffer or (stride, continuous, element_size) + Examples: + make_full_bank_swizzled_layout(buffer) + make_full_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + buffer = args[0] + stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) + element_size = int(tvm.DataType(buffer.dtype).bits) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_full_bank_swizzled_layout( + stride, + continuous, + element_size, + ) + + +# swizzle 64B +# args: buffer or (stride, continuous, element_size) +def make_half_bank_swizzled_layout(*args): + """ + Args: + args: buffer or (stride, continuous, element_size) + Examples: + make_half_bank_swizzled_layout(buffer) + make_half_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + buffer = args[0] + stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) + element_size = int(tvm.DataType(buffer.dtype).bits) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_half_bank_swizzled_layout( + stride, + continuous, + element_size, + ) + + +# swizzle 32B +# args: buffer or (stride, continuous, element_size) +def make_quarter_bank_swizzled_layout(*args): + """ + Args: + args: buffer or (stride, continuous, element_size) + Examples: + make_quarter_bank_swizzled_layout(buffer) + make_quarter_bank_swizzled_layout(stride, continuous, element_size) + """ + if len(args) == 1: + buffer = args[0] + stride, continuous = int(buffer.shape[0]), int(buffer.shape[1]) + element_size = int(tvm.DataType(buffer.dtype).bits) + elif len(args) == 3: + stride, continuous, element_size = args + else: + raise ValueError(f"Invalid arguments: {args}") + return _ffi_api.make_quarter_bank_swizzled_layout( + stride, + continuous, + element_size, ) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 1c8ca8652..4d2b3625e 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -1,13 +1,14 @@ +from enum import IntEnum from tilelang import tvm as tvm from tvm import tir -from tilelang.utils.target import ( - target_is_cuda,) from tvm.target import Target from tvm.ir.base import Node from tvm.runtime import Scriptable import tvm.ffi from tilelang.ir import GemmWarpPolicy from .gemm_mma import GemmMMA +from .gemm_wgmma import GemmWGMMA +from tilelang import _ffi_api @tvm.ffi.register_func("tl.gemm_py.infer_layout") @@ -17,12 +18,29 @@ def gemm_py_infer_layout(gemm_py, target, thread_bounds): @tvm.ffi.register_func("tl.gemm_py.lower") -def gemm_py_lower(gemm_py, target, thread_bounds, thread_var): +def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): thread_nums = thread_bounds.extent - stmt = gemm_py.lower(target, thread_nums, thread_var) + stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) return stmt +# TODO(lei): support Volta and WMMA? +# same definition with src/op/gemm_py.h +class GemmInst(IntEnum): + MMA = 0 + WGMMMA = 1 + MFMA = 2 + + def is_mma(self) -> bool: + return self == GemmInst.MMA + + def is_wgmma(self) -> bool: + return self == GemmInst.WGMMMA + + def is_mfma(self) -> bool: + return self == GemmInst.MFMA + + @tvm.ffi.register_object("tl.GemmPy") class GemmPy(Node, Scriptable): A: tir.Buffer @@ -50,16 +68,53 @@ class GemmPy(Node, Scriptable): policy: GemmWarpPolicy def infer_layout(self, target: Target, thread_nums: int): - if target_is_cuda(target): - # TODO(lei): Support more cuda architectures, now mma only - return GemmMMA(self).infer_layout(target, thread_nums) - else: - raise ValueError(f"Unsupported target: {target}") + """Infer the layout for the GEMM operation based on target architecture.""" + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst) + return impl_class(self).infer_layout(target, thread_nums) + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + """Lower the GEMM operation to TIR statements based on target architecture.""" + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst) + return impl_class(self).lower(layout_map, target, thread_nums, thread_var) + + def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst: + """Select the appropriate GEMM instruction based on target and thread configuration. + + The selection logic follows this priority: + 1. WGMMA for Hopper architecture with sufficient matrix size and warp count + 2. MFMA for CDNA (AMD) architecture + 3. MMA for CUDA architecture + 4. Fallback to MMA for other cases + + Args: + thread_nums: Number of threads in the block + target: Target architecture + + Returns: + GemmInst: The selected GEMM instruction type + """ + return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target)) - def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): - if target_is_cuda(target): - # TODO(lei): Support more cuda architectures, now mma only - # Now only implement ssr layout - return GemmMMA(self).lower(target, thread_nums, thread_var) + def _get_implementation_class(self, gemm_inst: GemmInst): + """Get the appropriate implementation class for the given GEMM instruction. + + Args: + gemm_inst: The selected GEMM instruction type + + Returns: + The implementation class for the instruction type + + Raises: + NotImplementedError: If the instruction type is not supported + ValueError: If the instruction type is unknown + """ + if gemm_inst.is_mma(): + return GemmMMA + elif gemm_inst.is_wgmma(): + return GemmWGMMA + elif gemm_inst.is_mfma(): + raise NotImplementedError("MFMA is not implemented") else: - raise ValueError(f"Unsupported target: {target}") + raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 724187205..849b6d33a 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -5,6 +5,7 @@ from tilelang.utils.language import is_shared, is_fragment from tilelang.ir import GemmWarpPolicy from tvm.ir.base import Node +from tvm.ir import PrimExpr @dataclass @@ -103,7 +104,7 @@ def offset_B(self) -> int: return self.gemm_node.offset_B @property - def clear_accum(self) -> bool: + def clear_accum(self) -> PrimExpr: return self.gemm_node.clear_accum @property diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index a046ee126..42abe376a 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -57,7 +57,7 @@ def infer_layout(self, target: Target, thread_nums: int): raise ValueError( f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") - def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) @@ -87,6 +87,8 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): B_shared = self.B C_local = self.C + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + if self.is_gemm_ss(): @T.prim_func diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/tileop/gemm/gemm_wgmma.py new file mode 100644 index 000000000..6186d7925 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_wgmma.py @@ -0,0 +1,137 @@ +from .gemm_base import GemmBase +from tilelang.layout import make_wgmma_swizzled_layout +from tilelang.intrinsics.wgmma_macro_generator import ( + TensorCoreIntrinEmitter,) +from tilelang.utils.language import is_shared, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmWGMMA(GemmBase): + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + + if self.is_gemm_ss(): + a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp + b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp + return { + # WGMMA does not support padding + self.A: + make_wgmma_swizzled_layout( + self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: + make_wgmma_swizzled_layout( + self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: + mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp + return { + self.A: + mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: + make_wgmma_swizzled_layout( + self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: + mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError( + f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + True) + + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B]) + + A_shared = self.A + B_shared = self.B + C_local = self.C + clear_accum = self.clear_accum + + if self.is_gemm_ss(): + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + # Perform Matrix Multiplication + mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_rs(): + A_local = self.A + + @T.prim_func + def _gemm_rsr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_rsr, inline_let=True) + raise ValueError( + f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B)