diff --git a/recipes_source/distributed_device_mesh.rst b/recipes_source/distributed_device_mesh.rst index 3a04b8de4b..cc982c1179 100644 --- a/recipes_source/distributed_device_mesh.rst +++ b/recipes_source/distributed_device_mesh.rst @@ -138,7 +138,7 @@ users would not need to manually create and manage shard group and replicate gro # HSDP: MeshShape(2, 4) mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp_replicate", "dp_shard")) model = FSDP( - ToyModel(), device_mesh=mesh_2d + ToyModel(), mesh=mesh_2d ) Let's create a file named ``hsdp.py``.