@@ -6,25 +6,44 @@ Wright <https://github.com/lessw2020>`__, `Rohan Varma
66<https://github.com/rohan-varma/> `__, `Yanli Zhao
77<https://github.com/zhaojuanmao> `__
88
9+ .. grid :: 2
10+
11+ .. grid-item-card :: :octicon:`mortar-board;1em;` What you will learn
12+ :class-card: card-prerequisites
13+
14+ * PyTorch's Fully Sharded Data Parallel Module: A wrapper for sharding module parameters across
15+ data parallel workers.
16+
17+
18+
19+
20+ .. grid-item-card :: :octicon:`list-unordered;1em;` Prerequisites
21+ :class-card: card-prerequisites
22+
23+ * PyTorch 1.12 or later
24+ * Read about the `FSDP API <https://pytorch.org/docs/main/fsdp.html >`__.
25+
926
1027This tutorial introduces more advanced features of Fully Sharded Data Parallel
1128(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please
1229refer to the `FSDP getting started tutorial
1330<https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html> `__.
1431
1532In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text
16- summarization as a working example.
33+ summarization as a working example.
1734
1835The example uses Wikihow and for simplicity, we will showcase the training on a
19- single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on
20- large scale FSDP training on a multi-node cluster, please stay tuned for that on
21- the PyTorch medium channel.
36+ single node, P4dn instance with 8 A100 GPUs. We now have several blog posts (
37+ `(link1), <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ >`__
38+ `(link2) <https://engineering.fb.com/2021/07/15/open-source/fsdp/ >`__)
39+ and a `paper <https://arxiv.org/abs/2304.11277 >`__ on
40+ large scale FSDP training on a multi-node cluster.
2241
2342FSDP is a production ready package with focus on ease of use, performance, and
2443long-term support. One of the main benefits of FSDP is reducing the memory
2544footprint on each GPU. This enables training of larger models with lower total
2645memory vs DDP, and leverages the overlap of computation and communication to
27- train models efficiently.
46+ train models efficiently.
2847This reduced memory pressure can be leveraged to either train larger models or
2948increase batch size, potentially helping overall training throughput. You can
3049read more about PyTorch FSDP `here
@@ -47,21 +66,21 @@ Recap on How FSDP Works
4766
4867At a high level FDSP works as follow:
4968
50- *In constructor *
69+ *In the constructor *
5170
5271* Shard model parameters and each rank only keeps its own shard
5372
54- *In forward pass *
73+ *In the forward pass *
5574
5675* Run `all_gather ` to collect all shards from all ranks to recover the full
57- parameter for this FSDP unit Run forward computation
58- * Discard non-owned parameter shards it has just collected to free memory
76+ parameter for this FSDP unit and run the forward computation
77+ * Discard the non-owned parameter shards it has just collected to free memory
5978
60- *In backward pass *
79+ *In the backward pass *
6180
6281* Run `all_gather ` to collect all shards from all ranks to recover the full
63- parameter in this FSDP unit Run backward computation
64- * Discard non-owned parameters to free memory.
82+ parameter in this FSDP unit and run backward computation
83+ * Discard non-owned parameters to free memory.
6584* Run reduce_scatter to sync gradients
6685
6786
@@ -80,15 +99,11 @@ examples
8099
81100*Setup *
82101
83- 1.1 Install PyTorch Nightlies
84-
85- We will install PyTorch nightlies, as some of the features such as activation
86- checkpointing is available in nightlies and will be added in next PyTorch
87- release after 1.12.
102+ 1.1 Install the latest PyTorch
88103
89- .. code-block :: bash
104+ .. code-block :: bash
90105
91- pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
106+ pip3 install torch torchvision torchaudio
92107
93108 1.2 Dataset Setup
94109
@@ -154,7 +169,7 @@ Next, we add the following code snippets to a Python script “T5_training.py”
154169 import tqdm
155170 from datetime import datetime
156171
157- 1.4 Distributed training setup.
172+ 1.4 Distributed training setup.
158173Here we use two helper functions to initialize the processes for distributed
159174training, and then to clean up after training completion. In this tutorial, we
160175are going to use torch elastic, using `torchrun
@@ -191,13 +206,13 @@ metrics.
191206 date_of_run = datetime.now().strftime(" %Y-%m-%d -%I:%M:%S_%p" )
192207 print (f " --> current date and time of run = { date_of_run} " )
193208 return date_of_run
194-
209+
195210 def format_metrics_to_gb (item ):
196211 """ quick function to format numbers to gigabyte and round to 4 digit precision"""
197212 metric_num = item / g_gigabyte
198213 metric_num = round (metric_num, ndigits = 4 )
199214 return metric_num
200-
215+
201216
202217 2.2 Define a train function:
203218
@@ -275,7 +290,7 @@ metrics.
275290
276291.. code-block :: python
277292
278-
293+
279294 def fsdp_main (args ):
280295
281296 model, tokenizer = setup_model(" t5-base" )
@@ -292,7 +307,7 @@ metrics.
292307
293308
294309 # wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
295- train_dataset = wikihow(tokenizer, ' train' , 1500 , 512 , 150 , False )
310+ train_dataset = wikihow(tokenizer, ' train' , 1500 , 512 , 150 , False )
296311 val_dataset = wikihow(tokenizer, ' validation' , 300 , 512 , 150 , False )
297312
298313 sampler1 = DistributedSampler(train_dataset, rank = rank, num_replicas = world_size, shuffle = True )
@@ -430,7 +445,7 @@ metrics.
430445
431446.. code-block :: python
432447
433-
448+
434449 if __name__ == ' __main__' :
435450 # Training settings
436451 parser = argparse.ArgumentParser(description = ' PyTorch T5 FSDP Example' )
@@ -463,7 +478,7 @@ metrics.
463478
464479 To run the the training using torchrun:
465480
466- .. code-block :: bash
481+ .. code-block :: bash
467482
468483 torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
469484
@@ -487,7 +502,7 @@ communication efficient. In PyTorch 1.12, FSDP added this support and now we
487502have a wrapping policy for transfomers.
488503
489504It can be created as follows, where the T5Block represents the T5 transformer
490- layer class (holding MHSA and FFN).
505+ layer class (holding MHSA and FFN).
491506
492507
493508.. code-block :: python
@@ -499,7 +514,7 @@ layer class (holding MHSA and FFN).
499514 },
500515 )
501516 torch.cuda.set_device(local_rank)
502-
517+
503518
504519 model = FSDP(model,
505520 auto_wrap_policy = t5_auto_wrap_policy)
@@ -513,22 +528,22 @@ Mixed Precision
513528FSDP supports flexible mixed precision training allowing for arbitrary reduced
514529precision types (such as fp16 or bfloat16). Currently BFloat16 is only available
515530on Ampere GPUs, so you need to confirm native support before you use it. On
516- V100s for example, BFloat16 can still be run but due to it running non-natively,
531+ V100s for example, BFloat16 can still be run but because it runs non-natively,
517532it can result in significant slowdowns.
518533
519534To check if BFloat16 is natively supported, you can use the following :
520535
521536.. code-block :: python
522-
537+
523538 bf16_ready = (
524539 torch.version.cuda
525- and torch.cuda.is_bf16_supported()
540+ and torch.cuda.is_bf16_supported()
526541 and LooseVersion(torch.version.cuda) >= " 11.0"
527542 and dist.is_nccl_available()
528543 and nccl.version() >= (2 , 10 )
529544 )
530545
531- One of the advantages of mixed percision in FSDP is providing granular control
546+ One of the advantages of mixed precision in FSDP is providing granular control
532547over different precision levels for parameters, gradients, and buffers as
533548follows:
534549
@@ -571,7 +586,7 @@ with the following policy:
571586.. code-block :: bash
572587
573588 grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
574-
589+
575590
576591 In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper:
577592
@@ -604,9 +619,9 @@ CPU-based initialization:
604619 auto_wrap_policy = t5_auto_wrap_policy,
605620 mixed_precision = bfSixteen,
606621 device_id = torch.cuda.current_device())
607-
608622
609-
623+
624+
610625 Sharding Strategy
611626-----------------
612627FSDP sharding strategy by default is set to fully shard the model parameters,
@@ -627,7 +642,7 @@ instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows:
627642 sharding_strategy = ShardingStrategy.SHARD_GRAD_OP # ZERO2)
628643
629644This will reduce the communication overhead in FSDP , in this case, it holds full
630- parameters after forward and through the backwards pass .
645+ parameters after forward and through the backwards pass .
631646
632647This saves an all_gather during backwards so there is less communication at the
633648cost of a higher memory footprint. Note that full model params are freed at the
@@ -652,12 +667,12 @@ wrapper in 2.4 as follows:
652667 mixed_precision = bfSixteen,
653668 device_id = torch.cuda.current_device(),
654669 backward_prefetch = BackwardPrefetch.BACKWARD_PRE )
655-
670+
656671`backward_prefetch` has two modes, `BACKWARD_PRE ` and `BACKWARD_POST ` .
657672`BACKWARD_POST ` means that the next FSDP unit' s params will not be requested
658673until the current FSDP unit processing is complete, thus minimizing memory
659674overhead. In some cases, using `BACKWARD_PRE ` can increase model training speed
660- up to 2 - 10 % , with even higher speed improvements noted for larger models.
675+ up to 2 - 10 % , with even higher speed improvements noted for larger models.
661676
662677Model Checkpoint Saving, by streaming to the Rank0 CPU
663678------------------------------------------------------
@@ -696,7 +711,7 @@ Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping
696711policy especially for transformer models, along with mixed precision and
697712backward prefetch should speed up your training runs. Also, features such as
698713initializing the model on device, and checkpoint saving via streaming to CPU
699- should help to avoid OOM error in dealing with large models.
714+ should help to avoid OOM error in dealing with large models.
700715
701716We are actively working to add new features to FSDP for the next release. If
702717you have feedback, feature requests, questions or are encountering issues
0 commit comments