Skip to content

Commit 7444438

Browse files
authored
[NVIDIA] Use native bf16 ops (#5732)
The custom fma codegen for Ampere has been upstreamed to NVPTX, so we no longer need custom conversion code. As a bonus, we now codegen vectorized bf16 ops for free.
1 parent b24ec52 commit 7444438

File tree

4 files changed

+122
-179
lines changed

4 files changed

+122
-179
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,27 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
208208
ModuleAxisInfoAnalysis &axisAnalysisPass;
209209
};
210210

211+
// Trivial case where we map elementwise to an existing LLVM operator
212+
template <typename SourceOp, typename DestOp>
213+
struct ElementwiseOpConversion
214+
: public ElementwiseOpConversionBase<
215+
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
216+
using Base =
217+
ElementwiseOpConversionBase<SourceOp,
218+
ElementwiseOpConversion<SourceOp, DestOp>>;
219+
using Base::Base;
220+
using OpAdaptor = typename Base::OpAdaptor;
221+
222+
// An interface to support variant DestOp builder.
223+
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
224+
ConversionPatternRewriter &rewriter,
225+
Type elemTy, MultipleOperandsRange operands,
226+
Location loc) const {
227+
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
228+
adaptor.getAttributes().getValue())};
229+
}
230+
};
231+
211232
} // namespace gpu
212233

213234
} // namespace mlir::triton

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,26 +215,6 @@ struct ExternElementwiseOpConversion
215215
}
216216
};
217217

