Skip to content

Commit 708b5f1

Browse files
authored
[AMD] Use ROCDL ops to replace inline assembly for conversions (#6313)
This commit replaces some inline assembly we used for type conversion with ROCDL ops. This makes it easier for the LLVM AMDGPU backend to optimize.
1 parent 1f98379 commit 708b5f1

File tree

2 files changed

+94
-93
lines changed

2 files changed

+94
-93
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6363

6464
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.fp8.bf16
6565
%5 = tt.fp_to_fp %arg2, rounding = rtne : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
66+
67+
// CHECK-COUNT-4: rocdl.cvt.pk.bf8.f32
68+
%6 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
69+
70+
// CHECK-COUNT-4: rocdl.cvt.pk.fp8.f32
71+
%7 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
6672
tt.return
6773
}
6874
}
@@ -73,7 +79,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
7379
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
7480
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
7581
tt.func @upcast_from_f8(%arg0: tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
76-
%arg1: tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
82+
%arg1: tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
83+
%arg2: tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
84+
%arg3: tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
7785
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f32.bf8
7886
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7987

@@ -91,6 +99,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9199

92100
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf16.fp8
93101
%5 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
102+
103+
// CHECK-COUNT-4: rocdl.cvt.pk.f32.bf8
104+
%6 = tt.fp_to_fp %arg2 : tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
105+
106+
// CHECK-COUNT-4: rocdl.cvt.pk.f32.fp8
107+
%7 = tt.fp_to_fp %arg3 : tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
94108
tt.return
95109
}
96110
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 79 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace {
2424
// Data type conversion utility functions
2525
//===----------------------------------------------------------------------===//
2626
// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
27-
template <typename convertOp>
27+
template <typename ConvertOp>
2828
static SmallVector<Value>
2929
cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3030
Value v0, Value v1) {
@@ -39,13 +39,13 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3939

4040
auto resType = i32_ty;
4141
auto dstType = f32_ty;
42-
if constexpr (std::is_same_v<convertOp, ROCDL::CvtScaleF32PkF32Fp8Op> ||
43-
std::is_same_v<convertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
42+
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Fp8Op> ||
43+
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
4444
resType = i64_ty;
4545
dstType = f32_ty;
46-
} else if constexpr (std::is_same_v<convertOp,
46+
} else if constexpr (std::is_same_v<ConvertOp,
4747
ROCDL::CvtScaleF32PkF16Fp8Op> ||
48-
std::is_same_v<convertOp,
48+
std::is_same_v<ConvertOp,
4949
ROCDL::CvtScaleF32PkF16Bf8Op>) {
5050
resType = i32_ty;
5151
dstType = f16_ty;
@@ -55,7 +55,7 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
5555
}
5656
Value scale = b.f32_val(1);
5757
Value select = b.false_val();
58-
auto result = rewriter.create<convertOp>(loc, resType, i32v, scale, select);
58+
auto result = rewriter.create<ConvertOp>(loc, resType, i32v, scale, select);
5959
auto retVecTy = vec_ty(dstType, 2);
6060
auto retVec = b.bitcast(result, retVecTy);
6161
SmallVector<Value> ret(2);
@@ -65,7 +65,7 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
6565
}
6666

6767
// Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4
68-
template <typename convertOp>
68+
template <typename ConvertOp>
6969
static SmallVector<Value>
7070
cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7171
Value v0, Value v1) {
@@ -76,9 +76,9 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7676
Value select = b.false_val();
7777

7878
Value result;
79-
if constexpr (std::is_same_v<convertOp, ROCDL::CvtScaleF32PkFp8F32Op> ||
80-
std::is_same_v<convertOp, ROCDL::CvtScaleF32PkBf8F32Op>) {
81-
result = rewriter.create<convertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
79+
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkFp8F32Op> ||
80+
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkBf8F32Op>) {
81+
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
8282
select);
8383
} else {
8484
Type v2F16Ty = vec_ty(v0.getType(), 2);
@@ -87,7 +87,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
8787
auto idx1 = b.i32_val(1);
8888
srcVec = b.insert_element(v2F16Ty, srcVec, v0, idx0);
8989
srcVec = b.insert_element(v2F16Ty, srcVec, v1, idx1);
90-
result = rewriter.create<convertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
90+
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
9191
select);
9292
}
9393
auto fp8x4VecTy = vec_ty(i8_ty, 4);
@@ -295,89 +295,52 @@ static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
295295
return b.fpext(f32_ty, v);
296296
}
297297

298-
// Convert Fp8 to Fp32 on CDNA3
299-
static SmallVector<Value> cvtFp8ToFp32(Location loc,
300-
ConversionPatternRewriter &rewriter,
301-
Value v0, Value v1,
302-
const std::string &fp8_format) {
298+
// Convert Bf8/Fp8 to Fp32 on CDNA3
299+
template <typename ConvertOp>
300+
static SmallVector<Value> cvtPkF8ToFp32(Location loc,
301+
ConversionPatternRewriter &rewriter,
302+
Value v0, Value v1) {
303303
auto b = TritonLLVMOpBuilder(loc, rewriter);
304-
assert(fp8_format == "fp8" || fp8_format == "bf8");
305-
std::string ins_str = "v_cvt_pk_f32_" + fp8_format;
306-
307304
auto fp8x4VecTy = vec_ty(i8_ty, 4);
308305
Value fp8x4Vec = b.undef(fp8x4VecTy);
309-
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, b.i32_val(0));
310-
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, b.i32_val(1));
306+
auto idx0 = b.i32_val(0);
307+
auto idx1 = b.i32_val(1);
308+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, idx0);
309+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, idx1);
311310
auto i32v = b.bitcast(fp8x4Vec, i32_ty);
312311

