|
2 | 2 | #include "TargetInfo.h" |
3 | 3 | #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" |
4 | 4 | #include "Utility.h" |
| 5 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
5 | 6 | #include "mlir/Support/LLVM.h" |
6 | 7 | #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" |
7 | 8 | #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" |
@@ -350,26 +351,10 @@ struct FpToFpOpConversion |
350 | 351 | : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), |
351 | 352 | computeCapability(computeCapability) {} |
352 | 353 |
|
353 | | - static Value convertBf16ToFp32(Location loc, |
354 | | - ConversionPatternRewriter &rewriter, |
355 | | - const Value &v) { |
356 | | - PTXBuilder builder; |
357 | | - auto &cvt = *builder.create("cvt.f32.bf16"); |
358 | | - auto res = builder.newOperand("=r"); |
359 | | - auto operand = builder.newOperand(v, "h"); |
360 | | - cvt(res, operand); |
361 | | - return builder.launch(rewriter, loc, f32_ty, false); |
362 | | - } |
363 | | - |
364 | 354 | static Value convertFp16ToFp32(Location loc, |
365 | 355 | ConversionPatternRewriter &rewriter, |
366 | 356 | const Value &v) { |
367 | | - PTXBuilder builder; |
368 | | - auto &cvt = *builder.create("cvt.f32.f16"); |
369 | | - auto res = builder.newOperand("=r"); |
370 | | - auto operand = builder.newOperand(v, "h"); |
371 | | - cvt(res, operand); |
372 | | - return builder.launch(rewriter, loc, f32_ty, false); |
| 357 | + return rewriter.create<LLVM::FPExtOp>(loc, f32_ty, v); |
373 | 358 | } |
374 | 359 |
|
375 | 360 | static Value convertFp32ToBf16(Location loc, |
@@ -590,96 +575,6 @@ struct FDivOpConversion |
590 | 575 | } |
591 | 576 | }; |
592 | 577 |
|
593 | | -struct FMulOpConversion |
594 | | - : ElementwiseOpConversionBase<arith::MulFOp, FMulOpConversion> { |
595 | | - using Base = ElementwiseOpConversionBase<arith::MulFOp, FMulOpConversion>; |
596 | | - using Base::Base; |
597 | | - using Adaptor = typename Base::OpAdaptor; |
598 | | - |
599 | | - SmallVector<Value> createDestOps(arith::MulFOp op, OpAdaptor adaptor, |
600 | | - ConversionPatternRewriter &rewriter, |
601 | | - Type elemTy, MultipleOperandsRange operands, |
602 | | - Location loc) const { |
603 | | - auto lhsElemTy = getElementType(op.getLhs()); |
604 | | - auto rhsElemTy = getElementType(op.getRhs()); |
605 | | - if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { |
606 | | - PTXBuilder builder; |
607 | | - auto ptxAsm = " { .reg .b16 c; \n" |
608 | | - " mov.b16 c, 0x8000U; \n" // 0.0 |
609 | | - " fma.rn.bf16 $0, $1, $2, c; } \n"; |
610 | | - auto &fMul = *builder.create<PTXInstr>(ptxAsm); |
611 | | - auto res = builder.newOperand("=h"); |
612 | | - auto lhs = builder.newOperand(operands[0][0], "h"); |
613 | | - auto rhs = builder.newOperand(operands[0][1], "h"); |
614 | | - fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); |
615 | | - return {builder.launch(rewriter, loc, bf16_ty, false)}; |
616 | | - } else { |
617 | | - return {rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0][0], |
618 | | - operands[0][1])}; |
619 | | - } |
620 | | - } |
621 | | -}; |
622 | | - |
623 | | -struct FAddOpConversion |
624 | | - : ElementwiseOpConversionBase<arith::AddFOp, FAddOpConversion> { |
625 | | - using Base = ElementwiseOpConversionBase<arith::AddFOp, FAddOpConversion>; |
626 | | - using Base::Base; |
627 | | - using Adaptor = typename Base::OpAdaptor; |
628 | | - |
629 | | - SmallVector<Value> createDestOps(arith::AddFOp op, OpAdaptor adaptor, |
630 | | - ConversionPatternRewriter &rewriter, |
631 | | - Type elemTy, MultipleOperandsRange operands, |
632 | | - Location loc) const { |
633 | | - auto lhsElemTy = getElementType(op.getLhs()); |
634 | | - auto rhsElemTy = getElementType(op.getRhs()); |
635 | | - if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { |
636 | | - PTXBuilder builder; |
637 | | - auto ptxAsm = "{ .reg .b16 c; \n" |
638 | | - " mov.b16 c, 0x3f80U; \n" // 1.0 |
639 | | - " fma.rn.bf16 $0, $1, c, $2; } \n"; |
640 | | - auto &fAdd = *builder.create<PTXInstr>(ptxAsm); |
641 | | - auto res = builder.newOperand("=h"); |
642 | | - auto lhs = builder.newOperand(operands[0][0], "h"); |
643 | | - auto rhs = builder.newOperand(operands[0][1], "h"); |
644 | | - fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); |
645 | | - return {builder.launch(rewriter, loc, bf16_ty, false)}; |
646 | | - } else { |
647 | | - return {rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0][0], |
648 | | - operands[0][1])}; |
649 | | - } |
650 | | - } |
651 | | -}; |
652 | | - |
653 | | -struct FSubOpConversion |
654 | | - : ElementwiseOpConversionBase<arith::SubFOp, FSubOpConversion> { |
655 | | - using Base = ElementwiseOpConversionBase<arith::SubFOp, FSubOpConversion>; |
656 | | - using Base::Base; |
657 | | - using Adaptor = typename Base::OpAdaptor; |
658 | | - |
659 | | - SmallVector<Value> createDestOps(arith::SubFOp op, OpAdaptor adaptor, |
660 | | - ConversionPatternRewriter &rewriter, |
661 | | - Type elemTy, MultipleOperandsRange operands, |
662 | | - Location loc) const { |
663 | | - auto lhsElemTy = getElementType(op.getLhs()); |
664 | | - auto rhsElemTy = getElementType(op.getRhs()); |
665 | | - if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { |
666 | | - PTXBuilder builder; |
667 | | - auto ptxAsm = " { .reg .b16 c; \n" |
668 | | - " mov.b16 c, 0xbf80U; \n" // -1.0 |
669 | | - " fma.rn.bf16 $0, $2, c, $1;} \n"; |
670 | | - auto &fSub = *builder.create<PTXInstr>(ptxAsm); |
671 | | - auto res = builder.newOperand("=h"); |
672 | | - auto lhs = builder.newOperand(operands[0][0], "h"); |
673 | | - auto rhs = builder.newOperand(operands[0][1], "h"); |
674 | | - fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); |
675 | | - return {builder.launch(rewriter, loc, bf16_ty, false)}; |
676 | | - } else { |
677 | | - return {rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0][0], |
678 | | - operands[0][1])}; |
679 | | - } |
680 | | - } |
681 | | -}; |
682 | | - |
683 | 578 | // Uses inline ptx to convert s8/u8 to bf16, since the |
684 | 579 | struct SIToFPOpConversion |
685 | 580 | : ElementwiseOpConversionBase<arith::SIToFPOp, SIToFPOpConversion> { |
@@ -733,51 +628,6 @@ struct FPToSIOpConversion |
733 | 628 | } |
734 | 629 | }; |
735 | 630 |
|
736 | | -struct ExtFOpConversion |
737 | | - : ElementwiseOpConversionBase<arith::ExtFOp, ExtFOpConversion> { |
738 | | - using Base = ElementwiseOpConversionBase<arith::ExtFOp, ExtFOpConversion>; |
739 | | - using Base::Base; |
740 | | - using Adaptor = typename Base::OpAdaptor; |
741 | | - |
742 | | - SmallVector<Value> createDestOps(arith::ExtFOp op, OpAdaptor adaptor, |
743 | | - ConversionPatternRewriter &rewriter, |
744 | | - Type elemTy, MultipleOperandsRange operands, |
745 | | - Location loc) const { |
746 | | - auto inElemTy = getElementType(op.getIn()); |
747 | | - if (inElemTy.isBF16()) { |
748 | | - auto outElemTy = getElementType(op.getOut()); |
749 | | - assert(outElemTy.isF32() && "unsupported conversion"); |
750 | | - return { |
751 | | - FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0])}; |
752 | | - } else { |
753 | | - return {rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0][0])}; |
754 | | - } |
755 | | - } |
756 | | -}; |
757 | | - |
758 | | -struct TruncFOpConversion |
759 | | - : ElementwiseOpConversionBase<arith::TruncFOp, TruncFOpConversion> { |
760 | | - using Base = ElementwiseOpConversionBase<arith::TruncFOp, TruncFOpConversion>; |
761 | | - using Base::Base; |
762 | | - using Adaptor = typename Base::OpAdaptor; |
763 | | - |
764 | | - SmallVector<Value> createDestOps(arith::TruncFOp op, OpAdaptor adaptor, |
765 | | - ConversionPatternRewriter &rewriter, |
766 | | - Type elemTy, MultipleOperandsRange operands, |
767 | | - Location loc) const { |
768 | | - auto outElemTy = getElementType(op.getOut()); |
769 | | - if (outElemTy.isBF16()) { |
770 | | - auto inElemTy = getElementType(op.getIn()); |
771 | | - assert(inElemTy.isF32() && "unsupported conversion"); |
772 | | - return {// Trunc uses the default rounding mode: RTNE |
773 | | - FpToFpOpConversion::convertFp32ToBf16( |
774 | | - loc, rewriter, operands[0][0], RoundingMode::RTNE)}; |
775 | | - } else { |
776 | | - return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])}; |
777 | | - } |
778 | | - } |
779 | | -}; |
780 | | - |
781 | 631 | struct ExpOpConversionApprox |
782 | 632 | : ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox> { |
783 | 633 | using Base = ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox>; |
@@ -961,15 +811,21 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( |
961 | 811 | mlir::triton::populateElementwiseOpToLLVMPatterns( |
962 | 812 | typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); |
963 | 813 |
|
964 | | - patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
965 | | - patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
966 | | - patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
967 | | - patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
| 814 | +#define POPULATE_OP(SRC_OP, DST_OP) \ |
| 815 | + patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \ |
| 816 | + typeConverter, axisInfoAnalysis, benefit) |
968 | 817 |
|
969 | | - patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
970 | | - patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
971 | | - patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
| 818 | + POPULATE_OP(arith::SubFOp, LLVM::FSubOp); |
| 819 | + POPULATE_OP(arith::AddFOp, LLVM::FAddOp); |
| 820 | + POPULATE_OP(arith::MulFOp, LLVM::FMulOp); |
972 | 821 |
|
| 822 | + POPULATE_OP(arith::ExtFOp, LLVM::FPExtOp); |
| 823 | + POPULATE_OP(arith::TruncFOp, LLVM::FPTruncOp); |
| 824 | + |
| 825 | +#undef POPULATE_OP |
| 826 | + |
| 827 | + patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
| 828 | + patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit); |
973 | 829 | patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, |
974 | 830 | computeCapability, benefit); |
975 | 831 | patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis, |
|
0 commit comments