41
41
)
42
42
from torch .distributed .device_mesh import init_device_mesh
43
43
from torch .distributed .fsdp .fully_sharded_data_parallel import FullStateDictConfig
44
+ from torchtnt .utils .device_mesh import GlobalMeshCoordinator
44
45
from torchtnt .utils .precision import convert_precision_str_to_dtype
45
46
46
47
try :
@@ -367,7 +368,7 @@ def prepare_fsdp2(
367
368
module : torch .nn .Module ,
368
369
device : torch .device ,
369
370
strategy : Optional [FSDP2Strategy ] = None ,
370
- process_group : Optional [ProcessGroup ] = None ,
371
+ global_mesh : Optional [GlobalMeshCoordinator ] = None ,
371
372
) -> torch .nn .Module :
372
373
"""
373
374
Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_
@@ -376,12 +377,18 @@ def prepare_fsdp2(
376
377
module: module to be wrapped in FSDP
377
378
device: device to which module will be moved
378
379
strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs
380
+ global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
381
+ If not provided, a 1D default mesh will be created covering the entire world size.
379
382
"""
380
383
strategy = strategy or FSDP2Strategy ()
381
384
382
385
# prepare kwargs for fully_shard api
383
- pg = process_group or dist .distributed_c10d ._get_default_group ()
384
- mesh = init_device_mesh (device .type , mesh_shape = (pg .size (),))
386
+ if global_mesh is None :
387
+ pg = dist .distributed_c10d ._get_default_group ()
388
+ mesh = init_device_mesh (device .type , mesh_shape = (pg .size (),))
389
+ else :
390
+ mesh = global_mesh .dp_mesh
391
+
385
392
fsdp_kwargs : Dict [str , Any ] = {
386
393
"mesh" : mesh , # TODO we only configure 1D mesh for now, look into supporting HSDP
387
394
"reshard_after_forward" : strategy .reshard_after_forward ,
@@ -599,6 +606,7 @@ def prepare_module(
599
606
torch_compile_params : Optional [TorchCompileParams ] = None ,
600
607
activation_checkpoint_params : Optional [ActivationCheckpointParams ] = None ,
601
608
enable_compiled_autograd : bool = False ,
609
+ global_mesh : Optional [GlobalMeshCoordinator ] = None ,
602
610
) -> torch .nn .Module :
603
611
"""
604
612
Utility to move a module to device, set up parallelism, activation checkpointing and compile.
@@ -610,6 +618,8 @@ def prepare_module(
610
618
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
611
619
activation_checkpoint_params: params for enabling activation checkpointing.
612
620
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
621
+ global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology.
622
+ Only pass here if wanting to configure HSDP setup with FSDP2
613
623
"""
614
624
615
625
if strategy :
@@ -652,7 +662,7 @@ def prepare_module(
652
662
)
653
663
module = prepare_fsdp (module , device , strategy )
654
664
elif isinstance (strategy , FSDP2Strategy ):
655
- module = prepare_fsdp2 (module , device , strategy )
665
+ module = prepare_fsdp2 (module , device , strategy , global_mesh = global_mesh )
656
666
else :
657
667
module = module .to (device )
658
668
0 commit comments