Skip to content

Commit 6dd3297

Browse files
authored
Merge branch 'main' into tp_tutorial_2
2 parents 630e1d2 + 755434d commit 6dd3297

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
328328
329329
from torch.distributed.device_mesh import init_device_mesh
330330
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
331-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
331+
from torch.distributed.fsdp import fully_shard
332332
333333
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
334334
mesh_2d = init_device_mesh("cuda", (8, 8))
@@ -342,7 +342,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
342342
# apply Tensor Parallel intra-host on tp_mesh
343343
model_tp = parallelize_module(model, tp_mesh, tp_plan)
344344
# apply FSDP inter-host on dp_mesh
345-
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
345+
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)
346346
347347
348348
This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.

0 commit comments

Comments
 (0)