Skip to content

Commit 3377801

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Context manager to get temporary gloo pg (#902)
Summary: Pull Request resolved: #902 Reviewed By: JKSenthil Differential Revision: D62414608 fbshipit-source-id: 2686f0a018a7ff06d0a989a8b6bfdcd4928c8b87
1 parent f03fd9b commit 3377801

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

tests/utils/test_distributed.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99

1010
import os
1111
import unittest
12-
from typing import Literal, Optional, Union
12+
from typing import Callable, Literal, Optional, Union
1313
from unittest.mock import MagicMock, patch
1414
from urllib.parse import parse_qs, urlparse
1515

1616
import torch
1717
import torch.distributed as dist
1818
import torch.distributed.launcher as launcher
1919
from pyre_extensions import none_throws
20+
from torch.distributed import ProcessGroup
2021
from torchtnt.utils.distributed import (
2122
_validate_global_rank_world_size,
2223
all_gather_tensors,
@@ -25,6 +26,7 @@
2526
get_global_rank,
2627
get_local_rank,
2728
get_local_world_size,
29+
get_or_create_gloo_pg,
2830
get_process_group_backend_from_device,
2931
get_tcp_init_method,
3032
get_world_size,
@@ -463,3 +465,94 @@ def _test_method_for_rank_zero() -> str:
463465
val_from_test_method = _test_method_for_rank_zero()
464466
tc = unittest.TestCase()
465467
tc.assertEqual(val_from_test_method, "foo")
468+
469+
@skip_if_not_distributed
470+
def test_get_or_create_gloo_pg(self) -> None:
471+
spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg)
472+
473+
@staticmethod
474+
@patch("torchtnt.utils.distributed.dist.destroy_process_group")
475+
def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None:
476+
477+
# Get a side effect for the get_backend function that returns NCCL the first time it is called,
478+
# and then will return GLOO for subsequent calls. For use with _test_get_or_create_gloo_pg.
479+
def _get_backend_side_effect() -> Callable[[Optional[ProcessGroup]], str]:
480+
called_get_backend = False
481+
482+
def get_backend(_) -> str:
483+
# We just want to return NCCL the first time we call this function.
484+
nonlocal called_get_backend
485+
if not called_get_backend:
486+
called_get_backend = True
487+
return dist.Backend.NCCL
488+
else:
489+
return dist.Backend.GLOO # real PG
490+
491+
return get_backend
492+
493+
tc = unittest.TestCase()
494+
495+
# Test not distributed - no-op
496+
with patch(
497+
"torchtnt.utils.distributed.dist.is_initialized",
498+
return_value=False,
499+
):
500+
with get_or_create_gloo_pg() as pg:
501+
tc.assertIsNone(pg)
502+
503+
mock_destroy_process_group.assert_not_called()
504+
505+
# Test no-op since gloo pg already exists
506+
mock_destroy_process_group.reset_mock()
507+
with get_or_create_gloo_pg() as pg:
508+
tc.assertIs(pg, dist.group.WORLD)
509+
510+
mock_destroy_process_group.assert_not_called()
511+
512+
# Test creating new gloo candidate pg - no op
513+
mock_destroy_process_group.reset_mock()
514+
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
515+
with get_or_create_gloo_pg(gloo_pg) as pg:
516+
tc.assertIs(pg, gloo_pg)
517+
518+
mock_destroy_process_group.assert_not_called()
519+
520+
# Test with NCCL backend - should create a new gloo pg and destroy
521+
mock_destroy_process_group.reset_mock()
522+
523+
with patch(
524+
"torchtnt.utils.distributed.dist.get_backend",
525+
side_effect=_get_backend_side_effect(),
526+
):
527+
with get_or_create_gloo_pg() as pg:
528+
pg = none_throws(pg)
529+
tc.assertIsNot(pg, dist.group.WORLD)
530+
tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO)
531+
532+
mock_destroy_process_group.assert_called_once_with(pg)
533+
534+
# Test exception handling with existing pg - forward exception, group should not be destroyed
535+
mock_destroy_process_group.reset_mock()
536+
with tc.assertRaisesRegex(Exception, "Test Exception"):
537+
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
538+
with get_or_create_gloo_pg(gloo_pg) as pg:
539+
tc.assertIs(pg, gloo_pg)
540+
raise Exception("Test Exception")
541+
542+
mock_destroy_process_group.assert_not_called()
543+
544+
# Test exception handling with new pg - forward exception, group should be destroyed
545+
mock_destroy_process_group.reset_mock()
546+
with tc.assertRaisesRegex(Exception, "Test Exception"):
547+
with patch(
548+
"torchtnt.utils.distributed.dist.get_backend",
549+
side_effect=_get_backend_side_effect(),
550+
):
551+
with get_or_create_gloo_pg() as pg:
552+
tc.assertIsNot(pg, dist.group.WORLD)
553+
tc.assertEqual(
554+
none_throws(pg)._get_backend_name(), dist.Backend.GLOO
555+
)
556+
raise Exception("Test Exception")
557+
558+
mock_destroy_process_group.assert_called_once_with(pg)

torchtnt/utils/distributed.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
# pyre-strict
99

1010

11+
import logging
1112
import os
1213
import tempfile
14+
from contextlib import contextmanager
1315
from dataclasses import dataclass
1416
from datetime import timedelta
1517
from functools import wraps
16-
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
18+
from typing import Any, Callable, cast, Dict, Generator, List, Optional, TypeVar, Union
1719

1820
import torch
1921
import torch.nn.functional as F
@@ -28,6 +30,8 @@
2830
TParams = ParameterSpecification("TParams")
2931
TReturn = TypeVar("TReturn")
3032

33+
logger: logging.Logger = logging.getLogger(__name__)
34+
3135

3236
class PGWrapper:
3337
"""
@@ -641,3 +645,43 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
641645
return cast(T, val)
642646

643647
return wrapper
648+
649+
650+
@contextmanager
651+
def get_or_create_gloo_pg(
652+
candidate_pg: Optional[dist.ProcessGroup] = None,
653+
) -> Generator[Optional[dist.ProcessGroup], None, None]:
654+
"""
655+
Context manager to ensure that a gloo process group is used for the contained operations. First checks if the
656+
WORLD process group, or the provided candidate process group, is already gloo-based. In case it is, that is returned.
657+
Otherwise, a new gloo process group will be created and returned. Upon exiting the context, if a new process group
658+
was created, it will be destroyed.
659+
660+
Note: If the distributed environment is not initialized, this context manager will return None and will be no-op.
661+
662+
Args:
663+
candidate_pg: Optional process group to check if it is gloo-based. If None, the WORLD process group will be checked.
664+
"""
665+
gloo_pg_created = False
666+
667+
if not dist.is_initialized():
668+
logger.info("Not in a distributed environment, gloo process group not created")
669+
pg = None
670+
671+
else:
672+
pg = candidate_pg or dist.group.WORLD
673+
if dist.get_backend(pg) != dist.Backend.GLOO:
674+
logger.info("Creating temporary gloo process group")
675+
pg = dist.new_group(
676+
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
677+
)
678+
gloo_pg_created = True
679+
680+
try:
681+
yield pg
682+
683+
finally:
684+
# Cleanup temporary gloo pg if it was created
685+
if gloo_pg_created:
686+
dist.destroy_process_group(pg)
687+
logger.info("Destroyed temporary gloo process group")

0 commit comments

Comments
 (0)