@@ -24,7 +24,7 @@ namespace {
2424// Data type conversion utility functions
2525// ===----------------------------------------------------------------------===//
2626// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
27- template <typename convertOp >
27+ template <typename ConvertOp >
2828static SmallVector<Value>
2929cvtScalePkUpcastFromFp8 (Location loc, ConversionPatternRewriter &rewriter,
3030 Value v0, Value v1) {
@@ -39,13 +39,13 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3939
4040 auto resType = i32_ty;
4141 auto dstType = f32_ty;
42- if constexpr (std::is_same_v<convertOp , ROCDL::CvtScaleF32PkF32Fp8Op> ||
43- std::is_same_v<convertOp , ROCDL::CvtScaleF32PkF32Bf8Op>) {
42+ if constexpr (std::is_same_v<ConvertOp , ROCDL::CvtScaleF32PkF32Fp8Op> ||
43+ std::is_same_v<ConvertOp , ROCDL::CvtScaleF32PkF32Bf8Op>) {
4444 resType = i64_ty;
4545 dstType = f32_ty;
46- } else if constexpr (std::is_same_v<convertOp ,
46+ } else if constexpr (std::is_same_v<ConvertOp ,
4747 ROCDL::CvtScaleF32PkF16Fp8Op> ||
48- std::is_same_v<convertOp ,
48+ std::is_same_v<ConvertOp ,
4949 ROCDL::CvtScaleF32PkF16Bf8Op>) {
5050 resType = i32_ty;
5151 dstType = f16_ty;
@@ -55,7 +55,7 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
5555 }
5656 Value scale = b.f32_val (1 );
5757 Value select = b.false_val ();
58- auto result = rewriter.create <convertOp >(loc, resType, i32v, scale, select);
58+ auto result = rewriter.create <ConvertOp >(loc, resType, i32v, scale, select);
5959 auto retVecTy = vec_ty (dstType, 2 );
6060 auto retVec = b.bitcast (result, retVecTy);
6161 SmallVector<Value> ret (2 );
@@ -65,7 +65,7 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
6565}
6666
6767// Convert Fp16/Bf16/Fp32 to OCP Fp8/Bf8 on CDNA4
68- template <typename convertOp >
68+ template <typename ConvertOp >
6969static SmallVector<Value>
7070cvtScalePkDowncastToFp8 (Location loc, ConversionPatternRewriter &rewriter,
7171 Value v0, Value v1) {
@@ -76,9 +76,9 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7676 Value select = b.false_val ();
7777
7878 Value result;
79- if constexpr (std::is_same_v<convertOp , ROCDL::CvtScaleF32PkFp8F32Op> ||
80- std::is_same_v<convertOp , ROCDL::CvtScaleF32PkBf8F32Op>) {
81- result = rewriter.create <convertOp >(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
79+ if constexpr (std::is_same_v<ConvertOp , ROCDL::CvtScaleF32PkFp8F32Op> ||
80+ std::is_same_v<ConvertOp , ROCDL::CvtScaleF32PkBf8F32Op>) {
81+ result = rewriter.create <ConvertOp >(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
8282 select);
8383 } else {
8484 Type v2F16Ty = vec_ty (v0.getType (), 2 );
@@ -87,7 +87,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
8787 auto idx1 = b.i32_val (1 );
8888 srcVec = b.insert_element (v2F16Ty, srcVec, v0, idx0);
8989 srcVec = b.insert_element (v2F16Ty, srcVec, v1, idx1);
90- result = rewriter.create <convertOp >(loc, v2I16Ty, v2I16Vec, srcVec, scale,
90+ result = rewriter.create <ConvertOp >(loc, v2I16Ty, v2I16Vec, srcVec, scale,
9191 select);
9292 }
9393 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
@@ -295,89 +295,52 @@ static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
295295 return b.fpext (f32_ty, v);
296296}
297297
298- // Convert Fp8 to Fp32 on CDNA3
299- static SmallVector<Value> cvtFp8ToFp32 (Location loc,
300- ConversionPatternRewriter &rewriter ,
301- Value v0, Value v1 ,
302- const std::string &fp8_format ) {
298+ // Convert Bf8/ Fp8 to Fp32 on CDNA3
299+ template < typename ConvertOp>
300+ static SmallVector<Value> cvtPkF8ToFp32 (Location loc ,
301+ ConversionPatternRewriter &rewriter ,
302+ Value v0, Value v1 ) {
303303 auto b = TritonLLVMOpBuilder (loc, rewriter);
304- assert (fp8_format == " fp8" || fp8_format == " bf8" );
305- std::string ins_str = " v_cvt_pk_f32_" + fp8_format;
306-
307304 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
308305 Value fp8x4Vec = b.undef (fp8x4VecTy);
309- fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v0, b.i32_val (0 ));
310- fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, b.i32_val (1 ));
306+ auto idx0 = b.i32_val (0 );
307+ auto idx1 = b.i32_val (1 );
308+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v0, idx0);
309+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, idx1);
311310 auto i32v = b.bitcast (fp8x4Vec, i32_ty);
312311
313- GCNBuilder builder1;
314- auto &cvt = *builder1.create (ins_str);
315- auto res = builder1.newOperand (" =v" );
316- auto operand = builder1.newOperand (i32v, " v" );
317- cvt (res, operand);
318- auto i64v = builder1.launch (rewriter, loc, i64_ty, false );
319- auto fp32x2VecTy = vec_ty (f32_ty, 2 );
320- auto fp32x2Vec = b.bitcast (i64v, fp32x2VecTy);
312+ auto resType = i64_ty;
313+ auto dstType = f32_ty;
321314
315+ Value select = b.false_val ();
316+ auto result = rewriter.create <ConvertOp>(loc, resType, i32v, select);
317+ auto f32x2VecTy = vec_ty (dstType, 2 );
318+ auto retVec = b.bitcast (result, f32x2VecTy);
322319 SmallVector<Value> ret (2 );
323- ret[0 ] = b.extract_element (f32_ty, fp32x2Vec, b.i32_val (0 ));
324- ret[1 ] = b.extract_element (f32_ty, fp32x2Vec, b.i32_val (1 ));
325-
320+ ret[0 ] = b.extract_element (dstType, retVec, idx0);
321+ ret[1 ] = b.extract_element (dstType, retVec, idx1);
326322 return ret;
327323}
328324
329- // Convert Fp32 to Fp8 on CDNA3
330- static SmallVector<Value> cvtFp32ToFp8 (Location loc,
331- ConversionPatternRewriter &rewriter ,
332- Value v0, Value v1 ,
333- const std::string &fp8_format ) {
325+ // Convert Fp32 to Bf8/ Fp8 on CDNA3
326+ template < typename ConvertOp>
327+ static SmallVector<Value> cvtPkFp32ToF8 (Location loc ,
328+ ConversionPatternRewriter &rewriter ,
329+ Value v0, Value v1 ) {
334330 auto b = TritonLLVMOpBuilder (loc, rewriter);
335- assert (fp8_format == " fp8" || fp8_format == " bf8" );
336- std::string ins_str = " v_cvt_pk_" + fp8_format + " _f32" ;
337-
338- GCNBuilder builder;
339- auto &cvt = *builder.create (ins_str);
340- auto res = builder.newOperand (" =v" );
341- auto operand0 = builder.newOperand (v0, " v" );
342- auto operand1 = builder.newOperand (v1, " v" );
343- cvt (res, operand0, operand1);
344- auto fp8x4Vec = builder.launch (rewriter, loc, i32_ty, false );
331+ Type v2I16Ty = vec_ty (i16_ty, 2 );
332+ Value old = b.undef (i32_ty);
333+ Value select = b.false_val ();
345334
335+ Value result;
336+ result = rewriter.create <ConvertOp>(loc, v2I16Ty, v0, v1, old, select);
346337 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
347- auto a1 = b.bitcast (fp8x4Vec, fp8x4VecTy);
348-
338+ auto fp8x4Vec = b.bitcast (result, fp8x4VecTy);
349339 SmallVector<Value> ret (2 );
350- ret[0 ] = b.extract_element (i8_ty, a1, b.i32_val (0 ));
351- ret[1 ] = b.extract_element (i8_ty, a1, b.i32_val (1 ));
352-
353- return ret;
354- }
355-
356- // Convert Fp16 to Fp8 on CDNA3
357- static SmallVector<Value>
358- convert_val_Fp16_to_Fp8 (Location loc, ConversionPatternRewriter &rewriter,
359- Value v0, Value v1, const std::string &fp8_format) {
360- assert (fp8_format == " fp8" || fp8_format == " bf8" );
361- std::string ins_str = " v_cvt_pk_" + fp8_format + " _f32" ;
362-
363- auto f32_0 = cvtFp16ToFp32 (loc, rewriter, v0);
364- auto f32_1 = cvtFp16ToFp32 (loc, rewriter, v1);
365-
366- // Convert fp32 to fp8
367- return cvtFp32ToFp8 (loc, rewriter, f32_0, f32_1, fp8_format);
368- }
369-
370- // Convert Fp8 to Fp16 on CDNA3
371- static SmallVector<Value>
372- convert_val_Fp8_to_Fp16 (Location loc, ConversionPatternRewriter &rewriter,
373- Value v0, Value v1, const std::string &fp8_format) {
374- // Convert fp8 to fp32
375- SmallVector<Value> ret = cvtFp8ToFp32 (loc, rewriter, v0, v1, fp8_format);
376-
377- // Convert fp32 to fp16
378- ret[0 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[0 ], RoundingMode::RTNE);
379- ret[1 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[1 ], RoundingMode::RTNE);
380-
340+ auto idx0 = b.i32_val (0 );
341+ auto idx1 = b.i32_val (1 );
342+ ret[0 ] = b.extract_element (i8_ty, fp8x4Vec, idx0);
343+ ret[1 ] = b.extract_element (i8_ty, fp8x4Vec, idx1);
381344 return ret;
382345}
383346
@@ -422,31 +385,31 @@ static SmallVector<Value>
422385Fp32_to_Fp8E5M2FNUZ (Location loc, ConversionPatternRewriter &rewriter,
423386 const SmallVector<Value> &v) {
424387 assert (v.size () == 2 );
425- return cvtFp32ToFp8 (loc, rewriter, v[0 ], v[1 ], " bf8 " );
388+ return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op> (loc, rewriter, v[0 ], v[1 ]);
426389}
427390
428391// Fp32 -> Nanoo Fp8 on CDNA3
429392static SmallVector<Value>
430393Fp32_to_Fp8E4M3FNUZ (Location loc, ConversionPatternRewriter &rewriter,
431394 const SmallVector<Value> &v) {
432395 assert (v.size () == 2 );
433- return cvtFp32ToFp8 (loc, rewriter, v[0 ], v[1 ], " fp8 " );
396+ return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op> (loc, rewriter, v[0 ], v[1 ]);
434397}
435398
436399// Nanoo Bf8 -> Fp32 on CDNA3
437400static SmallVector<Value>
438401Fp8E5M2FNUZ_to_Fp32 (Location loc, ConversionPatternRewriter &rewriter,
439402 const SmallVector<Value> &v) {
440403 assert (v.size () == 2 );
441- return cvtFp8ToFp32 (loc, rewriter, v[0 ], v[1 ], " bf8 " );
404+ return cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op> (loc, rewriter, v[0 ], v[1 ]);
442405}
443406
444407// Nanoo Fp8 -> Fp32 on CDNA3
445408static SmallVector<Value>
446409Fp8E4M3FNUZ_to_Fp32 (Location loc, ConversionPatternRewriter &rewriter,
447410 const SmallVector<Value> &v) {
448411 assert (v.size () == 2 );
449- return cvtFp8ToFp32 (loc, rewriter, v[0 ], v[1 ], " fp8 " );
412+ return cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op> (loc, rewriter, v[0 ], v[1 ]);
450413}
451414
452415// Depend on whether we focus more on performance, we may skip
@@ -492,7 +455,11 @@ Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
492455static SmallVector<Value>
493456Fp16_to_Fp8E5M2FNUZ_HW (Location loc, ConversionPatternRewriter &rewriter,
494457 const SmallVector<Value> &v) {
495- return convert_val_Fp16_to_Fp8 (loc, rewriter, v[0 ], v[1 ], " bf8" );
458+ auto f32_0 = cvtFp16ToFp32 (loc, rewriter, v[0 ]);
459+ auto f32_1 = cvtFp16ToFp32 (loc, rewriter, v[1 ]);
460+
461+ // Convert fp32 to bf8
462+ return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op>(loc, rewriter, f32_0, f32_1);
496463}
497464
498465ConverterT Fp16_to_Fp8E5M2FNUZ (AMD::ISAFamily isaFamily) {
@@ -698,7 +665,15 @@ Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
698665static SmallVector<Value>
699666Fp8E5M2FNUZ_to_Fp16_HW (Location loc, ConversionPatternRewriter &rewriter,
700667 const SmallVector<Value> &v) {
701- return convert_val_Fp8_to_Fp16 (loc, rewriter, v[0 ], v[1 ], " bf8" );
668+ // Convert Bf8 to fp32
669+ SmallVector<Value> ret =
670+ cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0 ], v[1 ]);
671+
672+ // Convert fp32 to fp16
673+ ret[0 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[0 ], RoundingMode::RTNE);
674+ ret[1 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[1 ], RoundingMode::RTNE);
675+
676+ return ret;
702677}
703678
704679ConverterT Fp8E5M2FNUZ_to_Fp16 (AMD::ISAFamily isaFamily) {
@@ -944,7 +919,7 @@ static SmallVector<Value>
944919Fp8E4M3FNUZ_to_Bf16 (Location loc, ConversionPatternRewriter &rewriter,
945920 const SmallVector<Value> &v) {
946921 assert (v.size () == 2 );
947- auto ret = cvtFp8ToFp32 (loc, rewriter, v[0 ], v[1 ], " fp8 " );
922+ auto ret = cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op> (loc, rewriter, v[0 ], v[1 ]);
948923 ret[0 ] = convertFp32ToBf16 (loc, rewriter, ret[0 ], RoundingMode::RTZ);
949924 ret[1 ] = convertFp32ToBf16 (loc, rewriter, ret[1 ], RoundingMode::RTZ);
950925 return ret;
@@ -957,15 +932,15 @@ Bf16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
957932 assert (v.size () == 2 );
958933 auto v0 = convertBf16ToFp32 (loc, rewriter, v[0 ]);
959934 auto v1 = convertBf16ToFp32 (loc, rewriter, v[1 ]);
960- return cvtFp32ToFp8 (loc, rewriter, v0, v1, " fp8 " );
935+ return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op> (loc, rewriter, v0, v1);
961936}
962937
963938// fp8e5m2fnuz to bf16
964939static SmallVector<Value>
965940Fp8E5M2FNUZ_to_Bf16 (Location loc, ConversionPatternRewriter &rewriter,
966941 const SmallVector<Value> &v) {
967942 assert (v.size () == 2 );
968- auto ret = cvtFp8ToFp32 (loc, rewriter, v[0 ], v[1 ], " bf8 " );
943+ auto ret = cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op> (loc, rewriter, v[0 ], v[1 ]);
969944 ret[0 ] = convertFp32ToBf16 (loc, rewriter, ret[0 ], RoundingMode::RTZ);
970945 ret[1 ] = convertFp32ToBf16 (loc, rewriter, ret[1 ], RoundingMode::RTZ);
971946 return ret;
@@ -978,7 +953,7 @@ Bf16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
978953 assert (v.size () == 2 );
979954 auto v0 = convertBf16ToFp32 (loc, rewriter, v[0 ]);
980955 auto v1 = convertBf16ToFp32 (loc, rewriter, v[1 ]);
981- return cvtFp32ToFp8 (loc, rewriter, v0, v1, " bf8 " );
956+ return cvtPkFp32ToF8<ROCDL::CvtPkBf8F32Op> (loc, rewriter, v0, v1);
982957}
983958
984959static Value Fp8E4M3FNUZ_to_Fp16_oneValue (Location loc,
@@ -1026,7 +1001,15 @@ Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter,
10261001static SmallVector<Value>
10271002Fp8E4M3FNUZ_to_Fp16_HW (Location loc, ConversionPatternRewriter &rewriter,
10281003 const SmallVector<Value> &v) {
1029- return convert_val_Fp8_to_Fp16 (loc, rewriter, v[0 ], v[1 ], " fp8" );
1004+ // Convert fp8 to fp32
1005+ SmallVector<Value> ret =
1006+ cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0 ], v[1 ]);
1007+
1008+ // Convert fp32 to fp16
1009+ ret[0 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[0 ], RoundingMode::RTNE);
1010+ ret[1 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[1 ], RoundingMode::RTNE);
1011+
1012+ return ret;
10301013}
10311014
10321015static ConverterT Fp8E4M3FNUZ_to_Fp16 (AMD::ISAFamily isaFamily) {
@@ -1082,7 +1065,11 @@ Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter,
10821065static SmallVector<Value>
10831066Fp16_to_Fp8E4M3FNUZ_HW (Location loc, ConversionPatternRewriter &rewriter,
10841067 const SmallVector<Value> &v) {
1085- return convert_val_Fp16_to_Fp8 (loc, rewriter, v[0 ], v[1 ], " fp8" );
1068+ auto f32_0 = cvtFp16ToFp32 (loc, rewriter, v[0 ]);
1069+ auto f32_1 = cvtFp16ToFp32 (loc, rewriter, v[1 ]);
1070+
1071+ // Convert fp32 to fp8
1072+ return cvtPkFp32ToF8<ROCDL::CvtPkFp8F32Op>(loc, rewriter, f32_0, f32_1);
10861073}
10871074
10881075static ConverterT Fp16_to_Fp8E4M3FNUZ (AMD::ISAFamily isaFamily) {
0 commit comments