Skip to content

Commit 70abcd1

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add TPStrategy + 1d/2d parallel dispatcher (#1002)
Summary: Pull Request resolved: #1002 # Context TP, AC, FSDP2, and torch compile require a specific order (TP -> AC -> compile -> fsdp2). The current order in `prepare_module` is incompatible with this. # This Diff 1) Adds `TPStrategy` dataclass 2) Rename old `prepare_module` logic for plain DDP/FSDP/FSDP2 as `_prepare_module_1d` 3) Shard TP via `_prepare_module_2d` to handle TP and TP+FSDP2+HSDP. The correct order of operations will be applied here 4) Rework `prepare_module` to use `_prepare_module_2d` for TP/any 2D parallel application, and `_prepare_module_1d` for all other strategies Reviewed By: galrotem Differential Revision: D74410708 fbshipit-source-id: 04fda80ef619784d5d9ad5c4db0377e77dc43c75
1 parent 98e78df commit 70abcd1

File tree

2 files changed

+181
-1
lines changed

2 files changed

+181
-1
lines changed

tests/utils/test_prepare_module.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchtnt.utils.env import init_from_env
2020
from torchtnt.utils.prepare_module import (
2121
_check_and_convert_mp_policy_dtypes,
22+
_prepare_module_2d,
2223
apply_torch_compile,
2324
DDPStrategy,
2425
FSDP2Strategy,
@@ -29,6 +30,7 @@
2930
prepare_fsdp2,
3031
prepare_module,
3132
TorchCompileParams,
33+
TPStrategy,
3234
)
3335
from torchtnt.utils.test_utils import skip_if_not_distributed
3436
from torchtnt.utils.version import is_torch_version_geq
@@ -294,6 +296,75 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
294296
module, mesh=mock_mesh, reshard_after_forward=True
295297
)
296298

