|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "cuda_utils.h" |
| 4 | +#include "cutlass/cutlass.h" |
| 5 | +#include "cutlass/numeric_types.h" |
| 6 | + |
| 7 | +#include "cute/tensor.hpp" |
| 8 | +#include "cutlass/tensor_ref.h" |
| 9 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 10 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 11 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 12 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 13 | +#include "cutlass/gemm/kernel/tile_scheduler_params.h" |
| 14 | +#include "cutlass/epilogue/dispatch_policy.hpp" |
| 15 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 16 | + |
| 17 | +#include "cutlass_extensions/gemm/dispatch_policy.hpp" |
| 18 | +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" |
| 19 | + |
| 20 | +#include "cutlass_gemm_caller.cuh" |
| 21 | + |
| 22 | +namespace vllm { |
| 23 | + |
| 24 | +using namespace cute; |
| 25 | + |
| 26 | +// clang-format off |
| 27 | +template <class OutType, int ScaleGranularityM, |
| 28 | + int ScaleGranularityN, int ScaleGranularityK, |
| 29 | + class MmaTileShape, class ClusterShape, |
| 30 | + class EpilogueScheduler, class MainloopScheduler> |
| 31 | +struct cutlass_3x_gemm_fp8_blockwise { |
| 32 | + using ElementAB = cutlass::float_e4m3_t; |
| 33 | + |
| 34 | + using ElementA = ElementAB; |
| 35 | + using LayoutA = cutlass::layout::RowMajor; |
| 36 | + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type; |
| 37 | + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
| 38 | + |
| 39 | + using ElementB = ElementAB; |
| 40 | + // ColumnMajor is used for B to match the CUTLASS convention. |
| 41 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 42 | + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type; |
| 43 | + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
| 44 | + |
| 45 | + using ElementD = OutType; |
| 46 | + using LayoutD = cutlass::layout::RowMajor; |
| 47 | + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type; |
| 48 | + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
| 49 | + |
| 50 | + using ElementC = void; // TODO: support bias |
| 51 | + using LayoutC = LayoutD; |
| 52 | + using LayoutC_Transpose = LayoutD_Transpose; |
| 53 | + static constexpr int AlignmentC = AlignmentD; |
| 54 | + |
| 55 | + using ElementAccumulator = float; |
| 56 | + using ElementCompute = float; |
| 57 | + using ElementBlockScale = float; |
| 58 | + |
| 59 | + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< |
| 60 | + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, |
| 61 | + cute::UMMA::Major::MN, cute::UMMA::Major::K>; |
| 62 | + |
| 63 | + // layout_SFA and layout_SFB cannot be swapped since they are deduced. |
| 64 | + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); |
| 65 | + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); |
| 66 | + |
| 67 | + using ArchTag = cutlass::arch::Sm120; |
| 68 | + using OperatorClass = cutlass::arch::OpClassTensorOp; |
| 69 | + |
| 70 | + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; |
| 71 | + using ElementScalar = float; |
| 72 | + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>; |
| 73 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 74 | + ArchTag, |
| 75 | + OperatorClass, |
| 76 | + MmaTileShape, |
| 77 | + ClusterShape, |
| 78 | + cutlass::epilogue::collective::EpilogueTileAuto, |
| 79 | + ElementAccumulator, |
| 80 | + ElementCompute, |
| 81 | + ElementC, |
| 82 | + LayoutC, |
| 83 | + AlignmentC, |
| 84 | + ElementD, |
| 85 | + LayoutD, |
| 86 | + AlignmentD, |
| 87 | + EpilogueScheduler, |
| 88 | + DefaultOperation |
| 89 | + >::CollectiveOp; |
| 90 | + |
| 91 | + using StageCountType = cutlass::gemm::collective::StageCountAuto; |
| 92 | + using CollectiveMainloop = |
| 93 | + typename cutlass::gemm::collective::CollectiveBuilder< |
| 94 | + ArchTag, |
| 95 | + OperatorClass, |
| 96 | + ElementA, |
| 97 | + cute::tuple<LayoutA, LayoutSFA>, |
| 98 | + AlignmentA, |
| 99 | + ElementB, |
| 100 | + cute::tuple<LayoutB, LayoutSFB>, |
| 101 | + AlignmentB, |
| 102 | + ElementAccumulator, |
| 103 | + MmaTileShape, |
| 104 | + ClusterShape, |
| 105 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 106 | + MainloopScheduler |
| 107 | + >::CollectiveOp; |
| 108 | + |
| 109 | + using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal< |
| 110 | + Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>; |
| 111 | + |
| 112 | + struct GemmKernel : public KernelType {}; |
| 113 | +}; |
| 114 | + |
| 115 | +template <typename Gemm> |
| 116 | +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, |
| 117 | + torch::Tensor const& b, |
| 118 | + torch::Tensor const& a_scales, |
| 119 | + torch::Tensor const& b_scales) { |
| 120 | + using GemmKernel = typename Gemm::GemmKernel; |
| 121 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 122 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 123 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 124 | + using StrideC = typename Gemm::GemmKernel::StrideC; |
| 125 | + using LayoutSFA = typename Gemm::LayoutSFA; |
| 126 | + using LayoutSFB = typename Gemm::LayoutSFB; |
| 127 | + using ScaleConfig = typename Gemm::ScaleConfig; |
| 128 | + |
| 129 | + using ElementAB = typename Gemm::ElementAB; |
| 130 | + using ElementD = typename Gemm::ElementD; |
| 131 | + |
| 132 | + int32_t m = a.size(0), n = b.size(1), k = a.size(1); |
| 133 | + |
| 134 | + StrideA a_stride; |
| 135 | + StrideB b_stride; |
| 136 | + StrideC c_stride; |
| 137 | + a_stride = |
| 138 | + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); |
| 139 | + b_stride = |
| 140 | + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); |
| 141 | + c_stride = |
| 142 | + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); |
| 143 | + |
| 144 | + LayoutSFA layout_SFA = |
| 145 | + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); |
| 146 | + LayoutSFB layout_SFB = |
| 147 | + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); |
| 148 | + |
| 149 | + auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); |
| 150 | + auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); |
| 151 | + auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr()); |
| 152 | + auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr()); |
| 153 | + |
| 154 | + auto mainloop_args = [&](){ |
| 155 | + return typename GemmKernel::MainloopArguments{ |
| 156 | + a_ptr, a_stride, b_ptr, b_stride, |
| 157 | + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB |
| 158 | + }; |
| 159 | + }(); |
| 160 | + auto prob_shape = cute::make_shape(m, n, k, 1); |
| 161 | + |
| 162 | + auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
| 163 | + typename GemmKernel::EpilogueArguments epilogue_args{ |
| 164 | + {}, c_ptr, c_stride, c_ptr, c_stride}; |
| 165 | + c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, |
| 166 | + epilogue_args); |
| 167 | +} |
| 168 | + |
| 169 | +template <typename OutType> |
| 170 | +void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, |
| 171 | + torch::Tensor const& a, |
| 172 | + torch::Tensor const& b, |
| 173 | + torch::Tensor const& a_scales, |
| 174 | + torch::Tensor const& b_scales) { |
| 175 | + // TODO: better heuristics |
| 176 | + cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
| 177 | + OutType, 1, 128, 128, Shape<_128, _128, _128>, |
| 178 | + Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto, |
| 179 | + cutlass::gemm::collective::KernelScheduleAuto>>( |
| 180 | + out, a, b, a_scales, b_scales); |
| 181 | +} |
| 182 | + |
| 183 | +} // namespace vllm |
0 commit comments