Skip to content

Commit 2a650c2

Browse files
authored
[AMD] Add fp32<->OCP fp8/bf8 conversions on mi350 (#6110)
Implemented type conversions between fp32 and OCP fp8/bf8 using ROCDL intrinsic wrappers.
1 parent 5049304 commit 2a650c2

File tree

2 files changed

+131
-20
lines changed

2 files changed

+131
-20
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350
1111

1212

1313
def matching_int(dtype):
@@ -272,7 +272,8 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
272272
])
273273
def test_typeconvert_upcast(src_dtype, dst_dtype, device):
274274

275-
# On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300.
275+
# On HIP, fp8e4nv upcasting to fp32 is only supported on MI350, and
276+
# fp8e4nv upcasting to bf16 and fp16 is only supported on MI300 and MI350.
276277
if is_cuda():
277278
if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9))
278279
or src_dtype in ('float8e4b8', 'float8e5b16')):
@@ -281,10 +282,11 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
281282
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
282283
return
283284
elif is_hip():
284-
if src_dtype == 'float8e4nv' and (
285-
dst_dtype == 'float32' or ((dst_dtype in ('bfloat16')) and not is_hip_mi300())):
285+
if src_dtype == 'float8e4nv' and dst_dtype == 'float32' and not is_hip_mi350():
286286
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
287-
if (src_dtype in ('float8e4b15') or
287+
if (src_dtype == 'float8e4nv' and (not is_hip_mi300() or not is_hip_mi350())):
288+
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
289+
if (src_dtype in ('float8e4b15') or
288290
(src_dtype in ('float8e4b8', 'float8e5b16') and not is_hip_mi300())):
289291
# If the dtype should error out in the given device, we assert that and return
290292
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
@@ -341,11 +343,14 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
341343
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300")
342344

343345
if is_hip():
344-
if dst_dtype == 'float8e5' and rounding == 'rtne':
345-
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
346-
347-
if dst_dtype == 'float8e4nv' and not (src_dtype == 'float16' and rounding == 'rtne' and is_hip_mi300()):
348-
pytest.skip("float8e4nv downcast tests only supported from float16, with RTNE rounding, and on AMDGPU MI300")
346+
if dst_dtype == 'float8e5' and rounding == 'rtne' and not (src_dtype == 'float32' and is_hip_mi350()):
347+
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported from float32, and on MI350")
348+
349+
if dst_dtype == 'float8e4nv':
350+
if not rounding == 'rtne':
351+
pytest.skip("float8e4nv downcast tests only supported with RTNE rounding on AMDGPU")
352+
if not (is_hip_mi300() and src_dtype == 'float16' or is_hip_mi350() and src_dtype == 'float32'):
353+
pytest.skip("float8e4nv downcast tests only supported from float16, with RTNE rounding on AMDGPU MI300")
349354

350355
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_mi300():
351356
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300")

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ Fp16_to_Fp8E4M3FN_RTNE(Location loc, ConversionPatternRewriter &rewriter,
182182

183183
static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
184184
const Value &v) {
185+
185186
TritonLLVMOpBuilder b(loc, rewriter);
186187
return b.fpext(f32_ty, v);
187188
}
@@ -259,7 +260,6 @@ convert_val_Fp16_to_Fp8(Location loc, ConversionPatternRewriter &rewriter,
259260
static SmallVector<Value>
260261
convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
261262
Value v0, Value v1, const std::string &fp8_format) {
262-
263263
// Convert fp8 to fp32
264264
SmallVector<Value> ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format);
265265

@@ -270,6 +270,82 @@ convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
270270
return ret;
271271
}
272272

