Skip to content

Commit d71a41b

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add fsdp2 support (#967)
Summary: Pull Request resolved: #967 Reviewed By: anshulverma Differential Revision: D68735961 fbshipit-source-id: 69bdd1bd700dd58f4c92ed6ba8bc4ae0b4432dc0
1 parent 7272dbd commit d71a41b

File tree

2 files changed

+291
-1
lines changed

2 files changed

+291
-1
lines changed

tests/utils/test_prepare_module_gpu.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,32 @@
77

88
# pyre-strict
99
import unittest
10+
from typing import Any
1011

1112
import torch
1213
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14+
15+
try:
16+
from torch.distributed.fsdp import fully_shard
17+
except ImportError:
18+
19+
def noop(*args: Any, **kwargs: Any) -> None:
20+
pass
21+
22+
fully_shard = noop
23+
1324
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
1425
from torch.nn.parallel import DistributedDataParallel as DDP
1526
from torchtnt.utils.distributed import spawn_multi_process
1627
from torchtnt.utils.env import init_from_env
1728
from torchtnt.utils.prepare_module import (
1829
_is_fsdp_module,
1930
DDPStrategy,
31+
FSDP2Strategy,
2032
FSDPStrategy,
2133
prepare_ddp,
2234
prepare_fsdp,
35+
prepare_fsdp2,
2336
prepare_module,
2437
)
2538
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
@@ -111,6 +124,11 @@ def _test_is_fsdp_module() -> None:
111124
model = FSDP(model)
112125
assert _is_fsdp_module(model)
113126

127+
# test fsdp2
128+
model = torch.nn.Linear(1, 1, device=device)
129+
fully_shard(model)
130+
assert _is_fsdp_module(model)
131+
114132
@skip_if_not_distributed
115133
@skip_if_not_gpu
116134
def test_fdsp_precision(self) -> None:
@@ -276,3 +294,118 @@ def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None:
276294
tc = unittest.TestCase()
277295

278296
tc.assertTrue(isinstance(fsdp_module, FSDP))
297+
298+
@skip_if_not_distributed
299+
@skip_if_not_gpu
300+
def test_prepare_fsdp2(self) -> None:
301+
"""
302+
Launch tests of FSDP2 strategy
303+
"""
304+
305+
spawn_multi_process(
306+
1,
307+
"nccl",
308+
self._test_prepare_fsdp2_none_sharded_raises,
309+
)
310+
311+
spawn_multi_process(
312+
1,
313+
"nccl",
314+
self._test_prepare_fsdp2_shard_all,
315+
)
316+
317+
spawn_multi_process(
318+
1,
319+
"nccl",
320+
self._test_prepare_fsdp2_submodule,
321+
)
322+
323+
spawn_multi_process(
324+
1,
325+
"nccl",
326+
self._test_prepare_fsdp2_meta_device,
327+
)
328+
329+
@staticmethod
330+
def _test_prepare_fsdp2_none_sharded_raises() -> None:
331+
"""
332+
Test with a strategy that does not shard any modules, should raise error
333+
"""
334+
tc = unittest.TestCase()
335+
336+
module = SimpleModule()
337+
device = torch.device("cuda")
338+
strategy = FSDP2Strategy(modules_to_shard=[])
339+
with tc.assertRaises(ValueError):
340+
prepare_fsdp2(module, device, strategy)
341+
342+
@staticmethod
343+
def _test_prepare_fsdp2_shard_all() -> None:
344+
"""
345+
Test with a strategy that shards all modules
346+
"""
347+
tc = unittest.TestCase()
348+
349+
module = SimpleModule()
350+
device = torch.device("cuda")
351+
strategy = FSDP2Strategy(modules_to_shard="all")
352+
prepare_fsdp2(module, device, strategy)
353+
354+
for submodule in module.modules():
355+
tc.assertTrue(_is_fsdp_module(submodule))
356+
357+
@staticmethod
358+
def _test_prepare_fsdp2_submodule() -> None:
359+
"""
360+
Test with a strategy that shards modules (either str or module type)
361+
"""
362+
tc = unittest.TestCase()
363+
364+
for t in (torch.nn.Linear, "Linear"):
365+
module = SimpleModule()
366+
device = torch.device("cuda")
367+
strategy = FSDP2Strategy(modules_to_shard=(t,))
368+
prepare_fsdp2(module, device, strategy)
369+
370+
for submodule in module.modules():
371+
if isinstance(submodule, torch.nn.Conv2d):
372+
tc.assertFalse(_is_fsdp_module(submodule))
373+
else:
374+
# linear and SimpleModule are fsdp modules
375+
tc.assertTrue(_is_fsdp_module(submodule))
376+
377+
@staticmethod
378+
def _test_prepare_fsdp2_meta_device() -> None:
379+
"""
380+
Test with a strategy that shards specific modules on meta device
381+
"""
382+
tc = unittest.TestCase()
383+
384+
module = SimpleModule(meta_device=True)
385+
device = torch.device("cuda")
386+
strategy = FSDP2Strategy(modules_to_shard=(torch.nn.Linear,))
387+
prepare_fsdp2(module, device, strategy)
388+
389+
for submodule in module.modules():
390+
if isinstance(submodule, torch.nn.Conv2d):
391+
tc.assertFalse(_is_fsdp_module(submodule))
392+
else:
393+
# linear and SimpleModule are fsdp modules
394+
tc.assertTrue(_is_fsdp_module(submodule))
395+
396+
397+
class SimpleModule(torch.nn.Module):
398+
def __init__(self, meta_device: bool = False) -> None:
399+
super(SimpleModule, self).__init__()
400+
self.linear = torch.nn.Linear(10, 10, device="meta" if meta_device else None)
401+
self.conv1 = torch.nn.Conv2d(
402+
in_channels=3,
403+
out_channels=16,
404+
kernel_size=3,
405+
stride=1,
406+
padding=1,
407+
device="meta" if meta_device else None,
408+
)
409+
410+
def forward(self, x: torch.Tensor) -> torch.Tensor:
411+
return self.linear(x)

torchtnt/utils/prepare_module.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
ContextManager,
1616
Dict,
1717
Iterable,
18+
Literal,
1819
Optional,
20+
Set,
1921
Tuple,
22+
Type,
2023
Union,
2124
)
2225

