Skip to content

Commit cb5504b

Browse files
committed
fix a couple of issues
1 parent 50d2d8b commit cb5504b

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def test_size_type_onnx(self):
333333
"FLOAT8E5M2",
334334
"FLOAT8E4M3FN",
335335
"FLOAT8E4M3FNUZ",
336+
"FLOAT8E8M0",
336337
}:
337338
onnx_dtype_to_torch_dtype(i)
338339
onnx_dtype_to_np_dtype(i)

onnx_diagnostic/helpers/helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ def size_type(dtype: Any) -> int:
3636
TensorProto.FLOAT8E4M3FNUZ,
3737
TensorProto.FLOAT8E5M2,
3838
TensorProto.FLOAT8E5M2FNUZ,
39+
getattr(TensorProto, "FLOAT8E8M0", None),
3940
}:
4041
return 1
4142
if dtype in {TensorProto.COMPLEX128}:
4243
return 16
43-
from .helpers.onnx_helper import onnx_dtype_name
44+
from .onnx_helper import onnx_dtype_name
4445

4546
raise AssertionError(
4647
f"Unable to return the element size for type {onnx_dtype_name(dtype)}"

onnx_diagnostic/reference/ops/op_cast_like.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,26 @@
1111
float8e5m2fnuz,
1212
)
1313
except ImportError:
14+
bfloat16 = None
1415
from onnx.reference.ops.op_cast import cast_to
1516
from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
1617

1718

1819
def _cast_like(x, y, saturate):
19-
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
20-
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
21-
to = TensorProto.BFLOAT16
22-
elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
23-
to = TensorProto.FLOAT8E4M3FN
24-
elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
25-
to = TensorProto.FLOAT8E4M3FNUZ
26-
elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
27-
to = TensorProto.FLOAT8E5M2
28-
elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
29-
to = TensorProto.FLOAT8E5M2FNUZ
20+
if bfloat16 is not None:
21+
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
22+
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
23+
to = TensorProto.BFLOAT16
24+
elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
25+
to = TensorProto.FLOAT8E4M3FN
26+
elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
27+
to = TensorProto.FLOAT8E4M3FNUZ
28+
elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
29+
to = TensorProto.FLOAT8E5M2
30+
elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
31+
to = TensorProto.FLOAT8E5M2FNUZ
32+
else:
33+
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
3034
else:
3135
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
3236
return (cast_to(x, to, saturate),)

0 commit comments

Comments
 (0)