Skip to content

Commit 2254961

Browse files
jeffdailypytorchmergebot
authored andcommitted
[CI] fix test_pointwise_ops.py test_mul_div_scalar_partial (pytorch#170510)
Support any world size; 2, 3 or 4. Pull Request resolved: pytorch#170510 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent 66407ac commit 2254961

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/distributed/tensor/test_pointwise_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,8 @@ def test_mul_div_scalar_partial(self):
467467

468468
self.assertTrue(res._spec.placements[0].is_partial())
469469
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
470-
self.assertEqual(res, 12)
470+
expected = sum(i * 2 for i in range(self.world_size))
471+
self.assertEqual(res, expected)
471472

472473
res = aten.div.Scalar(dt, 2)
473474
self.assertEqual(
@@ -478,7 +479,8 @@ def test_mul_div_scalar_partial(self):
478479
self.assertTrue(res._spec.placements[0].is_partial())
479480
res = res.redistribute(dt.device_mesh, placements=[Replicate()])
480481

481-
self.assertEqual(res, 3)
482+
expected = expected / 4.0
483+
self.assertEqual(res, expected)
482484

483485
@with_comms
484486
@parametrize("op,reduce_op", [(torch.maximum, "max"), (torch.minimum, "min")])

0 commit comments

Comments
 (0)