Skip to content

Commit 01e579c

Browse files
authored
Update torch compat version to 2.7.1 (#9455)
1 parent 83dc9da commit 01e579c

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

torchax/dev-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-f https://download.pytorch.org/whl/torch
2-
torch==2.6.0 ; sys_platform == 'darwin' # macOS
3-
torch==2.6.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
2+
torch==2.7.1 ; sys_platform == 'darwin' # macOS
3+
torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
44
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
55
flax==0.10.6

torchax/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ authors = [
1212
{name = "Han Qi", email = "[email protected]"},
1313
{name = "Pytorch/XLA team", email = "[email protected]"},
1414
]
15-
description = "torchax is a library for running PyTorch on TPU"
15+
description = "torchax is a library for running Jax and PyTorch together"
1616
readme = "README.md"
1717
classifiers = [
1818
"Development Status :: 3 - Alpha",

torchax/torchax/decompositions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,13 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
761761
torch.ops.aten._chunk_cat.out,
762762
torch.ops.aten._weight_norm_interface.default,
763763
torch.ops.aten._weight_norm_interface.out,
764+
torch.ops.aten.__iand__.Tensor,
765+
torch.ops.aten.__ixor__.Tensor,
766+
torch.ops.aten.__ilshift__.Tensor,
767+
torch.ops.aten.__ilshift__.Scalar,
768+
torch.ops.aten.__irshift__.Tensor,
769+
torch.ops.aten.__irshift__.Scalar,
770+
torch.ops.aten.__ior__.Tensor,
764771
])
765772

766773
MUTABLE_DECOMPOSITION = [

torchax/torchax/ops/jaten.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
375375
return self[tuple(dims)]
376376

377377

378+
@op(torch.ops.aten.positive)
378379
@op(torch.ops.aten.detach)
379380
def _aten_detach(self):
380381
return self
@@ -2951,12 +2952,14 @@ def _aten_log2(x):
29512952

29522953
# aten.logical_and
29532954
@op(torch.ops.aten.logical_and)
2955+
@op(torch.ops.aten.__and__)
29542956
def _aten_logical_and(self, other):
29552957
return jnp.logical_and(self, other)
29562958

29572959

29582960
# aten.logical_or
29592961
@op(torch.ops.aten.logical_or)
2962+
@op(torch.ops.aten.__or__)
29602963
def _aten_logical_or(self, other):
29612964
return jnp.logical_or(self, other)
29622965

@@ -2998,6 +3001,7 @@ def _aten_logcumsumexp(self, dim=None):
29983001
# aten.max_pool3d_backward
29993002
# aten.logical_xor
30003003
@op(torch.ops.aten.logical_xor)
3004+
@op(torch.ops.aten.__xor__)
30013005
def _aten_logical_xor(self, other):
30023006
return jnp.logical_xor(self, other)
30033007

@@ -4946,7 +4950,7 @@ def _aten__linalg_solve_ex(a, b):
49464950
res = jnp.linalg.solve(a, b)
49474951
if batched:
49484952
res = res.squeeze(-1)
4949-
info_shape = a.shape[0] if len(a.shape) >= 3 else []
4953+
info_shape = a.shape[:-2]
49504954
info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32))
49514955
return res, info
49524956

@@ -5497,6 +5501,50 @@ def _aten_pad(self, pad, mode='constant', value=None):
54975501
)
54985502

54995503

5504+
@op(torch.ops.aten.is_nonzero)
5505+
def _aten_is_nonzero(a):
5506+
a = jnp.squeeze(a)
5507+
if a.shape == (0,):
5508+
raise RuntimeError('bool value of Tensor with no values is ambiguous')
5509+
if a.ndim != 0:
5510+
raise RuntimeError(
5511+
'bool value of Tensor with more than one value is ambiguous')
5512+
return a.item() != 0
5513+
5514+
5515+
@op(torch.ops.aten.logit)
5516+
def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array:
5517+
"""
5518+
Computes the logit function of the input tensor.
5519+
5520+
logit(p) = log(p / (1 - p))
5521+
5522+
Args:
5523+
self: Input tensor.
5524+
eps: A small value to clip the input tensor to avoid log(0) or division by zero.
5525+
If None, no clipping is performed.
5526+
5527+
Returns:
5528+
A tensor with the logit of each element of the input.
5529+
"""
5530+
if eps is not None:
5531+
self = jnp.clip(self, eps, 1.0 - eps)
5532+
res = jnp.log(self / (1.0 - self))
5533+
res = res.astype(mappings.t2j_dtype(torch.get_default_dtype()))
5534+
return res
5535+
5536+
5537+
@op(torch.ops.aten.floor_divide)
5538+
def _aten_floor_divide(x, y):
5539+
res = jnp.floor_divide(x, y)
5540+
return res
5541+
5542+
5543+
@op(torch.ops.aten._assert_tensor_metadata)
5544+
def _aten__assert_tensor_metadata(*args, **kwargs):
5545+
pass
5546+
5547+
55005548
mutation_ops_to_functional = {
55015549
torch.ops.aten.add_:
55025550
op_base.InplaceOp(torch.ops.aten.add),
@@ -5565,6 +5613,10 @@ def _aten_pad(self, pad, mode='constant', value=None):
55655613
op_base.InplaceOp(torch.ops.aten.scatter),
55665614
torch.ops.aten.bitwise_or_:
55675615
op_base.InplaceOp(torch.ops.aten.bitwise_or),
5616+
torch.ops.aten.floor_divide_:
5617+
op_base.InplaceOp(torch.ops.aten.floor_divide),
5618+
torch.ops.aten.remainder_:
5619+
op_base.InplaceOp(torch.ops.aten.remainder),
55685620
}
55695621

55705622
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.

0 commit comments

Comments
 (0)