|
41 | 41 | )
|
42 | 42 | from torch.distributed.device_mesh import init_device_mesh
|
43 | 43 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
|
| 44 | +from torch.distributed.tensor.parallel import parallelize_module |
| 45 | +from torch.distributed.tensor.parallel.style import ParallelStyle |
44 | 46 | from torchtnt.utils.device_mesh import GlobalMeshCoordinator
|
45 | 47 | from torchtnt.utils.precision import convert_precision_str_to_dtype
|
46 | 48 |
|
@@ -229,6 +231,20 @@ class FSDP2Strategy(Strategy):
|
229 | 231 | cpu_offload: bool = False
|
230 | 232 |
|
231 | 233 |
|
| 234 | +@dataclass |
| 235 | +class TPStrategy(Strategy): |
| 236 | + """ |
| 237 | + Dataclass representing Tensor Parallelism strategy. Specify the FSDP strategy for 2D parallelism setup. |
| 238 | +
|
| 239 | + Args: |
| 240 | + tp_plan: The plan used to parallelize the module. See https://pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.parallelize_module for details. |
| 241 | + fsdp2_strategy (optional): fsdp2 strategy to configure 2D parallel strategy |
| 242 | + """ |
| 243 | + |
| 244 | + tp_plan: Union[ParallelStyle, Dict[str, ParallelStyle]] |
| 245 | + fsdp2_strategy: Optional[FSDP2Strategy] = None |
| 246 | + |
| 247 | + |
232 | 248 | @dataclass
|
233 | 249 | class TorchCompileParams:
|
234 | 250 | """
|
@@ -609,7 +625,55 @@ def prepare_module(
|
609 | 625 | global_mesh: Optional[GlobalMeshCoordinator] = None,
|
610 | 626 | ) -> torch.nn.Module:
|
611 | 627 | """
|
612 |
| - Utility to move a module to device, set up parallelism, activation checkpointing and compile. |
| 628 | + Utility to move a module to device, set up parallelism (None, DDP, FSDP, HSDP, TP), activation checkpointing and compile. |
| 629 | + This function acts as a dispatcher to choose between 1D and 2D parallelism setup, depending on the strategy used. |
| 630 | +
|
| 631 | + Args: |
| 632 | + module: module to be used. |
| 633 | + device: device to which module will be moved. |
| 634 | + strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp``, ``fsdp``, or ``noop``. |
| 635 | + torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html. |
| 636 | + activation_checkpoint_params: params for enabling activation checkpointing. |
| 637 | + enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag. |
| 638 | + global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. |
| 639 | + """ |
| 640 | + if isinstance(strategy, TPStrategy): |
| 641 | + if global_mesh is None: |
| 642 | + raise ValueError( |
| 643 | + "TPStrategy expects global_mesh (GlobalMeshCoordinator) to be defined. Got None." |
| 644 | + ) |
| 645 | + return _prepare_module_2d( |
| 646 | + module, |
| 647 | + device, |
| 648 | + strategy=strategy, |
| 649 | + global_mesh=global_mesh, |
| 650 | + torch_compile_params=torch_compile_params, |
| 651 | + activation_checkpoint_params=activation_checkpoint_params, |
| 652 | + ) |
| 653 | + |
| 654 | + return _prepare_module_1d( |
| 655 | + module, |
| 656 | + device, |
| 657 | + strategy=strategy, |
| 658 | + torch_compile_params=torch_compile_params, |
| 659 | + activation_checkpoint_params=activation_checkpoint_params, |
| 660 | + enable_compiled_autograd=enable_compiled_autograd, |
| 661 | + global_mesh=global_mesh, |
| 662 | + ) |
| 663 | + |
| 664 | + |
| 665 | +def _prepare_module_1d( |
| 666 | + module: torch.nn.Module, |
| 667 | + device: torch.device, |
| 668 | + *, |
| 669 | + strategy: Optional[Union[Strategy, str]] = None, |
| 670 | + torch_compile_params: Optional[TorchCompileParams] = None, |
| 671 | + activation_checkpoint_params: Optional[ActivationCheckpointParams] = None, |
| 672 | + enable_compiled_autograd: bool = False, |
| 673 | + global_mesh: Optional[GlobalMeshCoordinator] = None, |
| 674 | +) -> torch.nn.Module: |
| 675 | + """ |
| 676 | + Utility to move a module to device, set up 1D parallelism (None, DDP, FSDP), activation checkpointing and compile. |
613 | 677 |
|
614 | 678 | Args:
|
615 | 679 | module: module to be used.
|
@@ -675,6 +739,51 @@ def prepare_module(
|
675 | 739 | return module
|
676 | 740 |
|
677 | 741 |
|
| 742 | +def _prepare_module_2d( |
| 743 | + module: torch.nn.Module, |
| 744 | + device: torch.device, |
| 745 | + *, |
| 746 | + strategy: TPStrategy, |
| 747 | + global_mesh: GlobalMeshCoordinator, |
| 748 | + torch_compile_params: Optional[TorchCompileParams] = None, |
| 749 | + activation_checkpoint_params: Optional[ActivationCheckpointParams] = None, |
| 750 | +) -> torch.nn.Module: |
| 751 | + """ |
| 752 | + Utility to move a module to device, set up 2D parallelism (FSDP / TP / HSDP), activation checkpointing and compile. |
| 753 | +
|
| 754 | + Order of composability is TP -> AC -> compile -> fsdp2. |
| 755 | +
|
| 756 | + Args: |
| 757 | + module: module to be used. |
| 758 | + device: device to which module will be moved. |
| 759 | + strategy: the TP parallelization strategy to be used. |
| 760 | + global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. |
| 761 | + torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html. |
| 762 | + activation_checkpoint_params: params for enabling activation checkpointing. |
| 763 | + """ |
| 764 | + |
| 765 | + # 1) apply TP |
| 766 | + parallelize_module(module, global_mesh.tp_mesh, parallelize_plan=strategy.tp_plan) |
| 767 | + |
| 768 | + # 2) apply AC if specified |
| 769 | + if activation_checkpoint_params: |
| 770 | + apply_ac(module, activation_checkpoint_params) |
| 771 | + |
| 772 | + # 3) apply torch.compile is specified |
| 773 | + if torch_compile_params: |
| 774 | + apply_torch_compile(module, torch_compile_params) |
| 775 | + |
| 776 | + # 4) apply data parallel / HSDP sharding (via FSDP2 apis) if specified in TPStrategy |
| 777 | + if (fsdp2_strategy := strategy.fsdp2_strategy) is not None: |
| 778 | + prepare_fsdp2(module, device, fsdp2_strategy, global_mesh) |
| 779 | + else: |
| 780 | + # prepare_fsdp2 will handle materializing meta weights |
| 781 | + # so if fsdp2strategy isn't used, we do it manually here |
| 782 | + materialize_meta_params(module, device) |
| 783 | + |
| 784 | + return module |
| 785 | + |
| 786 | + |
678 | 787 | def convert_str_to_strategy(
|
679 | 788 | strategy: str,
|
680 | 789 | ) -> Union[DDPStrategy, FSDPStrategy, FSDP2Strategy, NOOPStrategy]:
|
|
0 commit comments