Skip to content

Commit 3303f13

Browse files
authored
[Kernel] Add support for block FP8 on SM120 (NVIDIA 5090 and RTX PRO 6000) (#22131)
Signed-off-by: Junhao Li <[email protected]>
1 parent b2c8ce5 commit 3303f13

File tree

6 files changed

+229
-18
lines changed

6 files changed

+229
-18
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
427427
set(SRCS
428428
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
429429
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
430+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
430431
)
431432
set_gencode_flags_for_srcs(
432433
SRCS "${SRCS}"

csrc/cutlass_extensions/common.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
6060
#endif
6161
}
6262
};
63+
64+
template <typename Kernel>
65+
struct enable_sm120_only : Kernel {
66+
template <typename... Args>
67+
CUTLASS_DEVICE void operator()(Args&&... args) {
68+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
69+
Kernel::operator()(std::forward<Args>(args)...);
70+
#endif
71+
}
72+
};
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "scaled_mm_kernels.hpp"
2+
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
3+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4+
5+
namespace vllm {
6+
7+
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
8+
torch::Tensor const& a,
9+
torch::Tensor const& b,
10+
torch::Tensor const& a_scales,
11+
torch::Tensor const& b_scales) {
12+
if (out.dtype() == torch::kBFloat16) {
13+
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
14+
out, a, b, a_scales, b_scales);
15+
16+
} else {
17+
TORCH_CHECK(out.dtype() == torch::kFloat16);
18+
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
19+
out, a, b, a_scales, b_scales);
20+
}
21+
}
22+
23+
} // namespace vllm
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#pragma once
2+
3+
#include "cuda_utils.h"
4+
#include "cutlass/cutlass.h"
5+
#include "cutlass/numeric_types.h"
6+
7+
#include "cute/tensor.hpp"
8+
#include "cutlass/tensor_ref.h"
9+
#include "cutlass/gemm/dispatch_policy.hpp"
10+
#include "cutlass/gemm/collective/collective_builder.hpp"
11+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
12+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
13+
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
14+
#include "cutlass/epilogue/dispatch_policy.hpp"
15+
#include "cutlass/epilogue/collective/collective_builder.hpp"
16+
17+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
18+
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
19+
20+
#include "cutlass_gemm_caller.cuh"
21+
22+
namespace vllm {
23+
24+
using namespace cute;
25+
26+
// clang-format off
27+
template <class OutType, int ScaleGranularityM,
28+
int ScaleGranularityN, int ScaleGranularityK,
29+
class MmaTileShape, class ClusterShape,
30+
class EpilogueScheduler, class MainloopScheduler>
31+
struct cutlass_3x_gemm_fp8_blockwise {
32+
using ElementAB = cutlass::float_e4m3_t;
33+
34+
using ElementA = ElementAB;
35+
using LayoutA = cutlass::layout::RowMajor;
36+
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
37+
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
38+
39+
using ElementB = ElementAB;
40+
// ColumnMajor is used for B to match the CUTLASS convention.
41+
using LayoutB = cutlass::layout::ColumnMajor;
42+
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
43+
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
44+
45+
using ElementD = OutType;
46+
using LayoutD = cutlass::layout::RowMajor;
47+
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
48+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
49+
50+
using ElementC = void; // TODO: support bias
51+
using LayoutC = LayoutD;
52+
using LayoutC_Transpose = LayoutD_Transpose;
53+
static constexpr int AlignmentC = AlignmentD;
54+
55+
using ElementAccumulator = float;
56+
using ElementCompute = float;
57+
using ElementBlockScale = float;
58+
59+
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
60+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
61+
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
62+
63+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
64+
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
65+
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
66+
67+
using ArchTag = cutlass::arch::Sm120;
68+
using OperatorClass = cutlass::arch::OpClassTensorOp;
69+
70+
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
71+
using ElementScalar = float;
72+
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
73+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
74+
ArchTag,
75+
OperatorClass,
76+
MmaTileShape,
77+
ClusterShape,
78+
cutlass::epilogue::collective::EpilogueTileAuto,
79+
ElementAccumulator,
80+
ElementCompute,
81+
ElementC,
82+
LayoutC,
83+
AlignmentC,
84+
ElementD,
85+
LayoutD,
86+
AlignmentD,
87+
EpilogueScheduler,
88+
DefaultOperation
89+
>::CollectiveOp;
90+
91+
using StageCountType = cutlass::gemm::collective::StageCountAuto;
92+
using CollectiveMainloop =
93+
typename cutlass::gemm::collective::CollectiveBuilder<
94+
ArchTag,
95+
OperatorClass,
96+
ElementA,
97+
cute::tuple<LayoutA, LayoutSFA>,
98+
AlignmentA,
99+
ElementB,
100+
cute::tuple<LayoutB, LayoutSFB>,
101+
AlignmentB,
102+
ElementAccumulator,
103+
MmaTileShape,
104+
ClusterShape,
105+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
106+
MainloopScheduler
107+
>::CollectiveOp;
108+
109+
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
110+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
111+
112+
struct GemmKernel : public KernelType {};
113+
};
114+
115+
template <typename Gemm>
116+
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
117+
torch::Tensor const& b,
118+
torch::Tensor const& a_scales,
119+
torch::Tensor const& b_scales) {
120+
using GemmKernel = typename Gemm::GemmKernel;
121+
using StrideA = typename Gemm::GemmKernel::StrideA;
122+
using StrideB = typename Gemm::GemmKernel::StrideB;
123+
using StrideD = typename Gemm::GemmKernel::StrideD;
124+
using StrideC = typename Gemm::GemmKernel::StrideC;
125+
using LayoutSFA = typename Gemm::LayoutSFA;
126+
using LayoutSFB = typename Gemm::LayoutSFB;
127+
using ScaleConfig = typename Gemm::ScaleConfig;
128+
129+
using ElementAB = typename Gemm::ElementAB;
130+
using ElementD = typename Gemm::ElementD;
131+
132+
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
133+
134+
StrideA a_stride;
135+
StrideB b_stride;
136+
StrideC c_stride;
137+
a_stride =
138+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
139+
b_stride =
140+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
141+
c_stride =
142+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
143+
144+
LayoutSFA layout_SFA =
145+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
146+
LayoutSFB layout_SFB =
147+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
148+
149+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
150+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
151+
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
152+
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
153+
154+
auto mainloop_args = [&](){
155+
return typename GemmKernel::MainloopArguments{
156+
a_ptr, a_stride, b_ptr, b_stride,
157+
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
158+
};
159+
}();
160+
auto prob_shape = cute::make_shape(m, n, k, 1);
161+
162+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
163+
typename GemmKernel::EpilogueArguments epilogue_args{
164+
{}, c_ptr, c_stride, c_ptr, c_stride};
165+
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
166+
epilogue_args);
167+
}
168+
169+
template <typename OutType>
170+
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
171+
torch::Tensor const& a,
172+
torch::Tensor const& b,
173+
torch::Tensor const& a_scales,
174+
torch::Tensor const& b_scales) {
175+
// TODO: better heuristics
176+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
177+
OutType, 1, 128, 128, Shape<_128, _128, _128>,
178+
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
179+
cutlass::gemm::collective::KernelScheduleAuto>>(
180+
out, a, b, a_scales, b_scales);
181+
}
182+
183+
} // namespace vllm

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
4747
torch::Tensor const& b,
4848
torch::Tensor const& a_scales,
4949
torch::Tensor const& b_scales);
50+
51+
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
52+
torch::Tensor const& a,
53+
torch::Tensor const& b,
54+
torch::Tensor const& a_scales,
55+
torch::Tensor const& b_scales);
5056
} // namespace vllm
Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
#include <cudaTypedefs.h>
1+
#include "c3x/scaled_mm_helper.hpp"
22
#include "c3x/scaled_mm_kernels.hpp"
33

4-
#include "cuda_utils.h"
5-
64
/*
75
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8-
NVIDIA GPUs with sm120 (Blackwell Geforce).
6+
NVIDIA GPUs with sm120 (Blackwell).
97
*/
108

119
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
1513
torch::Tensor const& a_scales,
1614
torch::Tensor const& b_scales,
1715
std::optional<torch::Tensor> const& bias) {
18-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20-
21-
int M = a.size(0), N = b.size(1), K = a.size(1);
22-
TORCH_CHECK(
23-
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25-
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26-
27-
// Standard per-tensor/per-token/per-channel scaling
28-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30-
"Currently, only fp8 gemm is implemented for Blackwell");
31-
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
16+
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
17+
vllm::cutlass_scaled_mm_sm120_fp8,
18+
nullptr, // int8 not supported on SM120
19+
vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
3220
}
3321

3422
#endif

0 commit comments

Comments
 (0)