218-
template <typename SourceOp, typename DestOp>
219-
struct ElementwiseOpConversion
220-
: public ElementwiseOpConversionBase<
221-
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
222-
using Base =
223-
ElementwiseOpConversionBase<SourceOp,
224-
ElementwiseOpConversion<SourceOp, DestOp>>;
225-
using Base::Base;
226-
using OpAdaptor = typename Base::OpAdaptor;
227-
228-
// An interface to support variant DestOp builder.
229-
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
230-
ConversionPatternRewriter &rewriter,
231-
Type elemTy, MultipleOperandsRange operands,
232-
Location loc) const {
233-
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
234-
adaptor.getAttributes().getValue())};
235-
}
236-
};
237-
238218
struct ElementwiseInlineAsmOpConversion
239219
: public ConvertOpToLLVMPattern<ElementwiseInlineAsmOp> {
240220
using Base = ConvertOpToLLVMPattern<ElementwiseInlineAsmOp>;
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx83 | FileCheck --check-prefixes CHECK,SM90 --dump-input-context=20 %s
2+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=80 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_80 -mattr=+ptx83 | FileCheck --check-prefixes CHECK,SM80 --dump-input-context=20 %s
3+
4+
5+
#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
tt.func public @add_bf16(%ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) {
8+
// CHECK-LABEL: add_bf16
9+
// SM80-COUNT-4: fma.rn.bf16x2
10+
// SM90-COUNT-4: add.rn.bf16x2
11+
%0 = arith.addf %arg0, %arg1 : tensor<256xbf16, #blocked>
12+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
13+
%2 = tt.splat %ptr : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked>
14+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<bf16>, #blocked>, tensor<256xi32, #blocked>
15+
tt.store %3, %0 : tensor<256x!tt.ptr<bf16>, #blocked>
16+
tt.return
17+
}
18+
19+
tt.func public @sub_bf16(%ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) {
20+
// CHECK-LABEL: sub_bf16
21+
// SM80-COUNT-4: fma.rn.bf16x2
22+
// SM90-COUNT-4: sub.rn.bf16x2
23+
%0 = arith.subf %arg0, %arg1 : tensor<256xbf16, #blocked>
24+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
25+
%2 = tt.splat %ptr : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked>
26+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<bf16>, #blocked>, tensor<256xi32, #blocked>
27+
tt.store %3, %0 : tensor<256x!tt.ptr<bf16>, #blocked>
28+
tt.return
29+
}
30+
31+
tt.func public @mul_bf16(%ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>, %arg1: tensor<256xbf16, #blocked>) {
32+
// CHECK-LABEL: mul_bf16
33+
// SM80-COUNT-4: fma.rn.bf16x2
34+
// SM90-COUNT-4: mul.rn.bf16x2
35+
%0 = arith.mulf %arg0, %arg1 : tensor<256xbf16, #blocked>
36+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
37+
%2 = tt.splat %ptr : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked>
38+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<bf16>, #blocked>, tensor<256xi32, #blocked>
39+
tt.store %3, %0 : tensor<256x!tt.ptr<bf16>, #blocked>
40+
tt.return
41+
}
42+
43+
tt.func public @extf_bf16(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg0: tensor<256xbf16, #blocked>) {
44+
// CHECK-LABEL: extf_bf16
45+
// CHECK-COUNT-8: cvt.f32.bf16
46+
%0 = arith.extf %arg0 : tensor<256xbf16, #blocked> to tensor<256xf32, #blocked>
47+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
48+
%2 = tt.splat %ptr : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked>
49+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xi32, #blocked>
50+
tt.store %3, %0 : tensor<256x!tt.ptr<f32>, #blocked>
51+
tt.return
52+
}
53+
54+
tt.func public @truncf_bf16(%ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg0: tensor<256xf32, #blocked>) {
55+
// CHECK-LABEL: truncf_bf16
56+
// CHECK-COUNT-4: cvt.rn.bf16x2.f32
57+
%0 = arith.truncf %arg0 : tensor<256xf32, #blocked> to tensor<256xbf16, #blocked>
58+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
59+
%2 = tt.splat %ptr : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked>
60+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<bf16>, #blocked>, tensor<256xi32, #blocked>
61+
tt.store %3, %0 : tensor<256x!tt.ptr<bf16>, #blocked>
62+
tt.return
63+
}
64+
65+
tt.func public @extf_f16(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg0: tensor<256xf16, #blocked>) {
66+
// CHECK-LABEL: extf_f16
67+
// CHECK-COUNT-8: cvt.f32.f16
68+
%0 = arith.extf %arg0 : tensor<256xf16, #blocked> to tensor<256xf32, #blocked>
69+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
70+
%2 = tt.splat %ptr : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked>
71+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xi32, #blocked>
72+
tt.store %3, %0 : tensor<256x!tt.ptr<f32>, #blocked>
73+
tt.return
74+
}
75+
76+
tt.func public @truncf_f16(%ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg0: tensor<256xf32, #blocked>) {
77+
// CHECK-LABEL: truncf_f16
78+
// CHECK-COUNT-4: cvt.rn.f16x2.f32
79+
%0 = arith.truncf %arg0 : tensor<256xf32, #blocked> to tensor<256xf16, #blocked>
80+
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
81+
%2 = tt.splat %ptr : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked>
82+
%3 = tt.addptr %2, %1 : tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xi32, #blocked>
83+
tt.store %3, %0 : tensor<256x!tt.ptr<f16>, #blocked>
84+
tt.return
85+
}
86+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 15 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "TargetInfo.h"
33
#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
44
#include "Utility.h"
5+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
56
#include "mlir/Support/LLVM.h"
67
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
78
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -350,26 +351,10 @@ struct FpToFpOpConversion
350351
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
351352
computeCapability(computeCapability) {}
352353

353-
static Value convertBf16ToFp32(Location loc,
354-
ConversionPatternRewriter &rewriter,
355-
const Value &v) {
356-
PTXBuilder builder;
357-
auto &cvt = *builder.create("cvt.f32.bf16");
358-
auto res = builder.newOperand("=r");
359-
auto operand = builder.newOperand(v, "h");
360-
cvt(res, operand);
361-
return builder.launch(rewriter, loc, f32_ty, false);
362-
}
363-
364354
static Value convertFp16ToFp32(Location loc,
365355
ConversionPatternRewriter &rewriter,
366356
const Value &v) {
367-
PTXBuilder builder;
368-
auto &cvt = *builder.create("cvt.f32.f16");
369-
auto res = builder.newOperand("=r");
370-
auto operand = builder.newOperand(v, "h");
371-
cvt(res, operand);
372-
return builder.launch(rewriter, loc, f32_ty, false);
357+
return rewriter.create<LLVM::FPExtOp>(loc, f32_ty, v);
373358
}
374359

375360
static Value convertFp32ToBf16(Location loc,
@@ -590,96 +575,6 @@ struct FDivOpConversion
590575
}
591576
};
592577

