@@ -375,6 +375,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1):
375
375
return self [tuple (dims )]
376
376
377
377
378
+ @op (torch .ops .aten .positive )
378
379
@op (torch .ops .aten .detach )
379
380
def _aten_detach (self ):
380
381
return self
@@ -2951,12 +2952,14 @@ def _aten_log2(x):
2951
2952
2952
2953
# aten.logical_and
2953
2954
@op (torch .ops .aten .logical_and )
2955
+ @op (torch .ops .aten .__and__ )
2954
2956
def _aten_logical_and (self , other ):
2955
2957
return jnp .logical_and (self , other )
2956
2958
2957
2959
2958
2960
# aten.logical_or
2959
2961
@op (torch .ops .aten .logical_or )
2962
+ @op (torch .ops .aten .__or__ )
2960
2963
def _aten_logical_or (self , other ):
2961
2964
return jnp .logical_or (self , other )
2962
2965
@@ -2998,6 +3001,7 @@ def _aten_logcumsumexp(self, dim=None):
2998
3001
# aten.max_pool3d_backward
2999
3002
# aten.logical_xor
3000
3003
@op (torch .ops .aten .logical_xor )
3004
+ @op (torch .ops .aten .__xor__ )
3001
3005
def _aten_logical_xor (self , other ):
3002
3006
return jnp .logical_xor (self , other )
3003
3007
@@ -4946,7 +4950,7 @@ def _aten__linalg_solve_ex(a, b):
4946
4950
res = jnp .linalg .solve (a , b )
4947
4951
if batched :
4948
4952
res = res .squeeze (- 1 )
4949
- info_shape = a .shape [0 ] if len ( a . shape ) >= 3 else [ ]
4953
+ info_shape = a .shape [: - 2 ]
4950
4954
info = jnp .zeros (info_shape , dtype = mappings .t2j_dtype (torch .int32 ))
4951
4955
return res , info
4952
4956
@@ -5497,6 +5501,50 @@ def _aten_pad(self, pad, mode='constant', value=None):
5497
5501
)
5498
5502
5499
5503
5504
+ @op (torch .ops .aten .is_nonzero )
5505
+ def _aten_is_nonzero (a ):
5506
+ a = jnp .squeeze (a )
5507
+ if a .shape == (0 ,):
5508
+ raise RuntimeError ('bool value of Tensor with no values is ambiguous' )
5509
+ if a .ndim != 0 :
5510
+ raise RuntimeError (
5511
+ 'bool value of Tensor with more than one value is ambiguous' )
5512
+ return a .item () != 0
5513
+
5514
+
5515
+ @op (torch .ops .aten .logit )
5516
+ def _aten_logit (self : jax .Array , eps : float | None = None ) -> jax .Array :
5517
+ """
5518
+ Computes the logit function of the input tensor.
5519
+
5520
+ logit(p) = log(p / (1 - p))
5521
+
5522
+ Args:
5523
+ self: Input tensor.
5524
+ eps: A small value to clip the input tensor to avoid log(0) or division by zero.
5525
+ If None, no clipping is performed.
5526
+
5527
+ Returns:
5528
+ A tensor with the logit of each element of the input.
5529
+ """
5530
+ if eps is not None :
5531
+ self = jnp .clip (self , eps , 1.0 - eps )
5532
+ res = jnp .log (self / (1.0 - self ))
5533
+ res = res .astype (mappings .t2j_dtype (torch .get_default_dtype ()))
5534
+ return res
5535
+
5536
+
5537
+ @op (torch .ops .aten .floor_divide )
5538
+ def _aten_floor_divide (x , y ):
5539
+ res = jnp .floor_divide (x , y )
5540
+ return res
5541
+
5542
+
5543
+ @op (torch .ops .aten ._assert_tensor_metadata )
5544
+ def _aten__assert_tensor_metadata (* args , ** kwargs ):
5545
+ pass
5546
+
5547
+
5500
5548
mutation_ops_to_functional = {
5501
5549
torch .ops .aten .add_ :
5502
5550
op_base .InplaceOp (torch .ops .aten .add ),
@@ -5565,6 +5613,10 @@ def _aten_pad(self, pad, mode='constant', value=None):
5565
5613
op_base .InplaceOp (torch .ops .aten .scatter ),
5566
5614
torch .ops .aten .bitwise_or_ :
5567
5615
op_base .InplaceOp (torch .ops .aten .bitwise_or ),
5616
+ torch .ops .aten .floor_divide_ :
5617
+ op_base .InplaceOp (torch .ops .aten .floor_divide ),
5618
+ torch .ops .aten .remainder_ :
5619
+ op_base .InplaceOp (torch .ops .aten .remainder ),
5568
5620
}
5569
5621
5570
5622
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
0 commit comments