66//
77// ===----------------------------------------------------------------------===//
88
9- #include " mlir/Dialect/Arith/Transforms/Passes.h"
10-
119#include " mlir/Dialect/Arith/IR/Arith.h"
10+ #include " mlir/Dialect/Arith/Transforms/Passes.h"
1211#include " mlir/Dialect/Vector/IR/VectorOps.h"
12+ #include " mlir/IR/BuiltinTypeInterfaces.h"
1313#include " mlir/IR/ImplicitLocOpBuilder.h"
1414#include " mlir/IR/TypeUtilities.h"
1515#include " mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
3131 return rewriter.create <arith::ConstantOp>(
3232 loc, DenseElementsAttr::get (shapedTy, attr));
3333 }
34-
3534 return rewriter.create <arith::ConstantOp>(loc, attr);
3635}
3736
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
357356 f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
358357 Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
359358 if (resultETy.getIntOrFloatBitWidth () < 32 ) {
360- result = b.create <arith::TruncFOp>(resultTy, result);
359+ result = b.create <arith::TruncFOp>(resultTy, result, nullptr ,
360+ op.getFastmathAttr ());
361361 } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
362- result = b.create <arith::ExtFOp>(resultTy, result);
362+ result = b.create <arith::ExtFOp>(resultTy, result, op. getFastmathAttr () );
363363 }
364364 rewriter.replaceOp (op, result);
365365 return success ();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
395395 Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396396
397397 if (operandETy.getIntOrFloatBitWidth () < 32 ) {
398- operand = b.create <arith::ExtFOp>(f32Ty, operand);
398+ operand = b.create <arith::ExtFOp>(f32Ty, operand, op. getFastmathAttr () );
399399 } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
400- operand = b.create <arith::TruncFOp>(f32Ty, operand);
400+ operand = b.create <arith::TruncFOp>(
401+ f32Ty, operand, op.getRoundingmodeAttr (), op.getFastmathAttr ());
401402 }
402403 Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
403404 Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
409410 }
410411};
411412
413+ struct ScalingExtFOpConverter : public OpRewritePattern <arith::ScalingExtFOp> {
414+ using OpRewritePattern::OpRewritePattern;
415+ LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
416+ PatternRewriter &rewriter) const final {
417+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
418+ Value inputOperand = op.getIn ();
419+ Value scaleOperand = op.getScale ();
420+ Type scaleTy = scaleOperand.getType ();
421+ Type scaleETy = getElementTypeOrSelf (scaleOperand);
422+ // allow implicit exponent extraction from 16/32 bits floats
423+ if (scaleETy.getIntOrFloatBitWidth () >= 16 ) {
424+ scaleETy = b.getF8E8M0Type ();
425+ scaleTy = cloneToShapedType (scaleTy, scaleETy);
426+ scaleOperand = b.create <arith::TruncFOp>(scaleTy, scaleOperand, nullptr ,
427+ op.getFastmathAttr ());
428+ }
429+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
430+ return rewriter.notifyMatchFailure (
431+ op, " scaling_extf is using scales of type which can not be converted "
432+ " to f8E8M0FNU" );
433+ }
434+ Type resultTy = op.getType ();
435+ // extf on scale will essentially create floating point number
436+ // of type resulTy that is 2^scale and will also propagate NaNs
437+ Value scaleExt =
438+ b.create <arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr ());
439+ Value inputExt =
440+ b.create <arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr ());
441+ Value result =
442+ b.create <arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr ());
443+ rewriter.replaceOp (op, result);
444+ return success ();
445+ }
446+ };
447+
448+ /*
449+ Expands arith.ScalingTruncFOp(in, scale) into
450+ scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
451+ result = arith.truncf(in / (2^scale))
452+ */
453+ struct ScalingTruncFOpConverter
454+ : public OpRewritePattern<arith::ScalingTruncFOp> {
455+ using OpRewritePattern::OpRewritePattern;
456+ LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
457+ PatternRewriter &rewriter) const final {
458+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
459+ Value inputOperand = op.getIn ();
460+ Value scaleOperand = op.getScale ();
461+ Type scaleTy = scaleOperand.getType ();
462+ Type scaleETy = getElementTypeOrSelf (scaleOperand);
463+ // allow implicit exponent extraction from 16/32 bits floats
464+ if (scaleETy.getIntOrFloatBitWidth () >= 16 ) {
465+ scaleETy = b.getF8E8M0Type ();
466+ scaleTy = cloneToShapedType (scaleTy, scaleETy);
467+ scaleOperand = b.create <arith::TruncFOp>(scaleTy, scaleOperand, nullptr ,
468+ op.getFastmathAttr ());
469+ }
470+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
471+ return rewriter.notifyMatchFailure (
472+ op, " scaling_truncf is using scales type which can not be converted "
473+ " to f8E8M0FNU" );
474+ }
475+ Type resultTy = op.getType ();
476+ Type inputTy = inputOperand.getType ();
477+ // this will create a floating point number of type
478+ // inputTy that is 2^scale and will also propagate NaNs
479+ scaleOperand =
480+ b.create <arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr ());
481+ Value result = b.create <arith::DivFOp>(inputOperand, scaleOperand,
482+ op.getFastmathAttr ());
483+ Value resultCast = b.create <arith::TruncFOp>(
484+ resultTy, result, op.getRoundingmodeAttr (), op.getFastmathAttr ());
485+ rewriter.replaceOp (op, resultCast);
486+ return success ();
487+ }
488+ };
489+
412490struct ArithExpandOpsPass
413491 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
414492 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
432510 arith::MaximumFOp,
433511 arith::MinimumFOp,
434512 arith::MaxNumFOp,
435- arith::MinNumFOp
513+ arith::MinNumFOp,
514+ arith::ScalingExtFOp,
515+ arith::ScalingTruncFOp
436516 >();
437517
438518 if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
492572 patterns.getContext ());
493573}
494574
575+ void mlir::arith::populateExpandScalingExtTruncPatterns (
576+ RewritePatternSet &patterns) {
577+ patterns.add <ScalingExtFOpConverter, ScalingTruncFOpConverter>(
578+ patterns.getContext ());
579+ }
580+
495581void mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
496582 populateCeilFloorDivExpandOpsPatterns (patterns);
583+ populateExpandScalingExtTruncPatterns (patterns);
497584 // clang-format off
498585 patterns.add <
499586 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
503590 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
504591 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
505592 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
506- MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
593+ MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
507594 >(patterns.getContext ());
508595 // clang-format on
509596}
0 commit comments