593-
struct FMulOpConversion
594-
: ElementwiseOpConversionBase<arith::MulFOp, FMulOpConversion> {
595-
using Base = ElementwiseOpConversionBase<arith::MulFOp, FMulOpConversion>;
596-
using Base::Base;
597-
using Adaptor = typename Base::OpAdaptor;
598-
599-
SmallVector<Value> createDestOps(arith::MulFOp op, OpAdaptor adaptor,
600-
ConversionPatternRewriter &rewriter,
601-
Type elemTy, MultipleOperandsRange operands,
602-
Location loc) const {
603-
auto lhsElemTy = getElementType(op.getLhs());
604-
auto rhsElemTy = getElementType(op.getRhs());
605-
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
606-
PTXBuilder builder;
607-
auto ptxAsm = " { .reg .b16 c; \n"
608-
" mov.b16 c, 0x8000U; \n" // 0.0
609-
" fma.rn.bf16 $0, $1, $2, c; } \n";
610-
auto &fMul = *builder.create<PTXInstr>(ptxAsm);
611-
auto res = builder.newOperand("=h");
612-
auto lhs = builder.newOperand(operands[0][0], "h");
613-
auto rhs = builder.newOperand(operands[0][1], "h");
614-
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
615-
return {builder.launch(rewriter, loc, bf16_ty, false)};
616-
} else {
617-
return {rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0][0],
618-
operands[0][1])};
619-
}
620-
}
621-
};
622-
623-
struct FAddOpConversion
624-
: ElementwiseOpConversionBase<arith::AddFOp, FAddOpConversion> {
625-
using Base = ElementwiseOpConversionBase<arith::AddFOp, FAddOpConversion>;
626-
using Base::Base;
627-
using Adaptor = typename Base::OpAdaptor;
628-
629-
SmallVector<Value> createDestOps(arith::AddFOp op, OpAdaptor adaptor,
630-
ConversionPatternRewriter &rewriter,
631-
Type elemTy, MultipleOperandsRange operands,
632-
Location loc) const {
633-
auto lhsElemTy = getElementType(op.getLhs());
634-
auto rhsElemTy = getElementType(op.getRhs());
635-
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
636-
PTXBuilder builder;
637-
auto ptxAsm = "{ .reg .b16 c; \n"
638-
" mov.b16 c, 0x3f80U; \n" // 1.0
639-
" fma.rn.bf16 $0, $1, c, $2; } \n";
640-
auto &fAdd = *builder.create<PTXInstr>(ptxAsm);
641-
auto res = builder.newOperand("=h");
642-
auto lhs = builder.newOperand(operands[0][0], "h");
643-
auto rhs = builder.newOperand(operands[0][1], "h");
644-
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
645-
return {builder.launch(rewriter, loc, bf16_ty, false)};
646-
} else {
647-
return {rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0][0],
648-
operands[0][1])};
649-
}
650-
}
651-
};
652-
653-
struct FSubOpConversion
654-
: ElementwiseOpConversionBase<arith::SubFOp, FSubOpConversion> {
655-
using Base = ElementwiseOpConversionBase<arith::SubFOp, FSubOpConversion>;
656-
using Base::Base;
657-
using Adaptor = typename Base::OpAdaptor;
658-
659-
SmallVector<Value> createDestOps(arith::SubFOp op, OpAdaptor adaptor,
660-
ConversionPatternRewriter &rewriter,
661-
Type elemTy, MultipleOperandsRange operands,
662-
Location loc) const {
663-
auto lhsElemTy = getElementType(op.getLhs());
664-
auto rhsElemTy = getElementType(op.getRhs());
665-
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
666-
PTXBuilder builder;
667-
auto ptxAsm = " { .reg .b16 c; \n"
668-
" mov.b16 c, 0xbf80U; \n" // -1.0
669-
" fma.rn.bf16 $0, $2, c, $1;} \n";
670-
auto &fSub = *builder.create<PTXInstr>(ptxAsm);
671-
auto res = builder.newOperand("=h");
672-
auto lhs = builder.newOperand(operands[0][0], "h");
673-
auto rhs = builder.newOperand(operands[0][1], "h");
674-
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
675-
return {builder.launch(rewriter, loc, bf16_ty, false)};
676-
} else {
677-
return {rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0][0],
678-
operands[0][1])};
679-
}
680-
}
681-
};
682-
683578
// Uses inline ptx to convert s8/u8 to bf16, since the
684579
struct SIToFPOpConversion
685580
: ElementwiseOpConversionBase<arith::SIToFPOp, SIToFPOpConversion> {
@@ -733,51 +628,6 @@ struct FPToSIOpConversion
733628
}
734629
};
735630

