Skip to content

Commit 1c926f7

Browse files
authored
By default, to("jax") should go to TPU (#9468)
1 parent cf156c6 commit 1c926f7

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

torchax/test/test_interop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def forward(self, x):
132132
def test_to_jax_device(self):
133133
a = torch.ones(3, 3)
134134

135+
if is_tpu_available():
136+
# by default if tpu is available, to jax will be to tpu
137+
e = a.to("jax")
138+
self.assertEqual(e.jax_device.platform, "tpu")
139+
self.assertEqual(e.device.type, "jax")
140+
else:
141+
e = a.to("jax")
142+
self.assertEqual(e.jax_device.platform, "cpu")
143+
self.assertEqual(e.device.type, "jax")
144+
135145
with jax_device("cpu"):
136146
# move torch.tensor to torchax.tensor CPU
137147
b = a.to("jax")

torchax/torchax/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(self, configuration=None):
330330
self._prng_key = mutable_array(
331331
jax.random.key(torch.initial_seed() % (1 << 63)))
332332
self.autocast_dtype = None
333-
self._target_device = "cpu"
333+
self._target_device = jax.local_devices()[0].platform
334334

335335
@property
336336
def target_device(self):

0 commit comments

Comments
 (0)