Skip to content

Commit 7d8317a

Browse files
author
Manish Gupta
authored
Support for Mixed Input TensorOp (PaddlePaddle#1084)
* Passing warp-level mixed input F16*(S8/U8) tests * passing device-level mixed input F16*(S8/U8) tests * add to profiler - I8 (111 TFLOPs), U (123 TFLOPs) * fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs) * Speedup reference compilation (REVERT THIS COMMIT) * wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s) * Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs * BF16 * S8 (142 TFLOPs) * Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16] * rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast * Add device-level test and profiler support for upcast on operand A * Move shfl before the cvt and reduce #shfls by 1/2 * fix smem_usage calculation for mixed_input types * uncomment the stuff (getting ready for merge) * profiler changes and mixed-input reference * mixed input reference are in a new file * use platform instead of std * comments and typo only * Use CreateGemmOperator and delete CreateMixedInputGemmOperator * copyright for new files * rebase follow-up
1 parent 5cd735c commit 7d8317a

26 files changed

+2065
-14
lines changed

include/cutlass/arch/mma.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {};
6868

6969
/////////////////////////////////////////////////////////////////////////////////////////////////
7070

71+
/// Tag indicating the input data types are mixed and the narrower type is
72+
/// upcasted to the wider type
73+
struct OpMultiplyAddMixedInputUpcast {};
74+
75+
/////////////////////////////////////////////////////////////////////////////////////////////////
76+
7177
/// Tag indicating the input is converted to 2 (big and small) TF32 components
7278
// Perform 3xTF32 or 4xTF32 for every F32 output element
7379
struct OpMultiplyAddFastF32 {};
7480

81+
/////////////////////////////////////////////////////////////////////////////////////////////////
82+
7583
/// Tag indicating the input is converted to 2 (big and small) TF32 components
7684
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
7785
struct OpMultiplyAddComplexFastF32 {};
7886

87+
/////////////////////////////////////////////////////////////////////////////////////////////////
88+
7989
/// Helper for determining whether staged accumulation should be used for a given operator
8090
template <typename Operator>
8191
struct UseStagedAccumulation {

include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "cutlass/numeric_types.h"
3939
#include "cutlass/arch/mma.h"
4040
#include "cutlass/gemm/warp/mma_tensor_op.h"
41+
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"
4142
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
4243
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
4344

@@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<
227228

228229
/////////////////////////////////////////////////////////////////////////////////////////////////
229230

231+
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
232+
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
233+
template <
234+
/// Shape of one matrix production operation (concept: GemmShape)
235+
typename WarpShape_,
236+
/// Element type of A matrix
237+
typename ElementA,
238+
/// Layout of A matrix (concept: MatrixLayout)
239+
typename LayoutA,
240+
/// Element type of B matrix
241+
typename ElementB,
242+
/// Layout of B matrix (concept: MatrixLayout)
243+
typename LayoutB,
244+
/// Element type of C matrix
245+
typename ElementC,
246+
/// Layout of C matrix (concept: MatrixLayout)
247+
typename LayoutC,
248+
/// Number of partitions along K dimension
249+
int PartitionsK,
250+
/// Store the accumulators in row major or column major. Row major is used
251+
/// when output layout is interleaved.
252+
bool AccumulatorsInRowMajor>
253+
struct DefaultMmaTensorOp<
254+
WarpShape_,
255+
GemmShape<16, 8, 16>, // InstructionShape
256+
ElementA, // Element type of A matrix in Global Memory
257+
LayoutA, // Layout of A matrix in Global Memory
258+
ElementB, // Element type of B matrix in Global Memory
259+
LayoutB, // Layout of B matrix in Global Memory
260+
ElementC, // Element type of C matrix in Global Memory
261+
LayoutC, // Layout of C matrix in Global Memory
262+
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
263+
PartitionsK, AccumulatorsInRowMajor> {
264+
265+
266+
// Check if the ElementA and ElementB are of different data types
267+
static_assert(!platform::is_same<ElementA, ElementB>::value,
268+
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
269+
270+
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
271+
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
272+
ElementA, ElementB>::type;
273+
274+
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
275+
using MmaElementA = ElementOperand;
276+
using MmaElementB = ElementOperand;
277+
using MmaElementC = ElementC;
278+
279+
// Uses
280+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
281+
cutlass::arch::Mma<
282+
GemmShape<16, 8, 16>,
283+
32,
284+
MmaElementA, cutlass::layout::RowMajor,
285+
MmaElementB, cutlass::layout::ColumnMajor,
286+
MmaElementC, cutlass::layout::RowMajor,
287+
arch::OpMultiplyAdd
288+
>,
289+
cutlass::MatrixShape<1, 1> >;
290+
291+
// Define the warp-level tensor op
292+
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
293+
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
294+
Policy, PartitionsK, AccumulatorsInRowMajor>;
295+
};
296+
230297
} // namespace warp
231298
} // namespace gemm
232299
} // namespace cutlass

0 commit comments

Comments
 (0)