|
11 | 11 | float8e5m2fnuz, |
12 | 12 | ) |
13 | 13 | except ImportError: |
| 14 | + bfloat16 = None |
14 | 15 | from onnx.reference.ops.op_cast import cast_to |
15 | 16 | from ...helpers.onnx_helper import np_dtype_to_tensor_dtype |
16 | 17 |
|
17 | 18 |
|
18 | 19 | def _cast_like(x, y, saturate): |
19 | | - if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": |
20 | | - # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 |
21 | | - to = TensorProto.BFLOAT16 |
22 | | - elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": |
23 | | - to = TensorProto.FLOAT8E4M3FN |
24 | | - elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": |
25 | | - to = TensorProto.FLOAT8E4M3FNUZ |
26 | | - elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": |
27 | | - to = TensorProto.FLOAT8E5M2 |
28 | | - elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": |
29 | | - to = TensorProto.FLOAT8E5M2FNUZ |
| 20 | + if bfloat16 is not None: |
| 21 | + if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": |
| 22 | + # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 |
| 23 | + to = TensorProto.BFLOAT16 |
| 24 | + elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": |
| 25 | + to = TensorProto.FLOAT8E4M3FN |
| 26 | + elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": |
| 27 | + to = TensorProto.FLOAT8E4M3FNUZ |
| 28 | + elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": |
| 29 | + to = TensorProto.FLOAT8E5M2 |
| 30 | + elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": |
| 31 | + to = TensorProto.FLOAT8E5M2FNUZ |
| 32 | + else: |
| 33 | + to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore |
30 | 34 | else: |
31 | 35 | to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore |
32 | 36 | return (cast_to(x, to, saturate),) |
|
0 commit comments