Skip to content

Commit cb31137

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add global_mesh support in AutoUnit (#1003)
Summary: Pull Request resolved: #1003 Adds `global_mesh` arg in AutoUnit and forwards into `prepare_module` for model sharding Reviewed By: vdogaru Differential Revision: D74410711 fbshipit-source-id: fb7caedef706c9d8f7876f14d6d31e1d4aaa7151
1 parent 70abcd1 commit cb31137

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

tests/framework/test_auto_unit.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import unittest
1111
from typing import Any, Literal, Optional, Tuple, TypeVar
12-
from unittest.mock import MagicMock, patch
12+
from unittest.mock import MagicMock, Mock, patch
1313

1414
import torch
1515

@@ -37,6 +37,7 @@
3737
from torchtnt.framework.train import train
3838
from torchtnt.framework.unit import TPredictData
3939
from torchtnt.utils.device import copy_data_to_device
40+
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
4041
from torchtnt.utils.distributed import spawn_multi_process
4142
from torchtnt.utils.env import init_from_env
4243
from torchtnt.utils.lr_scheduler import TLRScheduler
@@ -780,6 +781,34 @@ def test_gradient_accumulation_fsdp2(self, _) -> None:
780781

781782
auto_unit.train_progress.increment_step()
782783

784+
@patch("torchtnt.framework.auto_unit.prepare_module")
785+
def test_global_mesh(self, mock_prepare_module: Mock) -> None:
786+
"""
787+
Test that the global mesh is forwarded correctly in the AutoUnit.
788+
"""
789+
module = torch.nn.Linear(1, 1)
790+
device = torch.device("cpu")
791+
strategy = DDPStrategy()
792+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
793+
mock_prepare_module.return_value = module
794+
795+
DummyAutoUnit(
796+
module=module,
797+
device=device,
798+
strategy=strategy,
799+
global_mesh=mock_global_mesh,
800+
)
801+
802+
mock_prepare_module.assert_called_once_with(
803+
module,
804+
device,
805+
strategy=strategy,
806+
torch_compile_params=None,
807+
activation_checkpoint_params=None,
808+
enable_compiled_autograd=False,
809+
global_mesh=mock_global_mesh,
810+
)
811+
783812

784813
Batch = Tuple[torch.Tensor, torch.Tensor]
785814

torchtnt/framework/auto_unit.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchtnt.framework.unit import EvalUnit, PredictUnit, TPredictData, TrainUnit
3636
from torchtnt.framework.utils import get_timing_context
3737
from torchtnt.utils.device import copy_data_to_device
38+
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
3839
from torchtnt.utils.env import init_from_env
3940
from torchtnt.utils.lr_scheduler import TLRScheduler
4041
from torchtnt.utils.precision import (
@@ -326,6 +327,7 @@ def __init__(
326327
torch_compile_params: Optional[TorchCompileParams] = None,
327328
detect_anomaly: Optional[bool] = None,
328329
enable_prefetch: bool = False,
330+
global_mesh: Optional[GlobalMeshCoordinator] = None,
329331
) -> None:
330332
"""
331333
AutoPredictUnit is a convenience for users who are running inference and would like to have certain features handled for them, such as:
@@ -348,6 +350,7 @@ def __init__(
348350
strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp`` or ``fsdp``.
349351
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
350352
detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
353+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
351354
352355
Note:
353356
Torch compile support is only available in PyTorch 2.0 or higher.
@@ -365,6 +368,7 @@ def __init__(
365368
self.device,
366369
strategy=strategy,
367370
torch_compile_params=torch_compile_params,
371+
global_mesh=global_mesh,
368372
)
369373

370374
# pyre-fixme[3]: Return annotation cannot be `Any`.
@@ -474,6 +478,7 @@ class AutoUnit(
474478
in a much more efficient way.
475479
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
476480
zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.
481+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
477482
478483
Note:
479484
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
@@ -510,6 +515,7 @@ def __init__(
510515
loss_backward_retain_graph: Optional[bool] = None,
511516
enable_prefetch: bool = True,
512517
zero_grad_at_train_step_start: bool = False,
518+
global_mesh: Optional[GlobalMeshCoordinator] = None,
513519
) -> None:
514520
super().__init__(
515521
module=module,
@@ -554,6 +560,7 @@ def __init__(
554560
torch_compile_params=torch_compile_params,
555561
activation_checkpoint_params=activation_checkpoint_params,
556562
enable_compiled_autograd=enable_compiled_autograd,
563+
global_mesh=global_mesh,
557564
)
558565

559566
self.grad_scaler: Optional[GradScaler] = None

0 commit comments

Comments
 (0)