Skip to content

Commit 58b6ea7

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Attempt to extend nccl collective timeout (#858)
Summary: Pull Request resolved: #858 We have two remaining tests that are still failing, with the following error message: ``` [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=BROADCAST, NumelIn=2, NumelOut=2, Timeout(ms)=60000) ran for 60033 milliseconds before timing out. ``` Let's attempt to increase the collective timeout for those tests. There's no guarantee this will work, but it's worth trying. Otherwise we may consider deleting the failing tests to avoid flakyness. Reviewed By: galrotem Differential Revision: D59342738 fbshipit-source-id: 220f1f359eb0f98e5175e93badc7e998ae00db64
1 parent 5dad8d3 commit 58b6ea7

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

tests/utils/test_checkpoint_gpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ class TestCheckpointUtilsGPU(unittest.TestCase):
2424
@skip_if_not_gpu
2525
def test_get_checkpoint_dirpaths_distributed(self) -> None:
2626
spawn_multi_process(
27-
2,
28-
"nccl",
29-
self._test_get_checkpoint_dirpaths,
27+
2, "nccl", self._test_get_checkpoint_dirpaths, timeout_s=180
3028
)
3129

3230
@staticmethod

tests/utils/test_distributed_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_pg_wrapper_scatter_object_list_nccl(self) -> None:
5252
2,
5353
"nccl",
5454
self._test_pg_wrapper_scatter_object_list,
55+
timeout_s=180,
5556
)
5657

5758
@classmethod

torchtnt/utils/distributed.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class ProcessGroupSetupParams:
518518
backend: str
519519
port: str
520520
world_size: int
521+
timeout_s: int
521522

522523

523524
def spawn_multi_process(
@@ -538,6 +539,11 @@ def spawn_multi_process(
538539
method_args: args for the method
539540
method_kwargs: kwargs for the method
540541
542+
Note:
543+
The default timeout used for distributed collectives in the process group is 60 seconds.
544+
This can be overridden by passing a `timeout_s` key in the `method_kwargs`. It will be
545+
extracted before passing to the method call.
546+
541547
Returns:
542548
A list, l, where l[i] is the return value of method(*method_args, **methods_kwargs) on rank i
543549
"""
@@ -550,7 +556,12 @@ def spawn_multi_process(
550556
# https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
551557
_init_pg_and_rank_and_launch_method,
552558
args=(
553-
ProcessGroupSetupParams(backend=backend, port=port, world_size=world_size),
559+
ProcessGroupSetupParams(
560+
backend=backend,
561+
port=port,
562+
world_size=world_size,
563+
timeout_s=method_kwargs.pop("timeout_s", 60),
564+
),
554565
mp_output_dict,
555566
method,
556567
method_args,
@@ -582,7 +593,9 @@ def _init_pg_and_rank_and_launch_method(
582593
rank=rank,
583594
world_size=pg_setup_params.world_size,
584595
backend=pg_setup_params.backend,
585-
timeout=timedelta(seconds=60), # setting up timeout for distributed collectives
596+
timeout=timedelta( # setting up timeout for distributed collectives
597+
seconds=pg_setup_params.timeout_s
598+
),
586599
)
587600
try:
588601
# pyre-ignore: spawn_multi_process uses unsafe types to begin with

0 commit comments

Comments
 (0)