736-
struct ExtFOpConversion
737-
: ElementwiseOpConversionBase<arith::ExtFOp, ExtFOpConversion> {
738-
using Base = ElementwiseOpConversionBase<arith::ExtFOp, ExtFOpConversion>;
739-
using Base::Base;
740-
using Adaptor = typename Base::OpAdaptor;
741-
742-
SmallVector<Value> createDestOps(arith::ExtFOp op, OpAdaptor adaptor,
743-
ConversionPatternRewriter &rewriter,
744-
Type elemTy, MultipleOperandsRange operands,
745-
Location loc) const {
746-
auto inElemTy = getElementType(op.getIn());
747-
if (inElemTy.isBF16()) {
748-
auto outElemTy = getElementType(op.getOut());
749-
assert(outElemTy.isF32() && "unsupported conversion");
750-
return {
751-
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0])};
752-
} else {
753-
return {rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0][0])};
754-
}
755-
}
756-
};
757-
758-
struct TruncFOpConversion
759-
: ElementwiseOpConversionBase<arith::TruncFOp, TruncFOpConversion> {
760-
using Base = ElementwiseOpConversionBase<arith::TruncFOp, TruncFOpConversion>;
761-
using Base::Base;
762-
using Adaptor = typename Base::OpAdaptor;
763-
764-
SmallVector<Value> createDestOps(arith::TruncFOp op, OpAdaptor adaptor,
765-
ConversionPatternRewriter &rewriter,
766-
Type elemTy, MultipleOperandsRange operands,
767-
Location loc) const {
768-
auto outElemTy = getElementType(op.getOut());
769-
if (outElemTy.isBF16()) {
770-
auto inElemTy = getElementType(op.getIn());
771-
assert(inElemTy.isF32() && "unsupported conversion");
772-
return {// Trunc uses the default rounding mode: RTNE
773-
FpToFpOpConversion::convertFp32ToBf16(
774-
loc, rewriter, operands[0][0], RoundingMode::RTNE)};
775-
} else {
776-
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
777-
}
778-
}
779-
};
780-
781631
struct ExpOpConversionApprox
782632
: ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox> {
783633
using Base = ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox>;
@@ -961,15 +811,21 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
961811
mlir::triton::populateElementwiseOpToLLVMPatterns(
962812
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
963813

964-
patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
965-
patterns.add<FSubOpConversion>(typeConverter, axisInfoAnalysis, benefit);
966-
patterns.add<FAddOpConversion>(typeConverter, axisInfoAnalysis, benefit);
967-
patterns.add<FMulOpConversion>(typeConverter, axisInfoAnalysis, benefit);
814+
#define POPULATE_OP(SRC_OP, DST_OP) \
815+
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
816+
typeConverter, axisInfoAnalysis, benefit)
968817

969-
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
970-
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
971-
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
818+
POPULATE_OP(arith::SubFOp, LLVM::FSubOp);
819+
POPULATE_OP(arith::AddFOp, LLVM::FAddOp);
820+
POPULATE_OP(arith::MulFOp, LLVM::FMulOp);
972821

822+
POPULATE_OP(arith::ExtFOp, LLVM::FPExtOp);
823+
POPULATE_OP(arith::TruncFOp, LLVM::FPTruncOp);
824+
825+
#undef POPULATE_OP
826+
827+
patterns.add<FDivOpConversion>(typeConverter, axisInfoAnalysis, benefit);
828+
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
973829
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis,
974830
computeCapability, benefit);
975831
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,

0 commit comments

Comments
 (0)