Skip to content

Commit 85657b5

Browse files
tlrmchlsmthmgoinyoukaichaozifeitongrobertgshaw2-redhat
authored
[Kernel] Factor out epilogues from cutlass kernels (#5391)
Co-authored-by: Michael Goin <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: zifeitong <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 0ce7b95 commit 85657b5

File tree

12 files changed

+274
-232
lines changed

12 files changed

+274
-232
lines changed

CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
179179
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
180180
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
181181
"csrc/custom_all_reduce.cu"
182-
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
183-
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
184-
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
182+
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
183+
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
184+
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
185185

186186
#
187187
# The CUTLASS kernels for Hopper require sm90a to be enabled.
188188
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
189189
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
190190
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
191191
set_source_files_properties(
192-
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
192+
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
193193
PROPERTIES
194194
COMPILE_FLAGS
195195
"-gencode arch=compute_90a,code=sm_90a")

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
7676
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
7777
scale_b: torch.tensor,
7878
out_dtype: torch.dtype) -> torch.tensor:
79-
return ops.cutlass_scaled_mm_dq(a,
80-
b,
81-
scale_a,
82-
scale_b,
83-
out_dtype=out_dtype)
79+
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
8480

8581

8682
# bench

csrc/ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
9090
int64_t size_k, int64_t size_n,
9191
int64_t num_bits);
9292

93-
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
94-
torch::Tensor const& b, torch::Tensor const& a_scales,
95-
torch::Tensor const& b_scales);
93+
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
94+
torch::Tensor const& b, torch::Tensor const& a_scales,
95+
torch::Tensor const& b_scales);
9696

9797
#endif
9898

csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu renamed to csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 111 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,14 @@
2929
using namespace cute;
3030

3131
/*
32-
This defines a quantized GEMM operation with dequantized output, similar to
33-
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
32+
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
3433
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
3534
36-
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
37-
per-row. B can be quantized per-tensor or per-column.
38-
Any combination of per-tensor and per-row or column is supported.
39-
A and B must have symmetric quantization (zero point == 0).
40-
41-
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
42-
scales are applied elementwise with numpy-style broadcasting.
43-
44-
ScaleA and ScaleB define the epilogue functions that apply the scales for
45-
the A and B operands respectively. These scales may be either per-tensor or
46-
per row or column.
35+
Epilogue functions can be defined to post-process the output before it is
36+
written to GPU memory.
37+
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
38+
as well as a static prepare_args function that constructs an
39+
EVTCompute::Arguments struct.
4740
*/
4841

4942
namespace {
@@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
8376
}
8477
};
8578

86-
template <typename Arch, template <typename> typename ArchGuard,
87-
typename ElementAB_, typename ElementD_, typename TileShape,
88-
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
89-
struct cutlass_2x_gemm {
90-
using ElementAB = ElementAB_;
91-
using ElementD = ElementD_;
92-
93-
using ElementAcc =
94-
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
95-
float>::type;
79+
/*
80+
This epilogue function defines a quantized GEMM operation similar to
81+
torch._scaled_mm.
9682
97-
using Operator =
98-
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
99-
cutlass::arch::OpMultiplyAddSaturate,
100-
cutlass::arch::OpMultiplyAdd>::type;
83+
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84+
per-row. B can be quantized per-tensor or per-column.
85+
Any combination of per-tensor and per-row or column is supported.
86+
A and B must have symmetric quantization (zero point == 0).
10187
102-
using OutputTileThreadMap =
103-
cutlass::epilogue::threadblock::OutputTileThreadLayout<
104-
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
105-
>;
88+
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89+
scales are applied elementwise with numpy-style broadcasting.
10690
91+
ScaleA and ScaleB define the epilogue functions that apply the scales for
92+
the A and B operands respectively. These scales may be either per-tensor or
93+
per row or column.
94+
*/
95+
template <typename ElementD, typename OutputTileThreadMap>
96+
struct ScaledEpilogue {
97+
private:
10798
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
10899

109100
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
123114
cutlass::multiplies, ElementD, float,
124115
cutlass::FloatRoundStyle::round_to_nearest>;
125116

126-
using EVTCompute1 =
117+
public:
118+
using EVTCompute =
127119
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
120+
using ArgumentType = typename EVTCompute::Arguments;
121+
122+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
123+
torch::Tensor const& b_scales) {
124+
using ScaleAArgs = typename ScaleA::Arguments;
125+
using ScaleBArgs = typename ScaleB::Arguments;
126+
127+
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
128+
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
129+
130+
typename EVTCompute0::Arguments evt0_compute_args{b_args};
131+
132+
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
133+
return evt_compute_args;
134+
}
135+
};
136+
137+
template <typename Arch, template <typename> typename ArchGuard,
138+
typename ElementAB_, typename ElementD_,
139+
template <typename, typename> typename Epilogue_, typename TileShape,
140+
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
141+
struct cutlass_2x_gemm {
142+
using ElementAB = ElementAB_;
143+
using ElementD = ElementD_;
144+
145+
using ElementAcc =
146+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
147+
float>::type;
148+
149+
using Operator =
150+
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
151+
cutlass::arch::OpMultiplyAddSaturate,
152+
cutlass::arch::OpMultiplyAdd>::type;
153+
154+
using OutputTileThreadMap =
155+
cutlass::epilogue::threadblock::OutputTileThreadLayout<
156+
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
157+
>;
158+
159+
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
160+
using EVTCompute = typename Epilogue::EVTCompute;
128161

129162
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
130163
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
131164
Stride<int64_t, Int<1>, Int<0>>>;
132165

133-
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
166+
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
134167

135168
// clang-format off
136169
using RowMajor = typename cutlass::layout::RowMajor;
@@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
153186
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
154187
};
155188

