29
29
using namespace cute ;
30
30
31
31
/*
32
- This defines a quantized GEMM operation with dequantized output, similar to
33
- torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
32
+ This file defines quantized GEMM operations using the CUTLASS 2.x API, for
34
33
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
35
34
36
- A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
37
- per-row. B can be quantized per-tensor or per-column.
38
- Any combination of per-tensor and per-row or column is supported.
39
- A and B must have symmetric quantization (zero point == 0).
40
-
41
- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
42
- scales are applied elementwise with numpy-style broadcasting.
43
-
44
- ScaleA and ScaleB define the epilogue functions that apply the scales for
45
- the A and B operands respectively. These scales may be either per-tensor or
46
- per row or column.
35
+ Epilogue functions can be defined to post-process the output before it is
36
+ written to GPU memory.
37
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
38
+ as well as a static prepare_args function that constructs an
39
+ EVTCompute::Arguments struct.
47
40
*/
48
41
49
42
namespace {
@@ -83,27 +76,25 @@ struct enable_sm89_to_sm90 : Kernel {
83
76
}
84
77
};
85
78
86
- template <typename Arch, template <typename > typename ArchGuard,
87
- typename ElementAB_, typename ElementD_, typename TileShape,
88
- typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
89
- struct cutlass_2x_gemm {
90
- using ElementAB = ElementAB_;
91
- using ElementD = ElementD_;
92
-
93
- using ElementAcc =
94
- typename std::conditional<std::is_same_v<ElementAB, int8_t >, int32_t ,
95
- float >::type;
79
+ /*
80
+ This epilogue function defines a quantized GEMM operation similar to
81
+ torch._scaled_mm.
96
82
97
- using Operator =
98
- typename std::conditional<std::is_same_v<ElementAB, int8_t >,
99
- cutlass::arch::OpMultiplyAddSaturate,
100
- cutlass::arch::OpMultiplyAdd>::type;
83
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84
+ per-row. B can be quantized per-tensor or per-column.
85
+ Any combination of per-tensor and per-row or column is supported.
86
+ A and B must have symmetric quantization (zero point == 0).
101
87
102
- using OutputTileThreadMap =
103
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
104
- TileShape, WarpShape, float , 4 , 1 /* epilogue stages */
105
- >;
88
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89
+ scales are applied elementwise with numpy-style broadcasting.
106
90
91
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
92
+ the A and B operands respectively. These scales may be either per-tensor or
93
+ per row or column.
94
+ */
95
+ template <typename ElementD, typename OutputTileThreadMap>
96
+ struct ScaledEpilogue {
97
+ private:
107
98
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
108
99
109
100
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
@@ -123,14 +114,56 @@ struct cutlass_2x_gemm {
123
114
cutlass::multiplies, ElementD, float ,
124
115
cutlass::FloatRoundStyle::round_to_nearest>;
125
116
126
- using EVTCompute1 =
117
+ public:
118
+ using EVTCompute =
127
119
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
120
+ using ArgumentType = typename EVTCompute::Arguments;
121
+
122
+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
123
+ torch::Tensor const & b_scales) {
124
+ using ScaleAArgs = typename ScaleA::Arguments;
125
+ using ScaleBArgs = typename ScaleB::Arguments;
126
+
127
+ ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
128
+ ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
129
+
130
+ typename EVTCompute0::Arguments evt0_compute_args{b_args};
131
+
132
+ typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
133
+ return evt_compute_args;
134
+ }
135
+ };
136
+
137
+ template <typename Arch, template <typename > typename ArchGuard,
138
+ typename ElementAB_, typename ElementD_,
139
+ template <typename , typename > typename Epilogue_, typename TileShape,
140
+ typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
141
+ struct cutlass_2x_gemm {
142
+ using ElementAB = ElementAB_;
143
+ using ElementD = ElementD_;
144
+
145
+ using ElementAcc =
146
+ typename std::conditional<std::is_same_v<ElementAB, int8_t >, int32_t ,
147
+ float >::type;
148
+
149
+ using Operator =
150
+ typename std::conditional<std::is_same_v<ElementAB, int8_t >,
151
+ cutlass::arch::OpMultiplyAddSaturate,
152
+ cutlass::arch::OpMultiplyAdd>::type;
153
+
154
+ using OutputTileThreadMap =
155
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
156
+ TileShape, WarpShape, float , 4 , 1 /* epilogue stages */
157
+ >;
158
+
159
+ using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
160
+ using EVTCompute = typename Epilogue::EVTCompute;
128
161
129
162
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
130
163
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
131
164
Stride<int64_t , Int<1 >, Int<0 >>>;
132
165
133
- using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1 >;
166
+ using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute >;
134
167
135
168
// clang-format off
136
169
using RowMajor = typename cutlass::layout::RowMajor;
@@ -153,11 +186,10 @@ struct cutlass_2x_gemm {
153
186
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
154
187
};
155
188
156
- template <typename Gemm>
157
- void cutlass_scaled_mm_dq_dispatcher (torch::Tensor& out, torch::Tensor const & a,
158
- torch::Tensor const & b,
159
- torch::Tensor const & a_scales,
160
- torch::Tensor const & b_scales) {
189
+ template <typename Gemm, typename ... EpilogueArgs>
190
+ void cutlass_gemm_caller (torch::Tensor& out, torch::Tensor const & a,
191
+ torch::Tensor const & b,
192
+ EpilogueArgs&&... epilogue_params) {
161
193
using ElementAB = typename Gemm::ElementAB;
162
194
using ElementD = typename Gemm::ElementD;
163
195
@@ -177,23 +209,14 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
177
209
auto b_ptr = static_cast <ElementAB const *>(b.data_ptr ());
178
210
auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
179
211
180
- auto a_scales_ptr = a_scales.data_ptr <float >();
181
- auto b_scales_ptr = b_scales.data_ptr <float >();
182
-
183
- using ScaleAArgs = typename Gemm::ScaleA::Arguments;
184
- using ScaleBArgs = typename Gemm::ScaleB::Arguments;
185
-
186
- ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
187
- ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
188
-
189
- typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
190
-
191
- typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
192
- evt0_compute_args};
193
212
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
194
213
214
+ using Epilogue = typename Gemm::Epilogue;
215
+ auto evt_args =
216
+ Epilogue::prepare_args (std::forward<EpilogueArgs>(epilogue_params)...);
217
+
195
218
typename Gemm::EVTD::Arguments epilogue_args{
196
- evt1_compute_args ,
219
+ evt_args ,
197
220
d_args,
198
221
};
199
222
@@ -229,10 +252,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
229
252
230
253
} // namespace
231
254
232
- void cutlass_scaled_mm_dq_sm75 (torch::Tensor& out, torch::Tensor const & a,
233
- torch::Tensor const & b,
234
- torch::Tensor const & a_scales,
235
- torch::Tensor const & b_scales) {
255
+ void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
256
+ torch::Tensor const & b,
257
+ torch::Tensor const & a_scales,
258
+ torch::Tensor const & b_scales) {
236
259
TORCH_CHECK (a.dtype () == torch::kInt8 );
237
260
TORCH_CHECK (b.dtype () == torch::kInt8 );
238
261
TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
@@ -243,23 +266,23 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
243
266
using InstructionShape = typename cutlass::gemm::GemmShape<8 , 8 , 16 >;
244
267
245
268
if (out.dtype () == torch::kBFloat16 ) {
246
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
269
+ return cutlass_gemm_caller <cutlass_2x_gemm<
247
270
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::bfloat16_t ,
248
- TileShape, WarpShape, InstructionShape, 2 >>(out, a, b, a_scales,
249
- b_scales);
271
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2 >>(
272
+ out, a, b, a_scales, b_scales);
250
273
} else {
251
274
TORCH_CHECK (out.dtype () == torch::kFloat16 );
252
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
275
+ return cutlass_gemm_caller <cutlass_2x_gemm<
253
276
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::half_t ,
254
- TileShape, WarpShape, InstructionShape, 2 >>(out, a, b, a_scales,
255
- b_scales);
277
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2 >>(
278
+ out, a, b, a_scales, b_scales);
256
279
}
257
280
}
258
281
259
- void cutlass_scaled_mm_dq_sm80 (torch::Tensor& out, torch::Tensor const & a,
260
- torch::Tensor const & b,
261
- torch::Tensor const & a_scales,
262
- torch::Tensor const & b_scales) {
282
+ void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
283
+ torch::Tensor const & b,
284
+ torch::Tensor const & a_scales,
285
+ torch::Tensor const & b_scales) {
263
286
TORCH_CHECK (a.dtype () == torch::kInt8 );
264
287
TORCH_CHECK (b.dtype () == torch::kInt8 );
265
288
TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
@@ -270,23 +293,23 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
270
293
using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
271
294
272
295
if (out.dtype () == torch::kBFloat16 ) {
273
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
296
+ return cutlass_gemm_caller <cutlass_2x_gemm<
274
297
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::bfloat16_t ,
275
- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
276
- b_scales);
298
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
299
+ out, a, b, a_scales, b_scales);
277
300
} else {
278
301
TORCH_CHECK (out.dtype () == torch::kFloat16 );
279
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
302
+ return cutlass_gemm_caller <cutlass_2x_gemm<
280
303
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t , cutlass::half_t ,
281
- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
282
- b_scales);
304
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
305
+ out, a, b, a_scales, b_scales);
283
306
}
284
307
}
285
308
286
- void cutlass_scaled_mm_dq_sm89 (torch::Tensor& out, torch::Tensor const & a,
287
- torch::Tensor const & b,
288
- torch::Tensor const & a_scales,
289
- torch::Tensor const & b_scales) {
309
+ void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
310
+ torch::Tensor const & b,
311
+ torch::Tensor const & a_scales,
312
+ torch::Tensor const & b_scales) {
290
313
using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
291
314
using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
292
315
using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -298,32 +321,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
298
321
TORCH_CHECK (b.dtype () == torch::kInt8 );
299
322
300
323
if (out.dtype () == torch::kBFloat16 ) {
301
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
324
+ return cutlass_gemm_caller <cutlass_2x_gemm<
302
325
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::bfloat16_t ,
303
- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
304
- b_scales);
326
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
327
+ out, a, b, a_scales, b_scales);
305
328
} else {
306
329
assert (out.dtype () == torch::kFloat16 );
307
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
330
+ return cutlass_gemm_caller <cutlass_2x_gemm<
308
331
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::half_t ,
309
- TileShape, WarpShape, InstructionShape, 5 >>(out, a, b, a_scales,
310
- b_scales);
332
+ ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5 >>(
333
+ out, a, b, a_scales, b_scales);
311
334
}
312
335
} else {
313
336
TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
314
337
TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
315
338
316
339
if (out.dtype () == torch::kBFloat16 ) {
317
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
340
+ return cutlass_gemm_caller <cutlass_2x_gemm<
318
341
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
319
- cutlass::bfloat16_t , TileShape, WarpShape, InstructionShape, 5 >>(
320
- out, a, b, a_scales, b_scales);
342
+ cutlass::bfloat16_t , ScaledEpilogue, TileShape, WarpShape,
343
+ InstructionShape, 5 >>( out, a, b, a_scales, b_scales);
321
344
} else {
322
345
TORCH_CHECK (out.dtype () == torch::kFloat16 );
323
- return cutlass_scaled_mm_dq_dispatcher <cutlass_2x_gemm<
346
+ return cutlass_gemm_caller <cutlass_2x_gemm<
324
347
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
325
- cutlass::half_t , TileShape, WarpShape, InstructionShape, 5 >>(
326
- out, a, b, a_scales, b_scales);
348
+ cutlass::half_t , ScaledEpilogue, TileShape, WarpShape,
349
+ InstructionShape, 5 >>( out, a, b, a_scales, b_scales);
327
350
}
328
351
}
329
352
}
0 commit comments