Skip to content

Commit b5a9bc1

Browse files
committed
fix the failing CI by reverting to default requires_grad
1 parent 26b7efc commit b5a9bc1

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

test/spmd/test_xla_dtensor_to_local.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_to_local_requires_grad(self):
4444
tensor = torch.randn(100_000, 88, requires_grad=True)
4545

4646
# Create XLAShardedTensor
47-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
47+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
4848

4949
# Verify requires_grad is set
5050
self.assertTrue(sharded_tensor.requires_grad)
@@ -70,7 +70,7 @@ def test_to_local_grad_independence(self):
7070
mesh = DeviceMesh("xla", list(range(world_size)))
7171

7272
tensor = torch.randn(100_000, 88, requires_grad=True)
73-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
73+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
7474

7575
# Create gradients
7676
res = sharded_tensor.sum()
@@ -95,7 +95,7 @@ def test_to_local_grad_none_handling(self):
9595
mesh = DeviceMesh("xla", list(range(world_size)))
9696

9797
tensor = torch.randn(100_000, 88, requires_grad=True)
98-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
98+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
9999

100100
# Don't do backward pass, so grad remains None
101101
self.assertIsNone(sharded_tensor.grad)

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __new__(cls,
114114
dtype=elem.dtype,
115115
layout=elem.layout,
116116
device=elem.device,
117-
requires_grad=kwargs.get("requires_grad", elem.requires_grad))
117+
requires_grad=kwargs.get("requires_grad", False))
118118
r.global_tensor = elem.detach() if r.requires_grad else elem
119119

120120
# Initialize mesh, partition, and spec information
@@ -164,14 +164,20 @@ def to_local(self):
164164
torch.Tensor: The global tensor representation with appropriate requires_grad setting.
165165
"""
166166

167-
# Create a new tensor with the same values of global_tensor
168-
result = self.global_tensor.clone()
169-
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
170-
if self.requires_grad:
171-
result.requires_grad_(self.requires_grad)
172-
result.grad = self.grad.clone() if self.grad is not None else None
173167

174-
return result
168+
if not self.requires_grad:
169+
# When requires_grad is False, global_tensor is the original tensor
170+
return self.global_tensor
171+
else:
172+
# When requires_grad is True, global_tensor is detached
173+
# Create a new tensor with the same values of global_tensor
174+
result = self.global_tensor.clone()
175+
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
176+
if self.requires_grad:
177+
result.requires_grad_(self.requires_grad)
178+
result.grad = self.grad.clone() if self.grad is not None else None
179+
180+
return result
175181

176182
@property
177183
def sharding_spec(self):

0 commit comments

Comments
 (0)