@@ -732,15 +732,14 @@ bool supportMMA(triton::DotOp op, int version) {
732732 return false ;
733733 if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2 ] % 64 == 0 &&
734734 retShapePerCTA[rank - 1 ] % 8 == 0 &&
735- (llvm::isa<Float8E5M2Type>(aElemTy) ||
736- llvm::isa<Float8E4M3FNType>( aElemTy) || aElemTy.isInteger ( 8 ) ||
737- aElemTy.isF16 () || aElemTy. isBF16 () || aElemTy. isF32 ()))) {
735+ (llvm::isa<Float8E5M2Type, Float8E4M3FNType >(aElemTy) ||
736+ aElemTy. isInteger ( 8 ) || aElemTy. isF16 ( ) || aElemTy.isBF16 ( ) ||
737+ aElemTy.isF32 ()))) {
738738 return false ;
739739 }
740740 // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
741741 if (op.getMaxNumImpreciseAcc () < 32 &&
742- (llvm::isa<Float8E5M2Type>(aElemTy) ||
743- llvm::isa<Float8E4M3FNType>(aElemTy)) &&
742+ (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
744743 cast<RankedTensorType>(op.getType ()).getElementType ().isF32 ()) {
745744 return false ;
746745 }
@@ -761,10 +760,8 @@ bool supportMMA(Value value, int version) {
761760 cast<triton::gpu::TensorOrMemDesc>(value.getType ()).getElementType ();
762761 // FP8 is not natively supported on all mma versions but it can always be
763762 // promoted to fp16 therefore we can always support it.
764- bool isFP8 = llvm::isa<Float8E5M2Type>(elemTy) ||
765- llvm::isa<Float8E4M3FNType>(elemTy) ||
766- llvm::isa<Float8E5M2FNUZType>(elemTy) ||
767- llvm::isa<Float8E4M3FNUZType>(elemTy);
763+ bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
764+ Float8E4M3FNUZType>(elemTy);
768765 return isFP8 || elemTy.isF16 () || elemTy.isBF16 () ||
769766 (elemTy.isF32 () && version >= 2 ) ||
770767 (elemTy.isInteger (8 ) && version >= 2 );
0 commit comments