@@ -37,30 +37,25 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3737 fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, idx1);
3838 auto i32v = b.bitcast (fp8x4Vec, i32_ty);
3939
40- auto resType = i32_ty;
41- auto dstType = f32_ty;
40+ Type resElemType;
4241 if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Fp8Op> ||
4342 std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
44- resType = i64_ty;
45- dstType = f32_ty;
43+ resElemType = f32_ty;
4644 } else if constexpr (std::is_same_v<ConvertOp,
4745 ROCDL::CvtScaleF32PkF16Fp8Op> ||
4846 std::is_same_v<ConvertOp,
4947 ROCDL::CvtScaleF32PkF16Bf8Op>) {
50- resType = i32_ty;
51- dstType = f16_ty;
48+ resElemType = f16_ty;
5249 } else {
53- resType = i32_ty;
54- dstType = bf16_ty;
50+ resElemType = bf16_ty;
5551 }
52+ Type resType = vec_ty (resElemType, 2 );
5653 Value scale = b.f32_val (1 );
57- Value select = b.false_val ();
58- auto result = rewriter.create <ConvertOp>(loc, resType, i32v, scale, select);
59- auto retVecTy = vec_ty (dstType, 2 );
60- auto retVec = b.bitcast (result, retVecTy);
54+ auto result = rewriter.create <ConvertOp>(loc, resType, i32v, scale,
55+ /* srcLoHiSel=*/ false );
6156 SmallVector<Value> ret (2 );
62- ret[0 ] = b.extract_element (dstType, retVec , idx0);
63- ret[1 ] = b.extract_element (dstType, retVec , idx1);
57+ ret[0 ] = b.extract_element (resElemType, result , idx0);
58+ ret[1 ] = b.extract_element (resElemType, result , idx1);
6459 return ret;
6560}
6661
@@ -73,13 +68,12 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7368 Type v2I16Ty = vec_ty (i16_ty, 2 );
7469 Value v2I16Vec = b.undef (v2I16Ty);
7570 Value scale = b.f32_val (1 );
76- Value select = b.false_val ();
7771
7872 Value result;
7973 if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkFp8F32Op> ||
8074 std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkBf8F32Op>) {
8175 result = rewriter.create <ConvertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
82- select );
76+ /* dstLoHiSel= */ false );
8377 } else {
8478 Type v2F16Ty = vec_ty (v0.getType (), 2 );
8579 Value srcVec = b.undef (v2F16Ty);
@@ -88,7 +82,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
8882 srcVec = b.insert_element (v2F16Ty, srcVec, v0, idx0);
8983 srcVec = b.insert_element (v2F16Ty, srcVec, v1, idx1);
9084 result = rewriter.create <ConvertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
91- select );
85+ /* dstLoHiSel= */ false );
9286 }
9387 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
9488 auto fp8x4Vec = b.bitcast (result, fp8x4VecTy);
@@ -312,8 +306,8 @@ static SmallVector<Value> cvtPkF8ToFp32(Location loc,
312306 auto resType = i64_ty;
313307 auto dstType = f32_ty;
314308
315- Value select = b. false_val ();
316- auto result = rewriter.create <ConvertOp>(loc, resType, i32v, select );
309+ auto result =
310+ rewriter.create <ConvertOp>(loc, resType, i32v, /* wordSel= */ false );
317311 auto f32x2VecTy = vec_ty (dstType, 2 );
318312 auto retVec = b.bitcast (result, f32x2VecTy);
319313 SmallVector<Value> ret (2 );
@@ -330,10 +324,10 @@ static SmallVector<Value> cvtPkFp32ToF8(Location loc,
330324 auto b = TritonLLVMOpBuilder (loc, rewriter);
331325 Type v2I16Ty = vec_ty (i16_ty, 2 );
332326 Value old = b.undef (i32_ty);
333- Value select = b.false_val ();
334327
335328 Value result;
336- result = rewriter.create <ConvertOp>(loc, v2I16Ty, v0, v1, old, select);
329+ result =
330+ rewriter.create <ConvertOp>(loc, v2I16Ty, v0, v1, old, /* wordSel=*/ false );
337331 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
338332 auto fp8x4Vec = b.bitcast (result, fp8x4VecTy);
339333 SmallVector<Value> ret (2 );
0 commit comments