Skip to content

Commit b378a46

Browse files
committed
run yapf
1 parent be7ab62 commit b378a46

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

test/spmd/test_xla_dtensor_to_local.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def test_to_local_requires_grad(self):
4444
tensor = torch.randn(100_000, 88, requires_grad=True)
4545

4646
# Create XLAShardedTensor
47-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
47+
sharded_tensor = XLAShardedTensor(
48+
tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
4849

4950
# Verify requires_grad is set
5051
self.assertTrue(sharded_tensor.requires_grad)
@@ -70,7 +71,8 @@ def test_to_local_grad_independence(self):
7071
mesh = DeviceMesh("xla", list(range(world_size)))
7172

7273
tensor = torch.randn(100_000, 88, requires_grad=True)
73-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
74+
sharded_tensor = XLAShardedTensor(
75+
tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
7476

7577
# Create gradients
7678
res = sharded_tensor.sum()
@@ -95,7 +97,8 @@ def test_to_local_grad_none_handling(self):
9597
mesh = DeviceMesh("xla", list(range(world_size)))
9698

9799
tensor = torch.randn(100_000, 88, requires_grad=True)
98-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
100+
sharded_tensor = XLAShardedTensor(
101+
tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad)
99102

100103
# Don't do backward pass, so grad remains None
101104
self.assertIsNone(sharded_tensor.grad)

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def to_local(self):
165165
torch.Tensor: The global tensor representation with appropriate requires_grad setting.
166166
"""
167167

168-
169168
if not self.requires_grad:
170169
# When requires_grad is False, global_tensor is the original tensor
171170
return self.global_tensor

0 commit comments

Comments
 (0)