6
6
import torch .utils .dlpack as torchdl
7
7
import torch .utils ._mode_utils as mode_utils
8
8
9
+ NUMPY_UNSUPPORTED_DTYPES = {
10
+ torch .bfloat16 : jnp .bfloat16 ,
11
+ torch .float8_e4m3fn : jnp .float8_e4m3fn ,
12
+ torch .float8_e4m3fnuz : jnp .float8_e4m3fnuz ,
13
+ torch .float8_e5m2 : jnp .float8_e5m2 ,
14
+ torch .float8_e5m2fnuz : jnp .float8_e5m2fnuz ,
15
+ }
16
+
9
17
10
18
def t2j (t , use_dlpack = True ):
11
19
is_bool = False
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
28
36
if res is None :
29
37
# https://github.com/google/jax/issues/7657
30
38
# https://github.com/google/jax/issues/17784
31
- if t .dtype == torch . bfloat16 :
39
+ if t .dtype in NUMPY_UNSUPPORTED_DTYPES :
32
40
nparray = (t .cpu ().detach ().to (torch .float32 ).numpy ()
33
- ) # numpy don't support bfloat16
41
+ ) # handle dtypes not supported by numpy
34
42
else :
35
43
nparray = t .cpu ().detach ().numpy ()
36
44
res = jnp .asarray (nparray )
37
- if t .dtype == torch . bfloat16 :
38
- res = res .astype (jnp . bfloat16 )
45
+ if t .dtype in NUMPY_UNSUPPORTED_DTYPES :
46
+ res = res .astype (NUMPY_UNSUPPORTED_DTYPES [ t . dtype ] )
39
47
40
48
if is_bool :
41
49
res = res .astype (jnp .bool_ )
0 commit comments