|
| 1 | +#include <cudaTypedefs.h> |
| 2 | + |
| 3 | +#include <c10/cuda/CUDAGuard.h> |
| 4 | +#include <torch/all.h> |
| 5 | + |
| 6 | +#include "cutlass/cutlass.h" |
| 7 | +#include "grouped_mm_c3x.cuh" |
| 8 | + |
| 9 | +using namespace cute; |
| 10 | + |
| 11 | +namespace { |
| 12 | + |
| 13 | +template <typename InType, typename OutType, |
| 14 | + template <typename, typename, typename> typename Epilogue> |
| 15 | +struct sm100_fp8_config_default { |
| 16 | + static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| 17 | + using KernelSchedule = |
| 18 | + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; |
| 19 | + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; |
| 20 | + using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>; |
| 21 | + using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>; |
| 22 | + using ArchTag = cutlass::arch::Sm100; |
| 23 | + |
| 24 | + using Cutlass3xGemm = |
| 25 | + cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape, |
| 26 | + ClusterShape, KernelSchedule, EpilogueSchedule>; |
| 27 | +}; |
| 28 | + |
| 29 | +template <typename InType, typename OutType, |
| 30 | + template <typename, typename, typename> typename Epilogue> |
| 31 | +struct sm100_fp8_config_M64 { |
| 32 | + // M in [1,64] |
| 33 | + static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| 34 | + using KernelSchedule = |
| 35 | + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; |
| 36 | + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; |
| 37 | + using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>; |
| 38 | + using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>; |
| 39 | + using ArchTag = cutlass::arch::Sm100; |
| 40 | + |
| 41 | + using Cutlass3xGemm = |
| 42 | + cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape, |
| 43 | + ClusterShape, KernelSchedule, EpilogueSchedule, |
| 44 | + true>; |
| 45 | +}; |
| 46 | + |
| 47 | +template <typename InType, typename OutType, |
| 48 | + template <typename, typename, typename> typename Epilogue> |
| 49 | +struct sm100_fp8_config_N8192 { |
| 50 | + // N in [8192, inf) |
| 51 | + static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
| 52 | + using KernelSchedule = |
| 53 | + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; |
| 54 | + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; |
| 55 | + using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>; |
| 56 | + using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>; |
| 57 | + using ArchTag = cutlass::arch::Sm100; |
| 58 | + |
| 59 | + using Cutlass3xGemm = |
| 60 | + cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape, |
| 61 | + ClusterShape, KernelSchedule, EpilogueSchedule>; |
| 62 | +}; |
| 63 | + |
| 64 | +template <typename InType, typename OutType> |
| 65 | +void run_cutlass_moe_mm_sm100( |
| 66 | + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, |
| 67 | + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, |
| 68 | + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, |
| 69 | + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, |
| 70 | + torch::Tensor const& b_strides, torch::Tensor const& c_strides, |
| 71 | + bool per_act_token, bool per_out_ch) { |
| 72 | + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); |
| 73 | + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); |
| 74 | + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); |
| 75 | + |
| 76 | + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, |
| 77 | + "A tensors must be of type float8_e4m3fn."); |
| 78 | + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, |
| 79 | + "B tensors must be of type float8_e4m3fn."); |
| 80 | + |
| 81 | + using Cutlass3xGemmDefault = typename sm100_fp8_config_default< |
| 82 | + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; |
| 83 | + using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192< |
| 84 | + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; |
| 85 | + using Cutlass3xGemmM64 = typename sm100_fp8_config_M64< |
| 86 | + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; |
| 87 | + |
| 88 | + uint32_t const m = a_tensors.size(0); |
| 89 | + uint32_t const n = out_tensors.size(1); |
| 90 | + |
| 91 | + if (m <= 64) { |
| 92 | + cutlass_group_gemm_caller<Cutlass3xGemmM64>( |
| 93 | + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, |
| 94 | + problem_sizes, a_strides, b_strides, c_strides, per_act_token, |
| 95 | + per_out_ch); |
| 96 | + } else if (n >= 8192) { |
| 97 | + cutlass_group_gemm_caller<Cutlass3xGemmN8192>( |
| 98 | + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, |
| 99 | + problem_sizes, a_strides, b_strides, c_strides, per_act_token, |
| 100 | + per_out_ch); |
| 101 | + } else { |
| 102 | + cutlass_group_gemm_caller<Cutlass3xGemmDefault>( |
| 103 | + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, |
| 104 | + problem_sizes, a_strides, b_strides, c_strides, per_act_token, |
| 105 | + per_out_ch); |
| 106 | + } |
| 107 | +} |
| 108 | +} // namespace |
| 109 | + |
| 110 | +void dispatch_moe_mm_sm100( |
| 111 | + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, |
| 112 | + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, |
| 113 | + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, |
| 114 | + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, |
| 115 | + torch::Tensor const& b_strides, torch::Tensor const& c_strides, |
| 116 | + bool per_act_token, bool per_out_ch) { |
| 117 | + if (out_tensors.dtype() == torch::kBFloat16) { |
| 118 | + run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>( |
| 119 | + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, |
| 120 | + problem_sizes, a_strides, b_strides, c_strides, per_act_token, |
| 121 | + per_out_ch); |
| 122 | + } else { |
| 123 | + run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>( |
| 124 | + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, |
| 125 | + problem_sizes, a_strides, b_strides, c_strides, per_act_token, |
| 126 | + per_out_ch); |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +void cutlass_moe_mm_sm100( |
| 131 | + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, |
| 132 | + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, |
| 133 | + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, |
| 134 | + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, |
| 135 | + torch::Tensor const& b_strides, torch::Tensor const& c_strides, |
| 136 | + bool per_act_token, bool per_out_ch) { |
| 137 | + dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, |
| 138 | + expert_offsets, problem_sizes, a_strides, b_strides, |
| 139 | + c_strides, per_act_token, per_out_ch); |
| 140 | +} |
0 commit comments