@@ -64,6 +64,49 @@ def test_to_local_requires_grad(self):
64
64
# All gradients should be 1.0 since we did a sum()
65
65
self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
66
66
67
+ def test_to_local_grad_independence (self ):
68
+ """Test that gradients are independent between original and local tensor."""
69
+ world_size = xr .global_runtime_device_count ()
70
+ mesh = DeviceMesh ("xla" , list (range (world_size )))
71
+
72
+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
73
+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
74
+
75
+ # Create gradients
76
+ res = sharded_tensor .sum ()
77
+ res .backward ()
78
+
79
+ # Get local tensor
80
+ local_tensor = sharded_tensor .to_local ()
81
+
82
+ # Verify gradients are initially the same
83
+ self .assertTrue (torch .allclose (local_tensor .grad , sharded_tensor .grad ))
84
+
85
+ # Modify local tensor's gradient
86
+ local_tensor .grad [0 , 0 ] = 999.0
87
+
88
+ # Verify gradients are now independent (not the same object)
89
+ self .assertFalse (local_tensor .grad is sharded_tensor .grad )
90
+ self .assertFalse (torch .allclose (local_tensor .grad , sharded_tensor .grad ))
91
+
92
+ def test_to_local_grad_none_handling (self ):
93
+ """Test that to_local() handles None gradients correctly."""
94
+ world_size = xr .global_runtime_device_count ()
95
+ mesh = DeviceMesh ("xla" , list (range (world_size )))
96
+
97
+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
98
+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
99
+
100
+ # Don't do backward pass, so grad remains None
101
+ self .assertIsNone (sharded_tensor .grad )
102
+
103
+ # Get local tensor
104
+ local_tensor = sharded_tensor .to_local ()
105
+
106
+ # Verify local tensor has correct properties
107
+ self .assertTrue (local_tensor .requires_grad )
108
+ self .assertIsNone (local_tensor .grad )
109
+
67
110
68
111
if __name__ == "__main__" :
69
112
result = unittest .main (exit = False )
0 commit comments