|
38 | 38 | #include "cutlass/numeric_types.h"
|
39 | 39 | #include "cutlass/arch/mma.h"
|
40 | 40 | #include "cutlass/gemm/warp/mma_tensor_op.h"
|
| 41 | +#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h" |
41 | 42 | #include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
|
42 | 43 | #include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
43 | 44 |
|
@@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<
|
227 | 228 |
|
228 | 229 | /////////////////////////////////////////////////////////////////////////////////////////////////
|
229 | 230 |
|
| 231 | +/// Partial Specialization - inputs are mixed types - uses wider datatype internally. |
| 232 | +/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32) |
| 233 | +template < |
| 234 | + /// Shape of one matrix production operation (concept: GemmShape) |
| 235 | + typename WarpShape_, |
| 236 | + /// Element type of A matrix |
| 237 | + typename ElementA, |
| 238 | + /// Layout of A matrix (concept: MatrixLayout) |
| 239 | + typename LayoutA, |
| 240 | + /// Element type of B matrix |
| 241 | + typename ElementB, |
| 242 | + /// Layout of B matrix (concept: MatrixLayout) |
| 243 | + typename LayoutB, |
| 244 | + /// Element type of C matrix |
| 245 | + typename ElementC, |
| 246 | + /// Layout of C matrix (concept: MatrixLayout) |
| 247 | + typename LayoutC, |
| 248 | + /// Number of partitions along K dimension |
| 249 | + int PartitionsK, |
| 250 | + /// Store the accumulators in row major or column major. Row major is used |
| 251 | + /// when output layout is interleaved. |
| 252 | + bool AccumulatorsInRowMajor> |
| 253 | +struct DefaultMmaTensorOp< |
| 254 | + WarpShape_, |
| 255 | + GemmShape<16, 8, 16>, // InstructionShape |
| 256 | + ElementA, // Element type of A matrix in Global Memory |
| 257 | + LayoutA, // Layout of A matrix in Global Memory |
| 258 | + ElementB, // Element type of B matrix in Global Memory |
| 259 | + LayoutB, // Layout of B matrix in Global Memory |
| 260 | + ElementC, // Element type of C matrix in Global Memory |
| 261 | + LayoutC, // Layout of C matrix in Global Memory |
| 262 | + arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype |
| 263 | + PartitionsK, AccumulatorsInRowMajor> { |
| 264 | + |
| 265 | + |
| 266 | + // Check if the ElementA and ElementB are of different data types |
| 267 | + static_assert(!platform::is_same<ElementA, ElementB>::value, |
| 268 | + "DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type"); |
| 269 | + |
| 270 | + // Data type used for internal computation - use the wider of the two data types for mma.sync operands |
| 271 | + using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)), |
| 272 | + ElementA, ElementB>::type; |
| 273 | + |
| 274 | + // Operand datatypes in the internal MMA instruction - use the wider of the two data types |
| 275 | + using MmaElementA = ElementOperand; |
| 276 | + using MmaElementB = ElementOperand; |
| 277 | + using MmaElementC = ElementC; |
| 278 | + |
| 279 | + // Uses |
| 280 | + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< |
| 281 | + cutlass::arch::Mma< |
| 282 | + GemmShape<16, 8, 16>, |
| 283 | + 32, |
| 284 | + MmaElementA, cutlass::layout::RowMajor, |
| 285 | + MmaElementB, cutlass::layout::ColumnMajor, |
| 286 | + MmaElementC, cutlass::layout::RowMajor, |
| 287 | + arch::OpMultiplyAdd |
| 288 | + >, |
| 289 | + cutlass::MatrixShape<1, 1> >; |
| 290 | + |
| 291 | + // Define the warp-level tensor op |
| 292 | + using Type = cutlass::gemm::warp::MmaMixedInputTensorOp< |
| 293 | + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, |
| 294 | + Policy, PartitionsK, AccumulatorsInRowMajor>; |
| 295 | +}; |
| 296 | + |
230 | 297 | } // namespace warp
|
231 | 298 | } // namespace gemm
|
232 | 299 | } // namespace cutlass
|
|
0 commit comments