Skip to content

Commit 1f787f1

Browse files
authored
Bug fixes (#9554)
1 parent 23158fd commit 1f787f1

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

torchax/test/test_mutations.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,44 @@ class TestMutations(TestCase):
88

99
def setUp(self):
1010
self.env = torchax.tensor.Environment()
11+
self.env.config.debug_print_each_op = True
1112

1213
def test_add(self):
1314
with self.env:
14-
x = torch.tensor([1, 2, 3], dtype=torch.int32)
15-
y = torch.tensor([4, 5, 6], dtype=torch.int32)
15+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
16+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
1617
x.add_(y)
17-
self.assertEqual(x, torch.tensor([5, 7, 9], dtype=torch.int32))
18+
torch.testing.assert_close(x.cpu(),
19+
torch.tensor([5, 7, 9], dtype=torch.int32))
1820

1921
def test_sub(self):
2022
with self.env:
21-
x = torch.tensor([1, 2, 3], dtype=torch.int32)
22-
y = torch.tensor([4, 5, 6], dtype=torch.int32)
23+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
24+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
2325
x.sub_(y)
24-
self.assertEqual(x, torch.tensor([-3, -3, -3], dtype=torch.int32))
26+
torch.testing.assert_close(x.cpu(),
27+
torch.tensor([-3, -3, -3], dtype=torch.int32))
2528

2629
def test_mul(self):
2730
with self.env:
28-
x = torch.tensor([1, 2, 3], dtype=torch.int32)
29-
y = torch.tensor([4, 5, 6], dtype=torch.int32)
31+
x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32)
32+
y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32)
3033

3134
x.mul_(y)
32-
self.assertEqual(x, torch.tensor([4, 10, 18], dtype=torch.int32))
35+
torch.testing.assert_close(x.cpu(),
36+
torch.tensor([4, 10, 18], dtype=torch.int32))
37+
38+
def test_index_copy(self):
39+
with self.env:
40+
x = torch.zeros(5, 3, device='jax')
41+
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
42+
device='jax',
43+
dtype=torch.float)
44+
index = torch.tensor([0, 4, 2], device='jax')
45+
x.index_copy_(0, index, t)
46+
expected = torch.tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.],
47+
[0., 0., 0.], [4., 5., 6.]])
48+
torch.testing.assert_close(x.cpu(), expected)
3349

3450

3551
if __name__ == '__main__':

torchax/torchax/ops/jaten.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
736736
return jnp.empty(sizes, dtype=dtype)
737737

738738

739-
@op(torch.ops.aten.index_put_)
740739
@op(torch.ops.aten.index_put)
741740
def _aten_index_put(self, indexes, values, accumulate=False):
742741
indexes = [slice(None, None, None) if i is None else i for i in indexes]
@@ -5618,6 +5617,8 @@ def _aten__assert_tensor_metadata(*args, **kwargs):
56185617
op_base.InplaceOp(torch.ops.aten.floor_divide),
56195618
torch.ops.aten.remainder_:
56205619
op_base.InplaceOp(torch.ops.aten.remainder),
5620+
torch.ops.aten.index_put_:
5621+
op_base.InplaceOp(torch.ops.aten.index_put),
56215622
}
56225623

56235624
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.

torchax/torchax/tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def __str__(self):
7070

7171
__repr__ = __str__
7272

73-
def __jax_array__(self):
74-
return self._elem
75-
7673
@property
7774
def shape(self):
7875
return torch.Size(self._elem.shape)
@@ -494,6 +491,8 @@ def _handle_tensor_constructor(self, func, args, kwargs):
494491
op = self._get_op_or_decomp(func)
495492
if op.needs_env:
496493
kwargs['env'] = self
494+
if op.is_jax_function:
495+
(args, kwargs) = self.t2j_iso((args, kwargs))
497496
res = op.func(*args, **kwargs)
498497
if isinstance(res, jax.Array):
499498
res = Tensor(res, self, requires_grad)

0 commit comments

Comments
 (0)