diff --git a/recipes_source/distributed_device_mesh.rst b/recipes_source/distributed_device_mesh.rst index cf87ecf16a9..3a04b8de4bf 100644 --- a/recipes_source/distributed_device_mesh.rst +++ b/recipes_source/distributed_device_mesh.rst @@ -121,7 +121,7 @@ users would not need to manually create and manage shard group and replicate gro import torch.nn as nn from torch.distributed.device_mesh import init_device_mesh - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy + from torch.distributed.fsdp import fully_shard as FSDP class ToyModel(nn.Module): @@ -136,9 +136,9 @@ users would not need to manually create and manage shard group and replicate gro # HSDP: MeshShape(2, 4) - mesh_2d = init_device_mesh("cuda", (2, 4)) + mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp_replicate", "dp_shard")) model = FSDP( - ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD + ToyModel(), device_mesh=mesh_2d ) Let's create a file named ``hsdp.py``.