Skip to content

Commit de0f754

Browse files
authored
[AMD] Use LLVM ops for fp16<->fp32 casts (triton-lang#5859)
Inline assembly can be a blocker for LLVM backend to optimize.
1 parent d827851 commit de0f754

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
55
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
66
tt.func @f16_to_f32(%arg0: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>) {
7-
// CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}v_cvt_f32_f16 {{.*}}: (f16) -> f32
7+
// CHECK-COUNT-8: llvm.fpext %{{.+}} : f16 to f32
88
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
99
tt.return
1010
}
@@ -21,3 +21,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2121
tt.return
2222
}
2323
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: f32_to_f16
28+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
29+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
30+
tt.func @f32_to_f16(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
31+
// CHECK-COUNT-8: llvm.intr.experimental.constrained.fptrunc %{{.+}} tonearest ignore : f32 to f16
32+
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
33+
// CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}s_setreg_imm32_b32{{.+}}v_cvt_f16_f32{{.+}}s_setreg_imm32_b32{{.+}} : (f32) -> f16
34+
35+
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
36+
tt.return
37+
}
38+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Analysis/Allocation.h"
66
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
77
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
89
#include "triton/Dialect/Triton/IR/Dialect.h"
910

1011
using namespace mlir;
@@ -173,12 +174,8 @@ Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
173174

174175
static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
175176
const Value &v) {
176-
GCNBuilder builder;
177-
auto &cvt = *builder.create("v_cvt_f32_f16");
178-
auto res = builder.newOperand("=v");
179-
auto operand = builder.newOperand(v, "v");
180-
cvt(res, operand);
181-
return builder.launch(rewriter, loc, f32_ty, false);
177+
TritonLLVMOpBuilder b(loc, rewriter);
178+
return b.fpext(f32_ty, v);
182179
}
183180

184181
// convert fp8 to fp32

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -532,20 +532,25 @@ int32_t getCtrlBitsForCacheModifierOnTarget(
532532

533533
Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v,
534534
triton::RoundingMode rounding) {
535+
if (rounding == triton::RoundingMode::RTNE) {
536+
LLVM::RoundingMode rm = LLVM::RoundingMode::NearestTiesToEven;
537+
return rewriter.create<LLVM::ConstrainedFPTruncIntr>(
538+
loc, f16_ty, v, rm, LLVM::FPExceptionBehavior::Ignore);
539+
}
540+
541+
// TODO: Figure out the test failure with RTZ LLVM::ConstrainedFPTruncIntr and
542+
// switch to not use inline assembly too.
543+
assert(rounding == triton::RoundingMode::RTZ);
535544
GCNBuilder builder;
536545

537546
auto &cvt = *builder.create("v_cvt_f16_f32");
538547
auto res = builder.newOperand("=v");
539548
auto operand = builder.newOperand(v, "v");
540-
if (rounding == triton::RoundingMode::RTZ) {
541-
auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc");
542-
setRTZ();
543-
}
549+
auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc");
550+
setRTZ();
544551
cvt(res, operand);
545-
if (rounding == triton::RoundingMode::RTZ) {
546-
auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0");
547-
resetRTZ();
548-
}
552+
auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0");
553+
resetRTZ();
549554
return builder.launch(rewriter, loc, f16_ty, false);
550555
}
551556

0 commit comments

Comments
 (0)