Skip to content

Commit abf18e4

Browse files
authored
Fix case when both device & dtype are given in .to (#9583)
1 parent e9a1c5f commit abf18e4

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

torchax/test/test_misc.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""If you don't know which file a test should go, and don't want to make a new file
2+
for a small test. PUt it here
3+
"""
4+
import torch
5+
import unittest
6+
import torchax
7+
import jax
8+
import jax.numpy as jnp
9+
10+
11+
class MiscTest(unittest.TestCase):
12+
13+
def test_extract_jax_kwargs(self):
14+
15+
class M(torch.nn.Module):
16+
17+
def forward(self, a, b):
18+
return torch.sin(a) + torch.cos(b)
19+
20+
weights, func = torchax.extract_jax(M())
21+
res = func(
22+
weights,
23+
args=(),
24+
kwargs={
25+
'a': jnp.array([1, 2, 3]),
26+
'b': jnp.array([3, 4, 5])
27+
})
28+
self.assertTrue(
29+
jnp.allclose(
30+
res,
31+
jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5]))))
32+
33+
def test_to_device(self):
34+
env = torchax.default_env()
35+
env.config.debug_print_each_op = True
36+
with env:
37+
step1 = torch.ones(
38+
100,
39+
100,
40+
)
41+
step2 = torch.triu(step1, diagonal=1)
42+
step3 = step2.to(dtype=torch.bool, device='jax')
43+
self.assertEqual(step3.device.type, 'jax')
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

torchax/torchax/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,12 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
469469
arr = self.t2j_copy(the_tensor)
470470
res = Tensor(arr, self, the_tensor.requires_grad)
471471

472-
if new_dtype is not None and new_dtype != the_tensor.dtype:
473-
if isinstance(the_tensor, Tensor):
472+
if new_dtype is not None and new_dtype != res.dtype:
473+
if isinstance(res, Tensor):
474474
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
475475
else:
476476
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477-
return the_tensor.to(device=new_device, dtype=new_dtype)
477+
return res.to(device=new_device, dtype=new_dtype)
478478
return res
479479

480480
def get_and_rotate_prng_key(self,

0 commit comments

Comments
 (0)