File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -132,6 +132,16 @@ def forward(self, x):
132
132
def test_to_jax_device (self ):
133
133
a = torch .ones (3 , 3 )
134
134
135
+ if is_tpu_available ():
136
+ # by default if tpu is available, to jax will be to tpu
137
+ e = a .to ("jax" )
138
+ self .assertEqual (e .jax_device .platform , "tpu" )
139
+ self .assertEqual (e .device .type , "jax" )
140
+ else :
141
+ e = a .to ("jax" )
142
+ self .assertEqual (e .jax_device .platform , "cpu" )
143
+ self .assertEqual (e .device .type , "jax" )
144
+
135
145
with jax_device ("cpu" ):
136
146
# move torch.tensor to torchax.tensor CPU
137
147
b = a .to ("jax" )
Original file line number Diff line number Diff line change @@ -330,7 +330,7 @@ def __init__(self, configuration=None):
330
330
self ._prng_key = mutable_array (
331
331
jax .random .key (torch .initial_seed () % (1 << 63 )))
332
332
self .autocast_dtype = None
333
- self ._target_device = "cpu"
333
+ self ._target_device = jax . local_devices ()[ 0 ]. platform
334
334
335
335
@property
336
336
def target_device (self ):
You can’t perform that action at this time.
0 commit comments