From 8d41b94add00c0f10629211c7be87880ee3e2c7b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 23 Jan 2025 13:35:17 -0500 Subject: [PATCH 1/6] [BACKEND] bump llvm to 1c28b9237382b093f477479c993c80181922ca6a --- cmake/llvm-hash.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index e793a5b69976..65c4c19898ae 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -e2402615a5a76d46a433dfcc1de10b38a1263c9d +1c28b9237382b093f477479c993c80181922ca6a From 32ae54e2e98cf5c19d64120742b5b5f205e026b9 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 21 Jan 2025 13:28:48 -0500 Subject: [PATCH 2/6] catch up to f8 llvm change --- include/triton/Conversion/MLIRTypes.h | 14 ++++++++------ lib/Analysis/Utility.cpp | 15 +++++++++------ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 7 ++++--- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 6 ++++-- .../TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp | 18 ++++++++++-------- .../AccelerateAMDMatmul.cpp | 3 ++- .../lib/TritonAMDGPUTransforms/MfmaGroup.cpp | 15 ++++++++++----- .../DotOpToLLVM/MMAv2.cpp | 16 ++++++++-------- .../DotOpToLLVM/WGMMA.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 12 +++++++----- 11 files changed, 65 insertions(+), 47 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index afa1aa989e6e..3dcfddefe1c0 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -28,15 +28,17 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + type.isBF16() || llvm::isa(type) || + llvm::isa(type) || + llvm::isa(type) || + llvm::isa(type) || llvm::isa(type); } inline bool isFloat8(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + return llvm::isa(type) || + llvm::isa(type) || + llvm::isa(type) || + llvm::isa(type) || llvm::isa(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 01c2aef7d431..f8591e8ec0b0 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -750,14 +750,15 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || - aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || - aElemTy.isF32()))) { + (llvm::isa(aElemTy) || + llvm::isa(aElemTy) || aElemTy.isInteger(8) || + aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && + (llvm::isa(aElemTy) || + llvm::isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -778,8 +779,10 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(elemTy) || + llvm::isa(elemTy) || + llvm::isa(elemTy) || + llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6d7632b1b788..3a6fccb04f55 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,7 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = llvm::isa(AElType) || llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index ce3282dc4521..9521933569ae 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,9 +45,10 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || - eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || - eltType.isF32()) { + if (llvm::isa(eltType) || + llvm::isa(eltType) || + llvm::isa(eltType) || eltType.isF16() || + eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index a171d8933996..e2c205944227 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,10 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || - aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(aElTy) || + llvm::isa(aElTy) || + llvm::isa(aElTy) || + llvm::isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 994a81d58ca1..f1d70fd29081 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1043,17 +1043,19 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || - srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ() || - srcElementType.isFloat8E5M2FNUZ() || - dstElementType.isFloat8E5M2FNUZ()) { + if (llvm::isa(srcElementType) || + llvm::isa(dstElementType) || + llvm::isa(srcElementType) || + llvm::isa(dstElementType) || + llvm::isa(srcElementType) || + llvm::isa(dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = - srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (dstElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E5M2FNUZ())); + srcElementType.isF32() && + !(isaFamily == AMD::ISAFamily::CDNA3 && + (llvm::isa(dstElementType) || + llvm::isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 7ea13142a76c..ec99b5adca81 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,7 +416,8 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isFP8 = llvm::isa(aElemTy) || + llvm::isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 4979ee005b9f..74306ce241ba 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,19 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + if (llvm::isa(dataTypeA) && + llvm::isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 24defdf1975e..36fa804e60f4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -303,17 +303,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E4M3FN()) + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 66a8ff3069f3..2c20942e155f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FN()) { + } else if (llvm::isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index ebcc8c399186..86ad53ab2aa1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -451,8 +451,8 @@ struct FpToFpOpConversion llvm::errs() << "\n"; llvm::report_fatal_error("Unsupported rounding mode for conversion."); } - if (computeCapability < 89 && - (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -475,7 +475,8 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (llvm::isa(dstElementType) || + llvm::isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -512,8 +513,9 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || + (!(computeCapability >= 90 && + (llvm::isa(dstElementType) || + llvm::isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; From 5d77cca664b5da2aec42e81a92fd8edeea4ae7b3 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 23 Jan 2025 14:03:07 -0500 Subject: [PATCH 3/6] move NVVMDialect.h --- third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp | 4 +++- .../lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 +++- third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 4 +++- third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index c333de6162f4..9c9c7c673e5f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -1,8 +1,10 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" +// clang-format off +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" +// clang-format on #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8abb7131ef9e..a54eacde95e5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" +// clang-format off +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" +// clang-format on #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index e1437ee34570..b20e6d348447 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -1,10 +1,12 @@ #include "TargetInfo.h" #include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +// clang-format off +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" +// clang-format on #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "llvm/Support/MathExtras.h" using namespace mlir; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 94c472a4314e..090b441cd397 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -1,6 +1,8 @@ +// clang-format off +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" +// clang-format on #include "Dialect/NVGPU/IR/Dialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" namespace mlir { From c26179d7348a3538861083622a54e8d049546adc Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 23 Jan 2025 17:46:59 -0500 Subject: [PATCH 4/6] format and bump hash again --- cmake/llvm-hash.txt | 2 +- lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 65c4c19898ae..07f4e83593b5 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -1c28b9237382b093f477479c993c80181922ca6a +c118864223c6309378cd704f3406533474c2759f diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3a6fccb04f55..2e1e520718bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,7 +632,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = llvm::isa(AElType) || llvm::isa(AElType); + bool isNativeFP8 = llvm::isa(AElType) || + llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || From 47552ddfd9a4218ea285693d2274d1459132780e Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 27 Jan 2025 13:05:48 -0500 Subject: [PATCH 5/6] address comments --- include/triton/Conversion/MLIRTypes.h | 16 ++++++---------- lib/Analysis/Utility.cpp | 15 ++++++--------- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 +-- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 7 +++---- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 6 ++---- .../TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp | 13 +++++-------- .../AccelerateAMDMatmul.cpp | 3 +-- .../ElementwiseOpToLLVM.cpp | 6 ++---- 8 files changed, 26 insertions(+), 43 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index 3dcfddefe1c0..dd8d4be4c259 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -26,19 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } +inline bool isFloat8(Type type) { + return isa(type); +} + inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || type.isBF16() || llvm::isa(type) || - llvm::isa(type) || - llvm::isa(type) || - llvm::isa(type) || llvm::isa(type); -} - -inline bool isFloat8(Type type) { - return llvm::isa(type) || - llvm::isa(type) || - llvm::isa(type) || - llvm::isa(type) || llvm::isa(type); + isFloat8(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f8591e8ec0b0..8af5369fcdf1 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -750,15 +750,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (llvm::isa(aElemTy) || - llvm::isa(aElemTy) || aElemTy.isInteger(8) || - aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { + (llvm::isa(aElemTy) || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (llvm::isa(aElemTy) || - llvm::isa(aElemTy)) && + (llvm::isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -779,10 +778,8 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = llvm::isa(elemTy) || - llvm::isa(elemTy) || - llvm::isa(elemTy) || - llvm::isa(elemTy); + bool isFP8 = llvm::isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 2e1e520718bd..f32891aceb5f 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -632,8 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = llvm::isa(AElType) || - llvm::isa(AElType); + bool isNativeFP8 = llvm::isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 9521933569ae..724cf3512c17 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -45,10 +45,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (llvm::isa(eltType) || - llvm::isa(eltType) || - llvm::isa(eltType) || eltType.isF16() || - eltType.isBF16() || eltType.isF32()) { + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index e2c205944227..f49a2555c7f9 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,10 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = llvm::isa(aElTy) || - llvm::isa(aElTy) || - llvm::isa(aElTy) || - llvm::isa(aElTy); + bool isFP8 = llvm::isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index f1d70fd29081..74179823ad72 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1043,19 +1043,16 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (llvm::isa(srcElementType) || - llvm::isa(dstElementType) || - llvm::isa(srcElementType) || - llvm::isa(dstElementType) || - llvm::isa(srcElementType) || - llvm::isa(dstElementType)) { + if (llvm::isa( + srcElementType) || + llvm::isa( + dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (llvm::isa(dstElementType) || - llvm::isa(dstElementType))); + (llvm::isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index ec99b5adca81..005089aaf7ac 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,8 +416,7 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = llvm::isa(aElemTy) || - llvm::isa(aElemTy); + bool isFP8 = llvm::isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 86ad53ab2aa1..7b641349385e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -475,8 +475,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (llvm::isa(dstElementType) || - llvm::isa(dstElementType)) { + if (llvm::isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -514,8 +513,7 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && (!(computeCapability >= 90 && - (llvm::isa(dstElementType) || - llvm::isa(dstElementType))) || + (llvm::isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; From 226fb65bcb1ba5ec5813febfcea0bf1a481624fc Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 27 Jan 2025 13:13:52 -0500 Subject: [PATCH 6/6] Added new AMDGPU gfx950 features in Rocdl dialect. New MFMA variants. ds.read.r*.* and global.load.lds --- .../amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp | 4 +--- .../lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 +--- .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp | 9 +++++---- .../nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 4 +--- third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 4 +--- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 9c9c7c673e5f..c333de6162f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -1,10 +1,8 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" -// clang-format off -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" -// clang-format on #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a54eacde95e5..8abb7131ef9e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,10 +1,8 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" -// clang-format off -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" -// clang-format on #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index 74aeb1062508..eff746d2a0fa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -61,7 +61,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, Type scaleAType, Type scaleBType) { if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { - if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { return mxfpKind::mxf4nvf4; } return mxfpKind::mxf4; @@ -102,9 +103,9 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 1; if (type.isF32()) return 2; - if (type.isFloat8E4M3FN()) + if (llvm::isa(type)) return 0; - if (type.isFloat8E5M2()) + if (llvm::isa(type)) return 1; llvm_unreachable("Unsupported type."); }; @@ -227,7 +228,7 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, opcode += "f16"; else if (srcElementTy.isF32()) opcode += "tf32"; - else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2()) + else if (llvm::isa(srcElementTy)) opcode += "f8f6f4"; else assert(0 && "Unsupported type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index b20e6d348447..e1437ee34570 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -1,12 +1,10 @@ #include "TargetInfo.h" #include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" -// clang-format off -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" -// clang-format on #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "llvm/Support/MathExtras.h" using namespace mlir; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 090b441cd397..94c472a4314e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -1,8 +1,6 @@ -// clang-format off -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "Utility.h" -// clang-format on #include "Dialect/NVGPU/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" namespace mlir {