@@ -182,6 +182,7 @@ Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
182182
183183static Value cvtFp16ToFp32 (Location loc, ConversionPatternRewriter &rewriter,
184184 const Value &v) {
185+
185186 TritonLLVMOpBuilder b (loc, rewriter);
186187 return b.fpext (f32_ty, v);
187188}
@@ -259,7 +260,6 @@ convert_val_Fp16_to_Fp8(Location loc, ConversionPatternRewriter &rewriter,
259260static SmallVector<Value>
260261convert_val_Fp8_to_Fp16 (Location loc, ConversionPatternRewriter &rewriter,
261262 Value v0, Value v1, const std::string &fp8_format) {
262-
263263 // Convert fp8 to fp32
264264 SmallVector<Value> ret = cvtFp8ToFp32 (loc, rewriter, v0, v1, fp8_format);
265265
@@ -270,6 +270,82 @@ convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
270270 return ret;
271271}
272272
273+ template <typename convertOp>
274+ static SmallVector<Value> cvtScaleFp8ToFp32 (Location loc,
275+ ConversionPatternRewriter &rewriter,
276+ Value v0, Value v1) {
277+ auto b = TritonLLVMOpBuilder (loc, rewriter);
278+ auto fp8x4VecTy = vec_ty (i8_ty, 4 );
279+ Value fp8x4Vec = b.undef (fp8x4VecTy);
280+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v0, b.i32_val (0 ));
281+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, b.i32_val (1 ));
282+ auto i32v = b.bitcast (fp8x4Vec, i32_ty);
283+
284+ Value scale = b.f32_val (1 );
285+ Value select = b.false_val ();
286+ auto result = rewriter.create <convertOp>(loc, i64_ty, i32v, scale, select);
287+ auto f32x2VecTy = vec_ty (f32_ty, 2 );
288+ auto f32x2Vec = b.bitcast (result, f32x2VecTy);
289+ SmallVector<Value> ret (2 );
290+ auto idx0 = b.i32_val (0 );
291+ auto idx1 = b.i32_val (1 );
292+ ret[0 ] = b.extract_element (f32_ty, f32x2Vec, idx0);
293+ ret[1 ] = b.extract_element (f32_ty, f32x2Vec, idx1);
294+ return ret;
295+ }
296+
297+ static SmallVector<Value> Fp8E4M3FN_to_Fp32 (Location loc,
298+ ConversionPatternRewriter &rewriter,
299+ const SmallVector<Value> &v) {
300+ assert (v.size () == 2 );
301+ return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Fp8>(loc, rewriter, v[0 ], v[1 ]);
302+ }
303+
304+ static SmallVector<Value> Fp8E5M2_to_Fp32 (Location loc,
305+ ConversionPatternRewriter &rewriter,
306+ const SmallVector<Value> &v) {
307+ assert (v.size () == 2 );
308+ return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Bf8>(loc, rewriter, v[0 ], v[1 ]);
309+ }
310+
311+ template <typename convertOp>
312+ static SmallVector<Value> cvtScaleFp32ToFp8 (Location loc,
313+ ConversionPatternRewriter &rewriter,
314+ Value v0, Value v1) {
315+ auto b = TritonLLVMOpBuilder (loc, rewriter);
316+ Type v2I16Ty = vec_ty (i16_ty, 2 );
317+ Value v2I16Vec = b.undef (v2I16Ty);
318+ Value scale = b.f32_val (1 );
319+ Value select = b.false_val ();
320+ Value result;
321+ result =
322+ rewriter.create <convertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale, select);
323+ auto fp8x4VecTy = vec_ty (i8_ty, 4 );
324+ auto fp8x4Vec = b.bitcast (result, fp8x4VecTy);
325+ SmallVector<Value> ret (2 );
326+ auto idx0 = b.i32_val (0 );
327+ auto idx1 = b.i32_val (1 );
328+ ret[0 ] = b.extract_element (i8_ty, fp8x4Vec, idx0);
329+ ret[1 ] = b.extract_element (i8_ty, fp8x4Vec, idx1);
330+ return ret;
331+ }
332+
333+ static SmallVector<Value> Fp32_to_Fp8E4M3FN (Location loc,
334+ ConversionPatternRewriter &rewriter,
335+ const SmallVector<Value> &v) {
336+ assert (v.size () == 2 );
337+ return cvtScaleFp32ToFp8<ROCDL::CvtScaleF32PkFp8F32>(loc, rewriter, v[0 ],
338+ v[1 ]);
339+ }
340+
341+ static SmallVector<Value> Fp32_to_Fp8E5M2 (Location loc,
342+ ConversionPatternRewriter &rewriter,
343+ const SmallVector<Value> &v) {
344+ assert (v.size () == 2 );
345+ return cvtScaleFp32ToFp8<ROCDL::CvtScaleF32PkBf8F32>(loc, rewriter, v[0 ],
346+ v[1 ]);
347+ }
348+
273349static SmallVector<Value>
274350Fp32_to_Fp8E5M2FNUZ (Location loc, ConversionPatternRewriter &rewriter,
275351 const SmallVector<Value> &v) {
@@ -950,8 +1026,12 @@ struct FpToFpOpConversion
9501026 Fp32_to_Fp8E4M3FNUZ},
9511027 {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
9521028 Fp32_to_Fp8E5M2FNUZ},
1029+ {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1030+ {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2},
9531031 {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
9541032 {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
1033+ {{F8E4M3FNTyID, F32TyID, undefRounding}, Fp8E4M3FN_to_Fp32},
1034+ {{F8E5M2TyID, F32TyID, undefRounding}, Fp8E5M2_to_Fp32},
9551035 };
9561036 std::tuple<TypeID, TypeID, RoundingMode> key = {
9571037 srcTy.getTypeID (), dstTy.getTypeID (),
@@ -969,8 +1049,8 @@ struct FpToFpOpConversion
9691049 auto b = TritonLLVMOpBuilder (loc, rewriter);
9701050 auto srcElementType = getElementType (op.getSrc ());
9711051 auto dstElementType = getElementType (op.getResult ());
972- auto roundingMode = op.getRounding ();
9731052
1053+ auto roundingMode = op.getRounding ();
9741054 if (srcElementType.isF32 () && dstElementType.isF16 ()) {
9751055 assert (roundingMode.has_value () &&
9761056 " rounding mode must be specified for fp32->fp16 conversion" );
@@ -994,20 +1074,46 @@ struct FpToFpOpConversion
9941074 }
9951075 return outVals;
9961076 }
997- size_t numElements = 4 ;
998- if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
999- srcElementType) ||
1000- llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1001- dstElementType)) {
1002- numElements = 2 ;
1077+
1078+ // numElements = 4 for conversions:
1079+ // ocp bf8->fp16, ocp bf8->bf16, ocp bf8->fp32 on non-CDNA4
1080+ // fp16->ocp bf8, bf16->ocp bf8, fp32->ocp bf8 on non-CDNA4
1081+ size_t numElements = 2 ;
1082+ if (llvm::isa<Float8E5M2Type>(srcElementType) &&
1083+ !llvm::isa<Float32Type>(dstElementType) ||
1084+ llvm::isa<Float8E5M2Type>(srcElementType) &&
1085+ isaFamily != AMD::ISAFamily::CDNA4 ||
1086+ !llvm::isa<Float32Type>(srcElementType) &&
1087+ llvm::isa<Float8E5M2Type>(dstElementType) ||
1088+ llvm::isa<Float32Type>(srcElementType) &&
1089+ llvm::isa<Float8E5M2Type>(dstElementType) &&
1090+ isaFamily != AMD::ISAFamily::CDNA4) {
1091+ numElements = 4 ;
10031092 }
1093+
1094+ // f32->fp8/bf8, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1095+ // done in two steps: f32->fp16 with rtne and fp16->fp8/bf8 with rtne
10041096 bool useFP16IntermediateSrc =
10051097 srcElementType.isF32 () &&
1098+ !(isaFamily == AMD::ISAFamily::CDNA4 &&
1099+ (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType)) &&
1100+ roundingMode == RoundingMode::RTNE) &&
10061101 !(isaFamily == AMD::ISAFamily::CDNA3 &&
10071102 (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1103+
1104+ // fp8/bf8->f32, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1105+ // done in two steps: fp8/bf8->fp16 and fp16->fp32
10081106 bool isDstFP32 = dstElementType.isF32 ();
1107+ bool useFP16IntermediateDst =
1108+ (isDstFP32 &&
1109+ !(isaFamily == AMD::ISAFamily::CDNA4 &&
1110+ (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementType))) &&
1111+ !(isaFamily == AMD::ISAFamily::CDNA3 &&
1112+ (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1113+ srcElementType))));
1114+
10091115 Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
1010- Type dstType = isDstFP32 ? f16_ty : dstElementType;
1116+ Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
10111117 SmallVector<Value> inVals;
10121118 inVals.reserve (std::min (numElements, operands.size ()));
10131119 for (unsigned i = 0 ; i < std::min (numElements, operands.size ()); i++) {
@@ -1052,7 +1158,7 @@ struct FpToFpOpConversion
10521158
10531159 assert (outVals.size () == inVals.size ());
10541160 outVals.resize (std::min (numElements, operands.size ()));
1055- if (isDstFP32)
1161+ if (isDstFP32 && dstType == f16_ty )
10561162 for (Value &v : outVals)
10571163 v = convertFp16ToFp32 (loc, rewriter, v);
10581164 // Pack values
0 commit comments