Skip to content

Commit ac0f206

Browse files
tianyu-lpytorchmergebot
authored andcommitted
[dtensor] fix side-effect on dtype for _like ops (pytorch#146869)
fixes pytorch#146749 Pull Request resolved: pytorch#146869 Approved by: https://github.com/yifuwang, https://github.com/janeyx99, https://github.com/ngimel
1 parent d774a63 commit ac0f206

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

test/distributed/tensor/test_tensor_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,12 @@ def test_zeros_like(self):
226226

227227
input_tensor = torch.randn(4, 8, requires_grad=True)
228228
dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
229-
zeros_like_dt = torch.zeros_like(dist_tensor)
230-
zeros_expected = torch.zeros(4, 8)
229+
zeros_like_dt = torch.zeros_like(dist_tensor, dtype=torch.bfloat16)
230+
zeros_expected = torch.zeros(4, 8, dtype=torch.bfloat16)
231231
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
232+
# make sure there is no side effect on the input tensor dtype
233+
self.assertEqual(dist_tensor.dtype, torch.float32)
234+
self.assertEqual(zeros_like_dt.dtype, torch.bfloat16)
232235

233236
@with_comms
234237
@skip_if_lt_x_gpu(4)

torch/distributed/tensor/_ops/_tensor_ops.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,16 @@ def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
148148
assert isinstance(select_strategy, OpStrategy)
149149
for arg_strategy in select_strategy.strategies:
150150
arg_spec = arg_strategy.output_spec
151-
if is_tensor_partial(arg_spec):
152-
# if the arg_spec have partial, accept partial
153-
# in the input_specs but output replicate for
154-
# those corresponding mesh dims
155-
output_spec = DTensorSpec(
156-
mesh=arg_spec.mesh,
157-
placements=tuple(
158-
Replicate() if isinstance(p, Partial) else p
159-
for p in arg_spec.placements
160-
),
161-
)
162-
create_like_strategy.strategies.append(
163-
PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,))
164-
)
165-
166-
else:
167-
create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
151+
output_spec = DTensorSpec(
152+
mesh=arg_spec.mesh,
153+
placements=tuple(
154+
Replicate() if isinstance(p, Partial) else p
155+
for p in arg_spec.placements
156+
),
157+
)
158+
create_like_strategy.strategies.append(
159+
PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,))
160+
)
168161

169162
return create_like_strategy
170163

0 commit comments

Comments
 (0)