Skip to content

Commit be7ab62

Browse files
committed
fix the failing CI by reverting to default requires_grad
1 parent bd9c9f3 commit be7ab62

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

test/spmd/test_xla_dtensor_to_local.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_to_local_requires_grad(self):
4444
tensor = torch.randn(100_000, 88, requires_grad=True)
4545

4646
# Create XLAShardedTensor
47-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
47+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
4848

4949
# Verify requires_grad is set
5050
self.assertTrue(sharded_tensor.requires_grad)
@@ -70,7 +70,7 @@ def test_to_local_grad_independence(self):
7070
mesh = DeviceMesh("xla", list(range(world_size)))
7171

7272
tensor = torch.randn(100_000, 88, requires_grad=True)
73-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
73+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
7474

7575
# Create gradients
7676
res = sharded_tensor.sum()
@@ -95,7 +95,7 @@ def test_to_local_grad_none_handling(self):
9595
mesh = DeviceMesh("xla", list(range(world_size)))
9696

9797
tensor = torch.randn(100_000, 88, requires_grad=True)
98-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
98+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
9999

100100
# Don't do backward pass, so grad remains None
101101
self.assertIsNone(sharded_tensor.grad)

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __new__(cls,
115115
dtype=elem.dtype,
116116
layout=elem.layout,
117117
device=elem.device,
118-
requires_grad=kwargs.get("requires_grad", elem.requires_grad))
118+
requires_grad=kwargs.get("requires_grad", False))
119119
r.global_tensor = elem.detach() if r.requires_grad else elem
120120

121121
# Initialize mesh, partition, and spec information
@@ -165,14 +165,20 @@ def to_local(self):
165165
torch.Tensor: The global tensor representation with appropriate requires_grad setting.
166166
"""
167167

168-
# Create a new tensor with the same values of global_tensor
169-
result = self.global_tensor.clone()
170-
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
171-
if self.requires_grad:
172-
result.requires_grad_(self.requires_grad)
173-
result.grad = self.grad.clone() if self.grad is not None else None
174168

175-
return result
169+
if not self.requires_grad:
170+
# When requires_grad is False, global_tensor is the original tensor
171+
return self.global_tensor
172+
else:
173+
# When requires_grad is True, global_tensor is detached
174+
# Create a new tensor with the same values of global_tensor
175+
result = self.global_tensor.clone()
176+
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
177+
if self.requires_grad:
178+
result.requires_grad_(self.requires_grad)
179+
result.grad = self.grad.clone() if self.grad is not None else None
180+
181+
return result
176182

177183
@property
178184
def sharding_spec(self):

0 commit comments

Comments
 (0)