@@ -44,7 +44,8 @@ def test_to_local_requires_grad(self):
44
44
tensor = torch .randn (100_000 , 88 , requires_grad = True )
45
45
46
46
# Create XLAShardedTensor
47
- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
47
+ sharded_tensor = XLAShardedTensor (
48
+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
48
49
49
50
# Verify requires_grad is set
50
51
self .assertTrue (sharded_tensor .requires_grad )
@@ -70,7 +71,8 @@ def test_to_local_grad_independence(self):
70
71
mesh = DeviceMesh ("xla" , list (range (world_size )))
71
72
72
73
tensor = torch .randn (100_000 , 88 , requires_grad = True )
73
- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
74
+ sharded_tensor = XLAShardedTensor (
75
+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
74
76
75
77
# Create gradients
76
78
res = sharded_tensor .sum ()
@@ -95,7 +97,8 @@ def test_to_local_grad_none_handling(self):
95
97
mesh = DeviceMesh ("xla" , list (range (world_size )))
96
98
97
99
tensor = torch .randn (100_000 , 88 , requires_grad = True )
98
- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
100
+ sharded_tensor = XLAShardedTensor (
101
+ tensor , mesh , [Shard (0 )], requires_grad = tensor .requires_grad )
99
102
100
103
# Don't do backward pass, so grad remains None
101
104
self .assertIsNone (sharded_tensor .grad )
0 commit comments