156-
template <typename Gemm>
157-
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
158-
torch::Tensor const& b,
159-
torch::Tensor const& a_scales,
160-
torch::Tensor const& b_scales) {
189+
template <typename Gemm, typename... EpilogueArgs>
190+
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
191+
torch::Tensor const& b,
192+
EpilogueArgs&&... epilogue_params) {
161193
using ElementAB = typename Gemm::ElementAB;
162194
using ElementD = typename Gemm::ElementD;
163195

@@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
177209
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
178210
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
179211

180-
auto a_scales_ptr = a_scales.data_ptr<float>();
181-
auto b_scales_ptr = b_scales.data_ptr<float>();
182-
183-
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
184-
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
185-
186-
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
187-
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
188-
189-
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
190-
191-
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
192-
evt0_compute_args};
193212
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
194213

214+
using Epilogue = typename Gemm::Epilogue;
215+
auto evt_args =
216+
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
217+
195218
typename Gemm::EVTD::Arguments epilogue_args{
196-
evt1_compute_args,
219+
evt_args,
197220
d_args,
198221
};
199222

@@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
229252

230253
} // namespace
231254

232-
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
233-
torch::Tensor const& b,
234-
torch::Tensor const& a_scales,
235-
torch::Tensor const& b_scales) {
255+
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
256+
torch::Tensor const& b,
257+
torch::Tensor const& a_scales,
258+
torch::Tensor const& b_scales) {
236259
TORCH_CHECK(a.dtype() == torch::kInt8);
237260
TORCH_CHECK(b.dtype() == torch::kInt8);
238261
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
243266
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
244267

245268
if (out.dtype() == torch::kBFloat16) {
246-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
269+
return cutlass_gemm_caller<cutlass_2x_gemm<
247270
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
248-
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
249-
b_scales);
271+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
272+
out, a, b, a_scales, b_scales);
250273
} else {
251274
TORCH_CHECK(out.dtype() == torch::kFloat16);
252-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
275+
return cutlass_gemm_caller<cutlass_2x_gemm<
253276
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
254-
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
255-
b_scales);
277+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
278+
out, a, b, a_scales, b_scales);
256279
}
257280
}
258281

259-
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
260-
torch::Tensor const& b,
261-
torch::Tensor const& a_scales,
262-
torch::Tensor const& b_scales) {
282+
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
283+
torch::Tensor const& b,
284+
torch::Tensor const& a_scales,
285+
torch::Tensor const& b_scales) {
263286
TORCH_CHECK(a.dtype() == torch::kInt8);
264287
TORCH_CHECK(b.dtype() == torch::kInt8);
265288
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
270293
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
271294

272295
if (out.dtype() == torch::kBFloat16) {
273-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
296+
return cutlass_gemm_caller<cutlass_2x_gemm<
274297
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
275-
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
276-
b_scales);
298+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
299+
out, a, b, a_scales, b_scales);
277300
} else {
278301
TORCH_CHECK(out.dtype() == torch::kFloat16);
279-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
302+
return cutlass_gemm_caller<cutlass_2x_gemm<
280303
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
281-
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
282-
b_scales);
304+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
305+
out, a, b, a_scales, b_scales);
283306
}
284307
}
285308

286-
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
287-
torch::Tensor const& b,
288-
torch::Tensor const& a_scales,
289-
torch::Tensor const& b_scales) {
309+
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
310+
torch::Tensor const& b,
311+
torch::Tensor const& a_scales,
312+
torch::Tensor const& b_scales) {
290313
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
291314
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
292315
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
298321
TORCH_CHECK(b.dtype() == torch::kInt8);
299322

300323
if (out.dtype() == torch::kBFloat16) {
301-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
324+
return cutlass_gemm_caller<cutlass_2x_gemm<
302325
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
303-
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
304-
b_scales);
326+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
327+
out, a, b, a_scales, b_scales);
305328
} else {
306329
assert(out.dtype() == torch::kFloat16);
307-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
330+
return cutlass_gemm_caller<cutlass_2x_gemm<
308331
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
309-
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
310-
b_scales);
332+
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
333+
out, a, b, a_scales, b_scales);
311334
}
312335
} else {
313336
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
314337
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
315338

316339
if (out.dtype() == torch::kBFloat16) {
317-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
340+
return cutlass_gemm_caller<cutlass_2x_gemm<
318341
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
319-
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
320-
out, a, b, a_scales, b_scales);
342+
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
343+
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
321344
} else {
322345
TORCH_CHECK(out.dtype() == torch::kFloat16);
323-
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
346+
return cutlass_gemm_caller<cutlass_2x_gemm<
324347
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
325-
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
326-
out, a, b, a_scales, b_scales);
348+
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
349+
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
327350
}
328351
}
329352
}

0 commit comments

Comments
 (0)