273+
template <typename convertOp>
274+
static SmallVector<Value> cvtScaleFp8ToFp32(Location loc,
275+
ConversionPatternRewriter &rewriter,
276+
Value v0, Value v1) {
277+
auto b = TritonLLVMOpBuilder(loc, rewriter);
278+
auto fp8x4VecTy = vec_ty(i8_ty, 4);
279+
Value fp8x4Vec = b.undef(fp8x4VecTy);
280+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, b.i32_val(0));
281+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, b.i32_val(1));
282+
auto i32v = b.bitcast(fp8x4Vec, i32_ty);
283+
284+
Value scale = b.f32_val(1);
285+
Value select = b.false_val();
286+
auto result = rewriter.create<convertOp>(loc, i64_ty, i32v, scale, select);
287+
auto f32x2VecTy = vec_ty(f32_ty, 2);
288+
auto f32x2Vec = b.bitcast(result, f32x2VecTy);
289+
SmallVector<Value> ret(2);
290+
auto idx0 = b.i32_val(0);
291+
auto idx1 = b.i32_val(1);
292+
ret[0] = b.extract_element(f32_ty, f32x2Vec, idx0);
293+
ret[1] = b.extract_element(f32_ty, f32x2Vec, idx1);
294+
return ret;
295+
}
296+
297+
static SmallVector<Value> Fp8E4M3FN_to_Fp32(Location loc,
298+
ConversionPatternRewriter &rewriter,
299+
const SmallVector<Value> &v) {
300+
assert(v.size() == 2);
301+
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Fp8>(loc, rewriter, v[0], v[1]);
302+
}
303+
304+
static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
305+
ConversionPatternRewriter &rewriter,
306+
const SmallVector<Value> &v) {
307+
assert(v.size() == 2);
308+
return cvtScaleFp8ToFp32<ROCDL::CvtScalePkF32Bf8>(loc, rewriter, v[0], v[1]);
309+
}
310+
311+
template <typename convertOp>
312+
static SmallVector<Value> cvtScaleFp32ToFp8(Location loc,
313+
ConversionPatternRewriter &rewriter,
314+
Value v0, Value v1) {
315+
auto b = TritonLLVMOpBuilder(loc, rewriter);
316+
Type v2I16Ty = vec_ty(i16_ty, 2);
317+
Value v2I16Vec = b.undef(v2I16Ty);
318+
Value scale = b.f32_val(1);
319+
Value select = b.false_val();
320+
Value result;
321+
result =
322+
rewriter.create<convertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale, select);
323+
auto fp8x4VecTy = vec_ty(i8_ty, 4);
324+
auto fp8x4Vec = b.bitcast(result, fp8x4VecTy);
325+
SmallVector<Value> ret(2);
326+
auto idx0 = b.i32_val(0);
327+
auto idx1 = b.i32_val(1);
328+
ret[0] = b.extract_element(i8_ty, fp8x4Vec, idx0);
329+
ret[1] = b.extract_element(i8_ty, fp8x4Vec, idx1);
330+
return ret;
331+
}
332+
333+
static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
334+
ConversionPatternRewriter &rewriter,
335+
const SmallVector<Value> &v) {
336+
assert(v.size() == 2);
337+
return cvtScaleFp32ToFp8<ROCDL::CvtScaleF32PkFp8F32>(loc, rewriter, v[0],
338+
v[1]);
339+
}
340+
341+
static SmallVector<Value> Fp32_to_Fp8E5M2(Location loc,
342+
ConversionPatternRewriter &rewriter,
343+
const SmallVector<Value> &v) {
344+
assert(v.size() == 2);
345+
return cvtScaleFp32ToFp8<ROCDL::CvtScaleF32PkBf8F32>(loc, rewriter, v[0],
346+
v[1]);
347+
}
348+
273349
static SmallVector<Value>
274350
Fp32_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
275351
const SmallVector<Value> &v) {
@@ -950,8 +1026,12 @@ struct FpToFpOpConversion
9501026
Fp32_to_Fp8E4M3FNUZ},
9511027
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
9521028
Fp32_to_Fp8E5M2FNUZ},
1029+
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1030+
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2},
9531031
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
9541032
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
1033+
{{F8E4M3FNTyID, F32TyID, undefRounding}, Fp8E4M3FN_to_Fp32},
1034+
{{F8E5M2TyID, F32TyID, undefRounding}, Fp8E5M2_to_Fp32},
9551035
};
9561036
std::tuple<TypeID, TypeID, RoundingMode> key = {
9571037
srcTy.getTypeID(), dstTy.getTypeID(),
@@ -969,8 +1049,8 @@ struct FpToFpOpConversion
9691049
auto b = TritonLLVMOpBuilder(loc, rewriter);
9701050
auto srcElementType = getElementType(op.getSrc());
9711051
auto dstElementType = getElementType(op.getResult());
972-
auto roundingMode = op.getRounding();
9731052

1053+
auto roundingMode = op.getRounding();
9741054
if (srcElementType.isF32() && dstElementType.isF16()) {
9751055
assert(roundingMode.has_value() &&
9761056
"rounding mode must be specified for fp32->fp16 conversion");
@@ -994,20 +1074,46 @@ struct FpToFpOpConversion
9941074
}
9951075
return outVals;
9961076
}
997-
size_t numElements = 4;
998-
if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
999-
srcElementType) ||
1000-
llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
1001-
dstElementType)) {
1002-
numElements = 2;
1077+
1078+
// numElements = 4 for conversions:
1079+
// ocp bf8->fp16, ocp bf8->bf16, ocp bf8->fp32 on non-CDNA4
1080+
// fp16->ocp bf8, bf16->ocp bf8, fp32->ocp bf8 on non-CDNA4
1081+
size_t numElements = 2;
1082+
if (llvm::isa<Float8E5M2Type>(srcElementType) &&
1083+
!llvm::isa<Float32Type>(dstElementType) ||
1084+
llvm::isa<Float8E5M2Type>(srcElementType) &&
1085+
isaFamily != AMD::ISAFamily::CDNA4 ||
1086+
!llvm::isa<Float32Type>(srcElementType) &&
1087+
llvm::isa<Float8E5M2Type>(dstElementType) ||
1088+
llvm::isa<Float32Type>(srcElementType) &&
1089+
llvm::isa<Float8E5M2Type>(dstElementType) &&
1090+
isaFamily != AMD::ISAFamily::CDNA4) {
1091+
numElements = 4;
10031092
}
1093+
1094+
// f32->fp8/bf8, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1095+
// done in two steps: f32->fp16 with rtne and fp16->fp8/bf8 with rtne
10041096
bool useFP16IntermediateSrc =
10051097
srcElementType.isF32() &&
1098+
!(isaFamily == AMD::ISAFamily::CDNA4 &&
1099+
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType)) &&
1100+
roundingMode == RoundingMode::RTNE) &&
10061101
!(isaFamily == AMD::ISAFamily::CDNA3 &&
10071102
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1103+
1104+
// fp8/bf8->f32, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1105+
// done in two steps: fp8/bf8->fp16 and fp16->fp32
10081106
bool isDstFP32 = dstElementType.isF32();
1107+
bool useFP16IntermediateDst =
1108+
(isDstFP32 &&
1109+
!(isaFamily == AMD::ISAFamily::CDNA4 &&
1110+
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementType))) &&
1111+
!(isaFamily == AMD::ISAFamily::CDNA3 &&
1112+
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1113+
srcElementType))));
1114+
10091115
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
1010-
Type dstType = isDstFP32 ? f16_ty : dstElementType;
1116+
Type dstType = useFP16IntermediateDst ? f16_ty : dstElementType;
10111117
SmallVector<Value> inVals;
10121118
inVals.reserve(std::min(numElements, operands.size()));
10131119
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
@@ -1052,7 +1158,7 @@ struct FpToFpOpConversion
10521158

10531159
assert(outVals.size() == inVals.size());
10541160
outVals.resize(std::min(numElements, operands.size()));
1055-
if (isDstFP32)
1161+
if (isDstFP32 && dstType == f16_ty)
10561162
for (Value &v : outVals)
10571163
v = convertFp16ToFp32(loc, rewriter, v);
10581164
// Pack values

0 commit comments

Comments
 (0)