Skip to content

Commit 26b7efc

Browse files
committed
Clone the grads and use inplace method for requires_grad
1 parent d7bc294 commit 26b7efc

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

test/spmd/test_xla_dtensor_to_local.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,49 @@ def test_to_local_requires_grad(self):
6464
# All gradients should be 1.0 since we did a sum()
6565
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor)))
6666

67+
def test_to_local_grad_independence(self):
68+
"""Test that gradients are independent between original and local tensor."""
69+
world_size = xr.global_runtime_device_count()
70+
mesh = DeviceMesh("xla", list(range(world_size)))
71+
72+
tensor = torch.randn(100_000, 88, requires_grad=True)
73+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
74+
75+
# Create gradients
76+
res = sharded_tensor.sum()
77+
res.backward()
78+
79+
# Get local tensor
80+
local_tensor = sharded_tensor.to_local()
81+
82+
# Verify gradients are initially the same
83+
self.assertTrue(torch.allclose(local_tensor.grad, sharded_tensor.grad))
84+
85+
# Modify local tensor's gradient
86+
local_tensor.grad[0, 0] = 999.0
87+
88+
# Verify gradients are now independent (not the same object)
89+
self.assertFalse(local_tensor.grad is sharded_tensor.grad)
90+
self.assertFalse(torch.allclose(local_tensor.grad, sharded_tensor.grad))
91+
92+
def test_to_local_grad_none_handling(self):
93+
"""Test that to_local() handles None gradients correctly."""
94+
world_size = xr.global_runtime_device_count()
95+
mesh = DeviceMesh("xla", list(range(world_size)))
96+
97+
tensor = torch.randn(100_000, 88, requires_grad=True)
98+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
99+
100+
# Don't do backward pass, so grad remains None
101+
self.assertIsNone(sharded_tensor.grad)
102+
103+
# Get local tensor
104+
local_tensor = sharded_tensor.to_local()
105+
106+
# Verify local tensor has correct properties
107+
self.assertTrue(local_tensor.requires_grad)
108+
self.assertIsNone(local_tensor.grad)
109+
67110

68111
if __name__ == "__main__":
69112
result = unittest.main(exit=False)

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def to_local(self):
168168
result = self.global_tensor.clone()
169169
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
170170
if self.requires_grad:
171-
result.requires_grad = self.requires_grad
172-
result.grad = self.grad
171+
result.requires_grad_(self.requires_grad)
172+
result.grad = self.grad.clone() if self.grad is not None else None
173173

174174
return result
175175

0 commit comments

Comments
 (0)