From c422f4c17c697a3dc2509580198538d5c95f2ffd Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Sun, 26 Jan 2025 11:52:28 +0100 Subject: [PATCH] Replace isF...() LLVM API calls with the corresponding isa<...>() The isF...() methods have been removed in the main LLVM branch: https://github.com/llvm/llvm-project/pull/123326 --- include/triton/Conversion/MLIRTypes.h | 12 ++++++------ lib/Analysis/Utility.cpp | 9 +++++---- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 3 ++- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 6 +++--- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 4 ++-- .../TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp | 15 ++++++++------- .../AccelerateAMDMatmul.cpp | 3 ++- .../amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp | 14 +++++++++----- .../TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp | 16 ++++++++-------- .../TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 4 ++-- .../ElementwiseOpToLLVM.cpp | 9 +++++---- 11 files changed, 52 insertions(+), 43 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index afa1aa989e6e..1fa8543a1415 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -28,15 +28,15 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || - type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + type.isBF16() || isa(type) || + isa(type) || isa(type) || + isa(type) || isa(type); } inline bool isFloat8(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || - type.isFloat8E5M2FNUZ(); + return isa(type) || isa(type) || + isa(type) || isa(type) || + isa(type); } inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 9527096bed8f..c09de6cab00a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -732,14 +732,14 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || + (isa(aElemTy) || isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && + (isa(aElemTy) || isa(aElemTy)) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -760,8 +760,9 @@ bool supportMMA(Value value, int version) { cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. - bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + bool isFP8 = isa(elemTy) || isa(elemTy) || + isa(elemTy) || + isa(elemTy); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 5bc3ae9cb713..78405cb06adb 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -344,7 +344,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); + bool isNativeFP8 = + isa(AElType) || isa(AElType); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 46785b1e89b8..3bc5d1057c2d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -44,9 +44,9 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || - eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || - eltType.isF32()) { + if (isa(eltType) || isa(eltType) || + isa(eltType) || eltType.isF16() || + eltType.isBF16() || eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 2267e3b7c251..f7fb6278b52f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); - bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || - aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool isFP8 = isa(aElTy) || isa(aElTy) || + isa(aElTy) || isa(aElTy); bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 4dfaf4881b67..01466699320e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -985,17 +985,18 @@ struct FpToFpOpConversion return outVals; } size_t numElements = 4; - if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() || - srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ() || - srcElementType.isFloat8E5M2FNUZ() || - dstElementType.isFloat8E5M2FNUZ()) { + if (isa(srcElementType) || + isa(dstElementType) || + isa(srcElementType) || + isa(dstElementType) || + isa(srcElementType) || + isa(dstElementType)) { numElements = 2; } bool useFP16IntermediateSrc = srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 && - (dstElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E5M2FNUZ())); + (isa(dstElementType) || + isa(dstElementType))); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; Type dstType = isDstFP32 ? f16_ty : dstElementType; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 7ea13142a76c..f7057266d7cd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -416,7 +416,8 @@ class BlockedToMFMA : public OpRewritePattern { // store instructions, except for fp8 matmul kernels due to regression // TODO (lixun): investigate the regression and enable this feature again auto aElemTy = mfmaInstr.getElementTypeA(); - bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isFP8 = + isa(aElemTy) || isa(aElemTy); bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 4979ee005b9f..59c102c1ed0d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -20,19 +20,23 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) { return MfmaTypeId::I8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Fp8Fp8TyId; } - if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Fp8Bf8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Bf8Fp8TyId; } - if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { + if (isa(dataTypeA) && + isa(dataTypeB)) { return MfmaTypeId::Bf8Bf8TyId; } - if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + if (isa(dataTypeA) && isa(dataTypeB)) { return MfmaTypeId::Fp16TyId; } llvm_unreachable("Unsupported input argument type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c5ec00097d93..06901280d001 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E5M2()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FN()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E4M3FN()) + if (isa(aTy.getElementType()) && + isa(bTy.getElementType())) return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 6b23915fbbab..5b48b841869d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -57,9 +57,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { + } else if (isa(aTy)) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FN()) { + } else if (isa(aTy)) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index d489d0a1b1f4..8e37a4ad10fe 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -467,7 +467,7 @@ struct FpToFpOpConversion llvm::report_fatal_error("Unsupported rounding mode for conversion."); } if (computeCapability < 89 && - (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { + (isa(srcTy) || isa(dstTy))) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -489,7 +489,8 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { + if (isa(dstElementType) || + isa(dstElementType)) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -526,8 +527,8 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || + (!(computeCapability >= 90 && (isa(dstElementType) || + isa(dstElementType))) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;