@@ -750,15 +750,14 @@ bool supportMMA(triton::DotOp op, int version) {
750750 return false ;
751751 if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2 ] % 64 == 0 &&
752752 retShapePerCTA[rank - 1 ] % 8 == 0 &&
753- (llvm::isa<Float8E5M2Type>(aElemTy) ||
754- llvm::isa<Float8E4M3FNType>( aElemTy) || aElemTy.isInteger ( 8 ) ||
755- aElemTy.isF16 () || aElemTy. isBF16 () || aElemTy. isF32 ()))) {
753+ (llvm::isa<Float8E5M2Type, Float8E4M3FNType >(aElemTy) ||
754+ aElemTy. isInteger ( 8 ) || aElemTy. isF16 ( ) || aElemTy.isBF16 ( ) ||
755+ aElemTy.isF32 ()))) {
756756 return false ;
757757 }
758758 // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
759759 if (op.getMaxNumImpreciseAcc () < 32 &&
760- (llvm::isa<Float8E5M2Type>(aElemTy) ||
761- llvm::isa<Float8E4M3FNType>(aElemTy)) &&
760+ (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
762761 cast<RankedTensorType>(op.getType ()).getElementType ().isF32 ()) {
763762 return false ;
764763 }
@@ -779,10 +778,8 @@ bool supportMMA(Value value, int version) {
779778 cast<triton::gpu::TensorOrMemDesc>(value.getType ()).getElementType ();
780779 // FP8 is not natively supported on all mma versions but it can always be
781780 // promoted to fp16 therefore we can always support it.
782- bool isFP8 = llvm::isa<Float8E5M2Type>(elemTy) ||
783- llvm::isa<Float8E4M3FNType>(elemTy) ||
784- llvm::isa<Float8E5M2FNUZType>(elemTy) ||
785- llvm::isa<Float8E4M3FNUZType>(elemTy);
781+ bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
782+ Float8E4M3FNUZType>(elemTy);
786783 return isFP8 || elemTy.isF16 () || elemTy.isBF16 () ||
787784 (elemTy.isF32 () && version >= 2 ) ||
788785 (elemTy.isInteger (8 ) && version >= 2 );
0 commit comments