|
9 | 9 |
|
10 | 10 | import os
|
11 | 11 | import unittest
|
12 |
| -from typing import Literal, Optional, Union |
| 12 | +from typing import Callable, Literal, Optional, Union |
13 | 13 | from unittest.mock import MagicMock, patch
|
14 | 14 | from urllib.parse import parse_qs, urlparse
|
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 | import torch.distributed as dist
|
18 | 18 | import torch.distributed.launcher as launcher
|
19 | 19 | from pyre_extensions import none_throws
|
| 20 | +from torch.distributed import ProcessGroup |
20 | 21 | from torchtnt.utils.distributed import (
|
21 | 22 | _validate_global_rank_world_size,
|
22 | 23 | all_gather_tensors,
|
|
25 | 26 | get_global_rank,
|
26 | 27 | get_local_rank,
|
27 | 28 | get_local_world_size,
|
| 29 | + get_or_create_gloo_pg, |
28 | 30 | get_process_group_backend_from_device,
|
29 | 31 | get_tcp_init_method,
|
30 | 32 | get_world_size,
|
@@ -463,3 +465,94 @@ def _test_method_for_rank_zero() -> str:
|
463 | 465 | val_from_test_method = _test_method_for_rank_zero()
|
464 | 466 | tc = unittest.TestCase()
|
465 | 467 | tc.assertEqual(val_from_test_method, "foo")
|
| 468 | + |
| 469 | + @skip_if_not_distributed |
| 470 | + def test_get_or_create_gloo_pg(self) -> None: |
| 471 | + spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg) |
| 472 | + |
| 473 | + @staticmethod |
| 474 | + @patch("torchtnt.utils.distributed.dist.destroy_process_group") |
| 475 | + def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None: |
| 476 | + |
| 477 | + # Get a side effect for the get_backend function that returns NCCL the first time it is called, |
| 478 | + # and then will return GLOO for subsequent calls. For use with _test_get_or_create_gloo_pg. |
| 479 | + def _get_backend_side_effect() -> Callable[[Optional[ProcessGroup]], str]: |
| 480 | + called_get_backend = False |
| 481 | + |
| 482 | + def get_backend(_) -> str: |
| 483 | + # We just want to return NCCL the first time we call this function. |
| 484 | + nonlocal called_get_backend |
| 485 | + if not called_get_backend: |
| 486 | + called_get_backend = True |
| 487 | + return dist.Backend.NCCL |
| 488 | + else: |
| 489 | + return dist.Backend.GLOO # real PG |
| 490 | + |
| 491 | + return get_backend |
| 492 | + |
| 493 | + tc = unittest.TestCase() |
| 494 | + |
| 495 | + # Test not distributed - no-op |
| 496 | + with patch( |
| 497 | + "torchtnt.utils.distributed.dist.is_initialized", |
| 498 | + return_value=False, |
| 499 | + ): |
| 500 | + with get_or_create_gloo_pg() as pg: |
| 501 | + tc.assertIsNone(pg) |
| 502 | + |
| 503 | + mock_destroy_process_group.assert_not_called() |
| 504 | + |
| 505 | + # Test no-op since gloo pg already exists |
| 506 | + mock_destroy_process_group.reset_mock() |
| 507 | + with get_or_create_gloo_pg() as pg: |
| 508 | + tc.assertIs(pg, dist.group.WORLD) |
| 509 | + |
| 510 | + mock_destroy_process_group.assert_not_called() |
| 511 | + |
| 512 | + # Test creating new gloo candidate pg - no op |
| 513 | + mock_destroy_process_group.reset_mock() |
| 514 | + gloo_pg = dist.new_group(backend=dist.Backend.GLOO) |
| 515 | + with get_or_create_gloo_pg(gloo_pg) as pg: |
| 516 | + tc.assertIs(pg, gloo_pg) |
| 517 | + |
| 518 | + mock_destroy_process_group.assert_not_called() |
| 519 | + |
| 520 | + # Test with NCCL backend - should create a new gloo pg and destroy |
| 521 | + mock_destroy_process_group.reset_mock() |
| 522 | + |
| 523 | + with patch( |
| 524 | + "torchtnt.utils.distributed.dist.get_backend", |
| 525 | + side_effect=_get_backend_side_effect(), |
| 526 | + ): |
| 527 | + with get_or_create_gloo_pg() as pg: |
| 528 | + pg = none_throws(pg) |
| 529 | + tc.assertIsNot(pg, dist.group.WORLD) |
| 530 | + tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO) |
| 531 | + |
| 532 | + mock_destroy_process_group.assert_called_once_with(pg) |
| 533 | + |
| 534 | + # Test exception handling with existing pg - forward exception, group should not be destroyed |
| 535 | + mock_destroy_process_group.reset_mock() |
| 536 | + with tc.assertRaisesRegex(Exception, "Test Exception"): |
| 537 | + gloo_pg = dist.new_group(backend=dist.Backend.GLOO) |
| 538 | + with get_or_create_gloo_pg(gloo_pg) as pg: |
| 539 | + tc.assertIs(pg, gloo_pg) |
| 540 | + raise Exception("Test Exception") |
| 541 | + |
| 542 | + mock_destroy_process_group.assert_not_called() |
| 543 | + |
| 544 | + # Test exception handling with new pg - forward exception, group should be destroyed |
| 545 | + mock_destroy_process_group.reset_mock() |
| 546 | + with tc.assertRaisesRegex(Exception, "Test Exception"): |
| 547 | + with patch( |
| 548 | + "torchtnt.utils.distributed.dist.get_backend", |
| 549 | + side_effect=_get_backend_side_effect(), |
| 550 | + ): |
| 551 | + with get_or_create_gloo_pg() as pg: |
| 552 | + tc.assertIsNot(pg, dist.group.WORLD) |
| 553 | + tc.assertEqual( |
| 554 | + none_throws(pg)._get_backend_name(), dist.Backend.GLOO |
| 555 | + ) |
| 556 | + raise Exception("Test Exception") |
| 557 | + |
| 558 | + mock_destroy_process_group.assert_called_once_with(pg) |
0 commit comments