@@ -8,28 +8,44 @@ class TestMutations(TestCase):
8
8
9
9
def setUp (self ):
10
10
self .env = torchax .tensor .Environment ()
11
+ self .env .config .debug_print_each_op = True
11
12
12
13
def test_add (self ):
13
14
with self .env :
14
- x = torch .tensor ([1 , 2 , 3 ], dtype = torch .int32 )
15
- y = torch .tensor ([4 , 5 , 6 ], dtype = torch .int32 )
15
+ x = torch .tensor ([1 , 2 , 3 ], device = 'jax' , dtype = torch .int32 )
16
+ y = torch .tensor ([4 , 5 , 6 ], device = 'jax' , dtype = torch .int32 )
16
17
x .add_ (y )
17
- self .assertEqual (x , torch .tensor ([5 , 7 , 9 ], dtype = torch .int32 ))
18
+ torch .testing .assert_close (x .cpu (),
19
+ torch .tensor ([5 , 7 , 9 ], dtype = torch .int32 ))
18
20
19
21
def test_sub (self ):
20
22
with self .env :
21
- x = torch .tensor ([1 , 2 , 3 ], dtype = torch .int32 )
22
- y = torch .tensor ([4 , 5 , 6 ], dtype = torch .int32 )
23
+ x = torch .tensor ([1 , 2 , 3 ], device = 'jax' , dtype = torch .int32 )
24
+ y = torch .tensor ([4 , 5 , 6 ], device = 'jax' , dtype = torch .int32 )
23
25
x .sub_ (y )
24
- self .assertEqual (x , torch .tensor ([- 3 , - 3 , - 3 ], dtype = torch .int32 ))
26
+ torch .testing .assert_close (x .cpu (),
27
+ torch .tensor ([- 3 , - 3 , - 3 ], dtype = torch .int32 ))
25
28
26
29
def test_mul (self ):
27
30
with self .env :
28
- x = torch .tensor ([1 , 2 , 3 ], dtype = torch .int32 )
29
- y = torch .tensor ([4 , 5 , 6 ], dtype = torch .int32 )
31
+ x = torch .tensor ([1 , 2 , 3 ], device = 'jax' , dtype = torch .int32 )
32
+ y = torch .tensor ([4 , 5 , 6 ], device = 'jax' , dtype = torch .int32 )
30
33
31
34
x .mul_ (y )
32
- self .assertEqual (x , torch .tensor ([4 , 10 , 18 ], dtype = torch .int32 ))
35
+ torch .testing .assert_close (x .cpu (),
36
+ torch .tensor ([4 , 10 , 18 ], dtype = torch .int32 ))
37
+
38
+ def test_index_copy (self ):
39
+ with self .env :
40
+ x = torch .zeros (5 , 3 , device = 'jax' )
41
+ t = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]],
42
+ device = 'jax' ,
43
+ dtype = torch .float )
44
+ index = torch .tensor ([0 , 4 , 2 ], device = 'jax' )
45
+ x .index_copy_ (0 , index , t )
46
+ expected = torch .tensor ([[1. , 2. , 3. ], [0. , 0. , 0. ], [7. , 8. , 9. ],
47
+ [0. , 0. , 0. ], [4. , 5. , 6. ]])
48
+ torch .testing .assert_close (x .cpu (), expected )
33
49
34
50
35
51
if __name__ == '__main__' :
0 commit comments