@@ -30,6 +33,29 @@
3033
checkpoint_wrapper,
3134
CheckpointImpl,
3235
)
36+
from torch.distributed.device_mesh import init_device_mesh
37+
38+
try:
39+
from torch.distributed.fsdp import (
40+
CPUOffloadPolicy,
41+
fully_shard,
42+
MixedPrecisionPolicy,
43+
)
44+
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
45+
except ImportError:
46+
47+
def noop(*args: Any, **kwargs: Any) -> None:
48+
pass
49+
50+
class NOOP:
51+
def __init__(self, *args: Any, **kwargs: Any) -> None:
52+
pass
53+
54+
fully_shard = noop
55+
MixedPrecisionPolicy = NOOP
56+
CPUOffloadPolicy = NOOP
57+
FSDPState = NOOP
58+
3359
from torch.distributed.fsdp import (
3460
FullyShardedDataParallel as FSDP,
3561
StateDictType as _StateDictType,
@@ -146,6 +172,52 @@ def __post_init__(self) -> None:
146172
self.mixed_precision = self.mixed_precision.to_native_mixed_precision()
147173

148174

175+
@dataclass
176+
class FSDP2Strategy(Strategy):
177+
"""
178+
Dataclass representing the `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_ strategy.
179+
For more details on the args, see the link.
180+
181+
Args:
182+
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
183+
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
184+
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
185+
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.
186+
187+
Note:
188+
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
189+
communication overhead.
190+
191+
Example:
192+
>>> model
193+
TransformerDecoder(
194+
(tok_embeddings): Embedding(128256, 4096)
195+
(layers): ModuleList(
196+
(0-31): 32 x TransformerSelfAttentionLayer(
197+
(attn): MultiHeadAttention(
198+
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
199+
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
200+
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
201+
(output_proj): Linear(in_features=4096, out_features=4096, bias=False)
202+
(pos_embeddings): RotaryPositionalEmbeddings()
203+
)
204+
...
205+
)
206+
(output): Linear(in_features=4096, out_features=128256, bias=False)
207+
)
208+
>>> # You can either specify the module to shard as a name ("Linear") or the module type (torch.nn.Linear)
209+
>>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"])
210+
"""
211+
212+
modules_to_shard: Union[
213+
Literal["all"],
214+
Iterable[Union[str, Type[torch.nn.Module]]],
215+
] = "all"
216+
reshard_after_forward: Union[bool, int] = True
217+
mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None
218+
cpu_offload: bool = False
219+
220+
149221
@dataclass
150222
class TorchCompileParams:
151223
"""
@@ -272,6 +344,89 @@ def prepare_fsdp(
272344
return module
273345

274346

347+
def prepare_fsdp2(
348+
module: torch.nn.Module,
349+
device: torch.device,
350+
strategy: Optional[FSDP2Strategy] = None,
351+
process_group: Optional[ProcessGroup] = None,
352+
) -> torch.nn.Module:
353+
"""
354+
Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_
355+
356+
Args:
357+
module: module to be wrapped in FSDP
358+
device: device to which module will be moved
359+
strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs
360+
"""
361+
strategy = strategy or FSDP2Strategy()
362+
363+
# prepare kwargs for fully_shard api
364+
pg = process_group or dist.distributed_c10d._get_default_group()
365+
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
366+
fsdp_kwargs: Dict[str, Any] = {
367+
"mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP
368+
"reshard_after_forward": strategy.reshard_after_forward,
369+
}
370+
if strategy.cpu_offload:
371+
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
372+
if (mp_policy := strategy.mp_policy) is not None:
373+
if isinstance(mp_policy, MixedPrecisionPolicy):
374+
fsdp_kwargs["mixed_precision"] = mp_policy
375+
else:
376+
fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy(
377+
param_dtype=mp_policy,
378+
reduce_dtype=mp_policy,
379+
output_dtype=mp_policy,
380+
cast_forward_inputs=True,
381+
)
382+
383+
# parse out the modules_to_shard argument
384+
modules_to_shard = strategy.modules_to_shard
385+
386+
shard_all = modules_to_shard == "all"
387+
shard_module_names: Set[str] = set()
388+
shard_module_types: Tuple[Type[torch.nn.Module], ...] = ()
389+
if not shard_all:
390+
assert (
391+
type(modules_to_shard) is not str
392+
), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}"
393+
394+
for item in modules_to_shard:
395+
if isinstance(item, str):
396+
shard_module_names.add(item)
397+
else:
398+
shard_module_types = shard_module_types + (item,)
399+
400+
# apply the fsdp2 sharding bottoms up
401+
num_layers_sharded = 0
402+
for _, m in reversed(list(module.named_modules())):
403+
if shard_all:
404+
# fully_shard does not support containers that do not implement forward
405+
if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)):
406+
fully_shard(m, **fsdp_kwargs)
407+
num_layers_sharded += 1
408+
elif (
409+
isinstance(m, shard_module_types) or type(m).__name__ in shard_module_names
410+
):
411+
# if m exists in shard_module_types, then shard it
412+
fully_shard(m, **fsdp_kwargs)
413+
num_layers_sharded += 1
414+
415+
if num_layers_sharded == 0:
416+
raise ValueError(
417+
"No layer modules were sharded with fsdp2. Please check if shard conditions are working as expected."
418+
)
419+
420+
# shard the top level model, so that all params are moved off cpu to gpu
421+
if not _is_fsdp_module(module):
422+
fully_shard(module, **fsdp_kwargs)
423+
424+
# materialized sharded meta weights to device
425+
materialize_meta_params(module, device)
426+
427+
return module
428+
429+
275430
class FSDPOptimizerWrapper:
276431
"""
277432
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
@@ -301,7 +456,7 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool:
301456
# Also check for composable FSDP API
302457
maybe_composable_state = _get_module_state(module)
303458
if maybe_composable_state is not None:
304-
return isinstance(maybe_composable_state, _FSDPState)
459+
return isinstance(maybe_composable_state, (_FSDPState, FSDPState))
305460

306461
return False
307462

@@ -366,6 +521,8 @@ def prepare_module(
366521
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters"
367522
)
368523
module = prepare_fsdp(module, device, strategy)
524+
elif isinstance(strategy, FSDP2Strategy):
525+
module = prepare_fsdp2(module, device, strategy)
369526
else:
370527
module = module.to(device)
371528

0 commit comments

Comments
 (0)