313-
GCNBuilder builder1;
314-
auto &cvt = *builder1.create(ins_str);
315-
auto res = builder1.newOperand("=v");
316-
auto operand = builder1.newOperand(i32v, "v");
317-
cvt(res, operand);
318-
auto i64v = builder1.launch(rewriter, loc, i64_ty, false);
319-
auto fp32x2VecTy = vec_ty(f32_ty, 2);
320-
auto fp32x2Vec = b.bitcast(i64v, fp32x2VecTy);
312+
auto resType = i64_ty;
313+
auto dstType = f32_ty;
321314

315+
Value select = b.false_val();
316+
auto result = rewriter.create<ConvertOp>(loc, resType, i32v, select);
317+
auto f32x2VecTy = vec_ty(dstType, 2);
318+
auto retVec = b.bitcast(result, f32x2VecTy);
322319
SmallVector<Value> ret(2);
323-
ret[0] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(0));
324-
ret[1] = b.extract_element(f32_ty, fp32x2Vec, b.i32_val(1));
325-
320+
ret[0] = b.extract_element(dstType, retVec, idx0);
321+
ret[1] = b.extract_element(dstType, retVec, idx1);
326322
return ret;
327323
}
328324

329-
// Convert Fp32 to Fp8 on CDNA3
330-
static SmallVector<Value> cvtFp32ToFp8(Location loc,
331-
ConversionPatternRewriter &rewriter,
332-
Value v0, Value v1,
333-
const std::string &fp8_format) {
325+
// Convert Fp32 to Bf8/Fp8 on CDNA3
326+
template <typename ConvertOp>
327+
static SmallVector<Value> cvtPkFp32ToF8(Location loc,
328+
ConversionPatternRewriter &rewriter,
329+
Value v0, Value v1) {
334330
auto b = TritonLLVMOpBuilder(loc, rewriter);
335-
assert(fp8_format == "fp8" || fp8_format == "bf8");
336-
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
337-
338-
GCNBuilder builder;
339-
auto &cvt = *builder.create(ins_str);
340-
auto res = builder.newOperand("=v");
341-
auto operand0 = builder.newOperand(v0, "v");
342-
auto operand1 = builder.newOperand(v1, "v");
343-
cvt(res, operand0, operand1);
344-
auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false);
331+
Type v2I16Ty = vec_ty(i16_ty, 2);
332+
Value old = b.undef(i32_ty);
333+
Value select = b.false_val();
345334

335+
Value result;
336+
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v0, v1, old, select);
346337
auto fp8x4VecTy = vec_ty(i8_ty, 4);
347-
auto a1 = b.bitcast(fp8x4Vec, fp8x4VecTy);
348-
338+
auto fp8x4Vec = b.bitcast(result, fp8x4VecTy);
349339
SmallVector<Value> ret(2);
350-
ret[0] = b.extract_element(i8_ty, a1, b.i32_val(0));
351-
ret[1] = b.extract_element(i8_ty, a1, b.i32_val(1));
352-
353-
return ret;
354-
}
355-
356-
// Convert Fp16 to Fp8 on CDNA3
357-
static SmallVector<Value>
358-
convert_val_Fp16_to_Fp8(Location loc, ConversionPatternRewriter &rewriter,
359-
Value v0, Value v1, const std::string &fp8_format) {
360-
assert(fp8_format == "fp8" || fp8_format == "bf8");
361-
std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32";
362-
363-
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0);
364-
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1);
365-
366-
// Convert fp32 to fp8
367-
return cvtFp32ToFp8(loc, rewriter, f32_0, f32_1, fp8_format);
368-
}
369-
370-
// Convert Fp8 to Fp16 on CDNA3
371-
static SmallVector<Value>
372-
convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
373-
Value v0, Value v1, const std::string &fp8_format) {
374-
// Convert fp8 to fp32
375-
SmallVector<Value> ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format);
376-
377-
// Convert fp32 to fp16
378-
ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE);
379-
ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE);
380-
340+
auto idx0 = b.i32_val(0);
341+
auto idx1 = b.i32_val(1);
342+
ret[0] = b.extract_element(i8_ty, fp8x4Vec, idx0);
343+
ret[1] = b.extract_element(i8_ty, fp8x4Vec, idx1);
381344
return ret;
382345
}
383346

@@ -422,31 +385,31 @@ static SmallVector<Value>
422385
Fp32_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
423386
const SmallVector<Value> &v) {
424387
assert(v.size() == 2);
425-
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "bf8");
388+
return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op>(loc, rewriter, v[0], v[1]);
426389
}
427390

