2929using namespace cute ;
3030
3131/*
32- This defines a quantized GEMM operation with dequantized output, similar to
33- torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
32+ This file defines quantized GEMM operations using the CUTLASS 2.x API, for
3433 NVIDIA GPUs with SM versions prior to sm90 (Hopper).
3534
36- A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
37- per-row. B can be quantized per-tensor or per-column.
38- Any combination of per-tensor and per-row or column is supported.
39- A and B must have symmetric quantization (zero point == 0).
40-
41- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
42- scales are applied elementwise with numpy-style broadcasting.
43-
44- ScaleA and ScaleB define the epilogue functions that apply the scales for
45- the A and B operands respectively. These scales may be either per-tensor or
46- per row or column.
35+ Epilogue functions can be defined to post-process the output before it is
36+ written to GPU memory.
37+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
38+ as well as a static prepare_args function that constructs an
39+ EVTCompute::Arguments struct.
4740*/
4841
4942namespace {
@@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
8376 }
8477};
8578
86- template <typename Arch, template <typename > typename ArchGuard,
87- typename ElementAB_, typename ElementD_, typename TileShape,
88- typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
89- struct cutlass_2x_gemm {
90- using ElementAB = ElementAB_;
91- using ElementD = ElementD_;
92-
93- using ElementAcc =
94- typename std::conditional<std::is_same_v<ElementAB, int8_t >, int32_t ,
95- float >::type;
79+ /*
80+ This epilogue function defines a quantized GEMM operation similar to
81+ torch._scaled_mm.
9682
97- using Operator =
98- typename std::conditional<std::is_same_v<ElementAB, int8_t >,
99- cutlass::arch::OpMultiplyAddSaturate,
100- cutlass::arch::OpMultiplyAdd>::type;
83+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84+ per-row. B can be quantized per-tensor or per-column.
85+ Any combination of per-tensor and per-row or column is supported.
86+ A and B must have symmetric quantization (zero point == 0).
10187
102- using OutputTileThreadMap =
103- cutlass::epilogue::threadblock::OutputTileThreadLayout<
104- TileShape, WarpShape, float , 4 , 1 /* epilogue stages */
105- >;
88+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89+ scales are applied elementwise with numpy-style broadcasting.
10690
91+ ScaleA and ScaleB define the epilogue functions that apply the scales for
92+ the A and B operands respectively. These scales may be either per-tensor or
93+ per row or column.
94+ */
95+ template <typename ElementD, typename OutputTileThreadMap>
96+ struct ScaledEpilogue {
97+ private:
10798 using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
10899
109100 using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
123114 cutlass::multiplies, ElementD, float ,
124115 cutlass::FloatRoundStyle::round_to_nearest>;
125116
126- using EVTCompute1 =
117+ public:
118+ using EVTCompute =
127119 cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
120+ using ArgumentType = typename EVTCompute::Arguments;
121+
122+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
123+ torch::Tensor const & b_scales) {
124+ using ScaleAArgs = typename ScaleA::Arguments;
125+ using ScaleBArgs = typename ScaleB::Arguments;
126+
127+ ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
128+ ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
129+
130+ typename EVTCompute0::Arguments evt0_compute_args{b_args};
131+
132+ typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
133+ return evt_compute_args;
134+ }
135+ };
136+
137+ template <typename Arch, template <typename > typename ArchGuard,
138+ typename ElementAB_, typename ElementD_,
139+ template <typename , typename > typename Epilogue_, typename TileShape,
140+ typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
141+ struct cutlass_2x_gemm {
142+ using ElementAB = ElementAB_;
143+ using ElementD = ElementD_;
144+
145+ using ElementAcc =
146+ typename std::conditional<std::is_same_v<ElementAB, int8_t >, int32_t ,
147+ float >::type;
148+
149+ using Operator =
150+ typename std::conditional<std::is_same_v<ElementAB, int8_t >,
151+ cutlass::arch::OpMultiplyAddSaturate,
152+ cutlass::arch::OpMultiplyAdd>::type;
153+
154+ using OutputTileThreadMap =
155+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
156+ TileShape, WarpShape, float , 4 , 1 /* epilogue stages */
157+ >;
158+
159+ using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
160+ using EVTCompute = typename Epilogue::EVTCompute;
128161
129162 using D = cutlass::epilogue::threadblock::VisitorAuxStore<
130163 OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
131164 Stride<int64_t , Int<1 >, Int<0 >>>;
132165
133- using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1 >;
166+ using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute >;
134167
135168 // clang-format off
136169 using RowMajor = typename cutlass::layout::RowMajor;
@@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
153186 using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
154187};
155188
156- template <typename Gemm>
157- void cutlass_scaled_mm_dq_dispatcher (torch::Tensor& out, torch::Tensor const & a,
158- torch::Tensor const & b,
159- torch::Tensor const & a_scales,
160- torch::Tensor const & b_scales) {
189+ template <typename Gemm, typename ... EpilogueArgs>
190+ void cutlass_gemm_caller (torch::Tensor& out, torch::Tensor const & a,
191+ torch::Tensor const & b,
192+ EpilogueArgs&&... epilogue_params) {
161193 using ElementAB = typename Gemm::ElementAB;
162194 using ElementD = typename Gemm::ElementD;
163195
@@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
177209 auto b_ptr = static_cast <ElementAB const *>(b.data_ptr ());
178210 auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
179211
180- auto a_scales_ptr = a_scales.data_ptr <float >();
181- auto b_scales_ptr = b_scales.data_ptr <float >();
182-
183- using ScaleAArgs = typename Gemm::ScaleA::Arguments;
184- using ScaleBArgs = typename Gemm::ScaleB::Arguments;
185-
186- ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
187- ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
188-
189- typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
190-
191- typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
192- evt0_compute_args};
193212 typename Gemm::D::Arguments d_args{c_ptr, c_stride};
194213
214+ using Epilogue = typename Gemm::Epilogue;
215+ auto evt_args =
216+ Epilogue::prepare_args (std::forward<EpilogueArgs>(epilogue_params)...);
217+
195218 typename Gemm::EVTD::Arguments epilogue_args{
196- evt1_compute_args ,
219+ evt_args ,
197220 d_args,
198221 };
199222
@@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
229252
230253} // namespace
231254
232- void cutlass_scaled_mm_dq_sm75 (torch::Tensor& out, torch::Tensor const & a,
233- torch::Tensor const & b,
234- torch::Tensor const & a_scales,
235- torch::Tensor const & b_scales) {
255+ void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
256+ torch::Tensor const & b,
257+ torch::Tensor const & a_scales,
258+ torch::Tensor const & b_scales) {
236259 TORCH_CHECK (a.dtype () == torch::kInt8 );
237260 TORCH_CHECK (b.dtype () == torch::kInt8 );
238261 TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
@@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
243266 using InstructionShape = typename cutlass::gemm::GemmShape<8 , 8 , 16 >;
244267
245268 if (out.dtype () == torch::kBFloat16 ) {
246- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
269+ return cutlass_gemm_caller <cutlass_2x_gemm<
247270 cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::bfloat16_t ,
248- TileShape, WarpShape, InstructionShape, 2 >>(out, a, b, a_scales,
249- b_scales);
271+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2 >>(
272+ out, a, b, a_scales, b_scales);
250273 } else {
251274 TORCH_CHECK (out.dtype () == torch::kFloat16 );
252- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
275+ return cutlass_gemm_caller <cutlass_2x_gemm<
253276 cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::half_t ,
254- TileShape, WarpShape, InstructionShape, 2 >>(out, a, b, a_scales,
255- b_scales);
277+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2 >>(
278+ out, a, b, a_scales, b_scales);
256279 }
257280}
258281
259- void cutlass_scaled_mm_dq_sm80 (torch::Tensor& out, torch::Tensor const & a,
260- torch::Tensor const & b,
261- torch::Tensor const & a_scales,
262- torch::Tensor const & b_scales) {
282+ void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
283+ torch::Tensor const & b,
284+ torch::Tensor const & a_scales,
285+ torch::Tensor const & b_scales) {
263286 TORCH_CHECK (a.dtype () == torch::kInt8 );
264287 TORCH_CHECK (b.dtype () == torch::kInt8 );
265288 TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
@@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
270293 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
271294
272295 if (out.dtype () == torch::kBFloat16 ) {
273- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
296+ return cutlass_gemm_caller <cutlass_2x_gemm<
274297 cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::bfloat16_t ,
275- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
276- b_scales);
298+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
299+ out, a, b, a_scales, b_scales);
277300 } else {
278301 TORCH_CHECK (out.dtype () == torch::kFloat16 );
279- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
302+ return cutlass_gemm_caller <cutlass_2x_gemm<
280303 cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::half_t ,
281- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
282- b_scales);
304+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
305+ out, a, b, a_scales, b_scales);
283306 }
284307}
285308
286- void cutlass_scaled_mm_dq_sm89 (torch::Tensor& out, torch::Tensor const & a,
287- torch::Tensor const & b,
288- torch::Tensor const & a_scales,
289- torch::Tensor const & b_scales) {
309+ void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
310+ torch::Tensor const & b,
311+ torch::Tensor const & a_scales,
312+ torch::Tensor const & b_scales) {
290313 using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
291314 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
292315 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
298321 TORCH_CHECK (b.dtype () == torch::kInt8 );
299322
300323 if (out.dtype () == torch::kBFloat16 ) {
301- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
324+ return cutlass_gemm_caller <cutlass_2x_gemm<
302325 cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::bfloat16_t ,
303- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
304- b_scales);
326+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
327+ out, a, b, a_scales, b_scales);
305328 } else {
306329 assert (out.dtype () == torch::kFloat16 );
307- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
330+ return cutlass_gemm_caller <cutlass_2x_gemm<
308331 cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::half_t ,
309- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
310- b_scales);
332+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
333+ out, a, b, a_scales, b_scales);
311334 }
312335 } else {
313336 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
314337 TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
315338
316339 if (out.dtype () == torch::kBFloat16 ) {
317- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
340+ return cutlass_gemm_caller <cutlass_2x_gemm<
318341 cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
319- cutlass::bfloat16_t , TileShape, WarpShape, InstructionShape, 5 >>(
320- out, a, b, a_scales, b_scales);
342+ cutlass::bfloat16_t , ScaledEpilogue, TileShape, WarpShape,
343+ InstructionShape, 5 >>( out, a, b, a_scales, b_scales);
321344 } else {
322345 TORCH_CHECK (out.dtype () == torch::kFloat16 );
323- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
346+ return cutlass_gemm_caller <cutlass_2x_gemm<
324347 cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
325- cutlass::half_t , TileShape, WarpShape, InstructionShape, 5 >>(
326- out, a, b, a_scales, b_scales);
348+ cutlass::half_t , ScaledEpilogue, TileShape, WarpShape,
349+ InstructionShape, 5 >>( out, a, b, a_scales, b_scales);
327350 }
328351 }
329352}
0 commit comments