299+
@patch("torchtnt.utils.prepare_module._prepare_module_2d")
300+
@patch("torchtnt.utils.prepare_module._prepare_module_1d")
301+
def test_prepare_module_dispatching(
302+
self, mock_prepare_module_1d: Mock, mock_prepare_module_2d: Mock
303+
) -> None:
304+
"""
305+
Test that prepare_module dispatches to the correct 1d/2d function based on the strategy
306+
"""
307+
308+
module = torch.nn.Linear(2, 2, device="cpu")
309+
device = torch.device("cpu")
310+
strategy = TPStrategy(tp_plan={}, fsdp2_strategy=None)
311+
312+
with self.assertRaisesRegex(ValueError, "TPStrategy expects global_mesh"):
313+
prepare_module(
314+
module,
315+
device,
316+
strategy=strategy,
317+
global_mesh=None,
318+
)
319+
320+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
321+
prepare_module(
322+
module,
323+
device,
324+
strategy=strategy,
325+
global_mesh=mock_global_mesh,
326+
)
327+
mock_prepare_module_2d.assert_called_with(
328+
module,
329+
device,
330+
strategy=strategy,
331+
global_mesh=mock_global_mesh,
332+
torch_compile_params=None,
333+
activation_checkpoint_params=None,
334+
)
335+
336+
strategy = FSDP2Strategy()
337+
prepare_module(
338+
module,
339+
device,
340+
strategy=strategy,
341+
global_mesh=mock_global_mesh,
342+
)
343+
mock_prepare_module_1d.assert_called_with(
344+
module,
345+
device,
346+
strategy=strategy,
347+
global_mesh=mock_global_mesh,
348+
torch_compile_params=None,
349+
activation_checkpoint_params=None,
350+
enable_compiled_autograd=False,
351+
)
352+
353+
@patch("torchtnt.utils.prepare_module.parallelize_module")
354+
def test_prepare_module_2d(self, mock_parallelize_module: Mock) -> None:
355+
"""
356+
Test that prepare_module_2d invokes TP apis
357+
"""
358+
359+
module = torch.nn.Linear(2, 2, device="cpu")
360+
device = torch.device("cpu")
361+
strategy = TPStrategy(tp_plan={}, fsdp2_strategy=None)
362+
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
363+
_prepare_module_2d(
364+
module, device, strategy=strategy, global_mesh=mock_global_mesh
365+
)
366+
mock_parallelize_module.assert_called_once()
367+
297368
def test_apply_torch_compile_recursive_module_types(self) -> None:
298369
"""
299370
Test that recursive_module_types is apply correctly.

torchtnt/utils/prepare_module.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
)
4242
from torch.distributed.device_mesh import init_device_mesh
4343
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
44+
from torch.distributed.tensor.parallel import parallelize_module
45+
from torch.distributed.tensor.parallel.style import ParallelStyle
4446
from torchtnt.utils.device_mesh import GlobalMeshCoordinator
4547
from torchtnt.utils.precision import convert_precision_str_to_dtype
4648

@@ -229,6 +231,20 @@ class FSDP2Strategy(Strategy):
229231
cpu_offload: bool = False
230232

231233

234+
@dataclass
235+
class TPStrategy(Strategy):
236+
"""
237+
Dataclass representing Tensor Parallelism strategy. Specify the FSDP strategy for 2D parallelism setup.
238+
239+
Args:
240+
tp_plan: The plan used to parallelize the module. See https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.parallelize_module for details.
241+
fsdp2_strategy (optional): fsdp2 strategy to configure 2D parallel strategy
242+
"""
243+
244+
tp_plan: Union[ParallelStyle, Dict[str, ParallelStyle]]
245+
fsdp2_strategy: Optional[FSDP2Strategy] = None
246+
247+
232248
@dataclass
233249
class TorchCompileParams:
234250
"""
@@ -609,7 +625,55 @@ def prepare_module(
609625
global_mesh: Optional[GlobalMeshCoordinator] = None,
610626
) -> torch.nn.Module:
611627
"""
612-
Utility to move a module to device, set up parallelism, activation checkpointing and compile.
628+
Utility to move a module to device, set up parallelism (None, DDP, FSDP, HSDP, TP), activation checkpointing and compile.
629+
This function acts as a dispatcher to choose between 1D and 2D parallelism setup, depending on the strategy used.
630+
631+
Args:
632+
module: module to be used.
633+
device: device to which module will be moved.
634+
strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp``, ``fsdp``, or ``noop``.
635+
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
636+
activation_checkpoint_params: params for enabling activation checkpointing.
637+
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
638+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
639+
"""
640+
if isinstance(strategy, TPStrategy):
641+
if global_mesh is None:
642+
raise ValueError(
643+
"TPStrategy expects global_mesh (GlobalMeshCoordinator) to be defined. Got None."
644+
)
645+
return _prepare_module_2d(
646+
module,
647+
device,
648+
strategy=strategy,
649+
global_mesh=global_mesh,
650+
torch_compile_params=torch_compile_params,
651+
activation_checkpoint_params=activation_checkpoint_params,
652+
)
653+
654+
return _prepare_module_1d(
655+
module,
656+
device,
657+
strategy=strategy,
658+
torch_compile_params=torch_compile_params,
659+
activation_checkpoint_params=activation_checkpoint_params,
660+
enable_compiled_autograd=enable_compiled_autograd,
661+
global_mesh=global_mesh,
662+
)
663+
664+
665+
def _prepare_module_1d(
666+
module: torch.nn.Module,
667+
device: torch.device,
668+
*,
669+
strategy: Optional[Union[Strategy, str]] = None,
670+
torch_compile_params: Optional[TorchCompileParams] = None,
671+
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
672+
enable_compiled_autograd: bool = False,
673+
global_mesh: Optional[GlobalMeshCoordinator] = None,
674+
) -> torch.nn.Module:
675+
"""
676+
Utility to move a module to device, set up 1D parallelism (None, DDP, FSDP), activation checkpointing and compile.
613677
614678
Args:
615679
module: module to be used.
@@ -675,6 +739,51 @@ def prepare_module(
675739
return module
676740

677741

742+
def _prepare_module_2d(
743+
module: torch.nn.Module,
744+
device: torch.device,
745+
*,
746+
strategy: TPStrategy,
747+
global_mesh: GlobalMeshCoordinator,
748+
torch_compile_params: Optional[TorchCompileParams] = None,
749+
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
750+
) -> torch.nn.Module:
751+
"""
752+
Utility to move a module to device, set up 2D parallelism (FSDP / TP / HSDP), activation checkpointing and compile.
753+
754+
Order of composability is TP -> AC -> compile -> fsdp2.
755+
756+
Args:
757+
module: module to be used.
758+
device: device to which module will be moved.
759+
strategy: the TP parallelization strategy to be used.
760+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
761+
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
762+
activation_checkpoint_params: params for enabling activation checkpointing.
763+
"""
764+
765+
# 1) apply TP
766+
parallelize_module(module, global_mesh.tp_mesh, parallelize_plan=strategy.tp_plan)
767+
768+
# 2) apply AC if specified
769+
if activation_checkpoint_params:
770+
apply_ac(module, activation_checkpoint_params)
771+
772+
# 3) apply torch.compile is specified
773+
if torch_compile_params:
774+
apply_torch_compile(module, torch_compile_params)
775+
776+
# 4) apply data parallel / HSDP sharding (via FSDP2 apis) if specified in TPStrategy
777+
if (fsdp2_strategy := strategy.fsdp2_strategy) is not None:
778+
prepare_fsdp2(module, device, fsdp2_strategy, global_mesh)
779+
else:
780+
# prepare_fsdp2 will handle materializing meta weights
781+
# so if fsdp2strategy isn't used, we do it manually here
782+
materialize_meta_params(module, device)
783+
784+
return module
785+
786+
678787
def convert_str_to_strategy(
679788
strategy: str,
680789
) -> Union[DDPStrategy, FSDPStrategy, FSDP2Strategy, NOOPStrategy]:

0 commit comments

Comments
 (0)