428391
// Fp32 -> Nanoo Fp8 on CDNA3
429392
static SmallVector<Value>
430393
Fp32_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
431394
const SmallVector<Value> &v) {
432395
assert(v.size() == 2);
433-
return cvtFp32ToFp8(loc, rewriter, v[0], v[1], "fp8");
396+
return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op>(loc, rewriter, v[0], v[1]);
434397
}
435398

436399
// Nanoo Bf8 -> Fp32 on CDNA3
437400
static SmallVector<Value>
438401
Fp8E5M2FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
439402
const SmallVector<Value> &v) {
440403
assert(v.size() == 2);
441-
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
404+
return cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0], v[1]);
442405
}
443406

444407
// Nanoo Fp8 -> Fp32 on CDNA3
445408
static SmallVector<Value>
446409
Fp8E4M3FNUZ_to_Fp32(Location loc, ConversionPatternRewriter &rewriter,
447410
const SmallVector<Value> &v) {
448411
assert(v.size() == 2);
449-
return cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
412+
return cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0], v[1]);
450413
}
451414

452415
// Depend on whether we focus more on performance, we may skip
@@ -492,7 +455,11 @@ Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
492455
static SmallVector<Value>
493456
Fp16_to_Fp8E5M2FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
494457
const SmallVector<Value> &v) {
495-
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8");
458+
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v[0]);
459+
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v[1]);
460+
461+
// Convert fp32 to bf8
462+
return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op>(loc, rewriter, f32_0, f32_1);
496463
}
497464

498465
ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) {
@@ -698,7 +665,15 @@ Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
698665
static SmallVector<Value>
699666
Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
700667
const SmallVector<Value> &v) {
701-
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8");
668+
// Convert Bf8 to fp32
669+
SmallVector<Value> ret =
670+
cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0], v[1]);
671+
672+
// Convert fp32 to fp16
673+
ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE);
674+
ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE);
675+
676+
return ret;
702677
}
703678

704679
ConverterT Fp8E5M2FNUZ_to_Fp16(AMD::ISAFamily isaFamily) {
@@ -944,7 +919,7 @@ static SmallVector<Value>
944919
Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
945920
const SmallVector<Value> &v) {
946921
assert(v.size() == 2);
947-
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "fp8");
922+
auto ret = cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0], v[1]);
948923
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0], RoundingMode::RTZ);
949924
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1], RoundingMode::RTZ);
950925
return ret;
@@ -957,15 +932,15 @@ Bf16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
957932
assert(v.size() == 2);
958933
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
959934
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
960-
return cvtFp32ToFp8(loc, rewriter, v0, v1, "fp8");
935+
return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op>(loc, rewriter, v0, v1);
961936
}
962937

963938
// fp8e5m2fnuz to bf16
964939
static SmallVector<Value>
965940
Fp8E5M2FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
966941
const SmallVector<Value> &v) {
967942
assert(v.size() == 2);
968-
auto ret = cvtFp8ToFp32(loc, rewriter, v[0], v[1], "bf8");
943+
auto ret = cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0], v[1]);
969944
ret[0] = convertFp32ToBf16(loc, rewriter, ret[0], RoundingMode::RTZ);
970945
ret[1] = convertFp32ToBf16(loc, rewriter, ret[1], RoundingMode::RTZ);
971946
return ret;
@@ -978,7 +953,7 @@ Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
978953
assert(v.size() == 2);
979954
auto v0 = convertBf16ToFp32(loc, rewriter, v[0]);
980955
auto v1 = convertBf16ToFp32(loc, rewriter, v[1]);
981-
return cvtFp32ToFp8(loc, rewriter, v0, v1, "bf8");
956+
return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op>(loc, rewriter, v0, v1);
982957
}
983958

984959
static Value Fp8E4M3FNUZ_to_Fp16_oneValue(Location loc,
@@ -1026,7 +1001,15 @@ Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10261001
static SmallVector<Value>
10271002
Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
10281003
const SmallVector<Value> &v) {
1029-
return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8");
1004+
// Convert fp8 to fp32
1005+
SmallVector<Value> ret =
1006+
cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0], v[1]);
1007+
1008+
// Convert fp32 to fp16
1009+
ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE);
1010+
ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE);
1011+
1012+
return ret;
10301013
}
10311014

10321015
static ConverterT Fp8E4M3FNUZ_to_Fp16(AMD::ISAFamily isaFamily) {
@@ -1082,7 +1065,11 @@ Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
10821065
static SmallVector<Value>
10831066
Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter,
10841067
const SmallVector<Value> &v) {
1085-
return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8");
1068+
auto f32_0 = cvtFp16ToFp32(loc, rewriter, v[0]);
1069+
auto f32_1 = cvtFp16ToFp32(loc, rewriter, v[1]);
1070+
1071+
// Convert fp32 to fp8
1072+
return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op>(loc, rewriter, f32_0, f32_1);
10861073
}
10871074

10881075
static ConverterT Fp16_to_Fp8E4M3FNUZ(AMD::ISAFamily isaFamily) {

0 commit comments

Comments
 (0)