Skip to content

Commit f8b44e2

Browse files
authored
Create mapping for FP8 torch dtypes (#9573)
Fix a bug when using `t2j` with fp8 dtypes.
1 parent 748ac9b commit f8b44e2

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torchax/torchax/ops/mappings.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
import torch.utils.dlpack as torchdl
77
import torch.utils._mode_utils as mode_utils
88

9+
NUMPY_UNSUPPORTED_DTYPES = {
10+
torch.bfloat16: jnp.bfloat16,
11+
torch.float8_e4m3fn: jnp.float8_e4m3fn,
12+
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
13+
torch.float8_e5m2: jnp.float8_e5m2,
14+
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
15+
}
16+
917

1018
def t2j(t, use_dlpack=True):
1119
is_bool = False
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
2836
if res is None:
2937
# https://github.com/google/jax/issues/7657
3038
# https://github.com/google/jax/issues/17784
31-
if t.dtype == torch.bfloat16:
39+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
3240
nparray = (t.cpu().detach().to(torch.float32).numpy()
33-
) # numpy don't support bfloat16
41+
) # handle dtypes not supported by numpy
3442
else:
3543
nparray = t.cpu().detach().numpy()
3644
res = jnp.asarray(nparray)
37-
if t.dtype == torch.bfloat16:
38-
res = res.astype(jnp.bfloat16)
45+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
46+
res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
3947

4048
if is_bool:
4149
res = res.astype(jnp.bool_)

0 commit comments

Comments
 (0)