Skip to content

Commit fba2dd6

Browse files
committed
address comments
1 parent 41cc335 commit fba2dd6

File tree

8 files changed

+26
-43
lines changed

8 files changed

+26
-43
lines changed

include/triton/Conversion/MLIRTypes.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
2626
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
2727
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

29+
inline bool isFloat8(Type type) {
30+
return isa<Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3FNUZType,
31+
Float8E5M2Type, Float8E5M2FNUZType>(type);
32+
}
33+
2934
inline bool isFloat(Type type) {
3035
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
3136
type.isBF16() || llvm::isa<Float8E4M3B11FNUZType>(type) ||
32-
llvm::isa<Float8E4M3FNType>(type) ||
33-
llvm::isa<Float8E4M3FNUZType>(type) ||
34-
llvm::isa<Float8E5M2Type>(type) || llvm::isa<Float8E5M2FNUZType>(type);
35-
}
36-
37-
inline bool isFloat8(Type type) {
38-
return llvm::isa<Float8E4M3B11FNUZType>(type) ||
39-
llvm::isa<Float8E4M3FNType>(type) ||
40-
llvm::isa<Float8E4M3FNUZType>(type) ||
41-
llvm::isa<Float8E5M2Type>(type) || llvm::isa<Float8E5M2FNUZType>(type);
37+
isFloat8(type);
4238
}
4339

4440
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

lib/Analysis/Utility.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -732,15 +732,14 @@ bool supportMMA(triton::DotOp op, int version) {
732732
return false;
733733
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
734734
retShapePerCTA[rank - 1] % 8 == 0 &&
735-
(llvm::isa<Float8E5M2Type>(aElemTy) ||
736-
llvm::isa<Float8E4M3FNType>(aElemTy) || aElemTy.isInteger(8) ||
737-
aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) {
735+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
736+
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
737+
aElemTy.isF32()))) {
738738
return false;
739739
}
740740
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
741741
if (op.getMaxNumImpreciseAcc() < 32 &&
742-
(llvm::isa<Float8E5M2Type>(aElemTy) ||
743-
llvm::isa<Float8E4M3FNType>(aElemTy)) &&
742+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
744743
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
745744
return false;
746745
}
@@ -761,10 +760,8 @@ bool supportMMA(Value value, int version) {
761760
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
762761
// FP8 is not natively supported on all mma versions but it can always be
763762
// promoted to fp16 therefore we can always support it.
764-
bool isFP8 = llvm::isa<Float8E5M2Type>(elemTy) ||
765-
llvm::isa<Float8E4M3FNType>(elemTy) ||
766-
llvm::isa<Float8E5M2FNUZType>(elemTy) ||
767-
llvm::isa<Float8E4M3FNUZType>(elemTy);
763+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
764+
Float8E4M3FNUZType>(elemTy);
768765
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
769766
(elemTy.isF32() && version >= 2) ||
770767
(elemTy.isInteger(8) && version >= 2);

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
344344
NvidiaMmaEncodingAttr mmaLayout =
345345
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
346346
if (mmaLayout) {
347-
bool isNativeFP8 = llvm::isa<Float8E5M2Type>(AElType) ||
348-
llvm::isa<Float8E4M3FNType>(AElType);
347+
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
349348
// promote operands for sm < 89 since fp8 mma is not natively supported
350349
// promote operands for sm >= 90 when mma is not v3
351350
if (!isNativeFP8 ||

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
4444
SmallVector<unsigned> validN;
4545

4646
// MMAv3 with larger instruction shape is preferred.
47-
if (llvm::isa<Float8E5M2Type>(eltType) ||
48-
llvm::isa<Float8E4M3FNType>(eltType) ||
49-
llvm::isa<Float8E4M3FNUZType>(eltType) || eltType.isF16() ||
50-
eltType.isBF16() || eltType.isF32()) {
47+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FNUZType>(
48+
eltType) ||
49+
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
5150
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
5251
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
5352
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
7777
const auto &d = getD();
7878
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
7979
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
80-
bool isFP8 = llvm::isa<Float8E5M2Type>(aElTy) ||
81-
llvm::isa<Float8E4M3FNType>(aElTy) ||
82-
llvm::isa<Float8E5M2FNUZType>(aElTy) ||
83-
llvm::isa<Float8E4M3FNUZType>(aElTy);
80+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
81+
Float8E4M3FNUZType>(aElTy);
8482
bool accFP32 =
8583
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
8684
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,19 +1106,16 @@ struct FpToFpOpConversion
11061106
return outVals;
11071107
}
11081108
size_t numElements = 4;
1109-
if (llvm::isa<Float8E4M3FNType>(srcElementType) ||
1110-
llvm::isa<Float8E4M3FNType>(dstElementType) ||
1111-
llvm::isa<Float8E4M3FNUZType>(srcElementType) ||
1112-
llvm::isa<Float8E4M3FNUZType>(dstElementType) ||
1113-
llvm::isa<Float8E5M2FNUZType>(srcElementType) ||
1114-
llvm::isa<Float8E5M2FNUZType>(dstElementType)) {
1109+
if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1110+
srcElementType) ||
1111+
llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1112+
dstElementType)) {
11151113
numElements = 2;
11161114
}
11171115
bool useFP16IntermediateSrc =
11181116
srcElementType.isF32() &&
11191117
!(isaFamily == AMD::ISAFamily::CDNA3 &&
1120-
(llvm::isa<Float8E4M3FNUZType>(dstElementType) ||
1121-
llvm::isa<Float8E5M2FNUZType>(dstElementType)));
1118+
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
11221119
bool isDstFP32 = dstElementType.isF32();
11231120
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
11241121
Type dstType = isDstFP32 ? f16_ty : dstElementType;

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
416416
// store instructions, except for fp8 matmul kernels due to regression
417417
// TODO (lixun): investigate the regression and enable this feature again
418418
auto aElemTy = mfmaInstr.getElementTypeA();
419-
bool isFP8 = llvm::isa<Float8E5M2FNUZType>(aElemTy) ||
420-
llvm::isa<Float8E4M3FNUZType>(aElemTy);
419+
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
421420
bool isTransposed = isChainDot(dotOp) || !isFP8;
422421
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
423422
oldRetType.getContext(),

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,7 @@ struct FpToFpOpConversion
489489
auto dstElementType = getElementType(op.getResult());
490490
auto roundingMode = op.getRounding();
491491

492-
if (llvm::isa<Float8E5M2Type>(dstElementType) ||
493-
llvm::isa<Float8E4M3FNType>(dstElementType)) {
492+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)) {
494493
assert(roundingMode.has_value() &&
495494
"Rounding mode must be specified for convertsions to fp8");
496495

@@ -528,8 +527,7 @@ struct FpToFpOpConversion
528527
bool useFP16IntermediateSrc =
529528
srcElementType.isF32() &&
530529
(!(computeCapability >= 90 &&
531-
(llvm::isa<Float8E4M3FNType>(dstElementType) ||
532-
llvm::isa<Float8E5M2Type>(dstElementType))) ||
530+
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) ||
533531
roundingMode.value() == RoundingMode::RTZ);
534532
bool isDstFP32 = dstElementType.isF32();
535533
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;

0 commit comments

Comments
 (0)