@@ -47,8 +47,7 @@ def test_same_manual_seed(self):
4747 y = torch .randn ((3 , 3 ))
4848 self .assertIsInstance (y , tensor .Tensor )
4949
50- self .assertTrue (
51- torch .equal (torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem )))
50+ self .assertTrue (torch .allclose (x , y ))
5251
5352 def test_different_manual_seed (self ):
5453 with xla_env :
@@ -60,8 +59,7 @@ def test_different_manual_seed(self):
6059 y = torch .randn ((3 , 3 ))
6160 self .assertIsInstance (y , tensor .Tensor )
6261
63- self .assertFalse (
64- torch .equal (torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem )))
62+ self .assertFalse (torch .allclose (x , y ))
6563
6664 def test_jit_with_rng (self ):
6765
@@ -76,20 +74,14 @@ def random_op():
7674
7775 # Result always expected to be the same for a jitted function because seeds
7876 # are baked in
79- torch .testing .assert_close (
80- torchax .tensor .j2t (random_jit ()._elem ),
81- torchax .tensor .j2t (random_jit ()._elem ),
82- atol = 0 ,
83- rtol = 0 )
77+ torch .testing .assert_close (random_jit (), random_jit (), atol = 0 , rtol = 0 )
8478
8579 def test_generator_seed (self ):
8680 with xla_env :
8781 x = torch .randn (2 , 3 , generator = torch .Generator ().manual_seed (0 ))
8882 y = torch .randn (2 , 3 , generator = torch .Generator ().manual_seed (0 ))
8983
90- # Values will be different, but still check device, layout, dtype, etc
91- torch .testing .assert_close (
92- torchax .tensor .j2t (x ._elem ), torchax .tensor .j2t (y ._elem ))
84+ torch .testing .assert_close (x , y )
9385
9486 def test_buffer (self ):
9587
0 commit comments