Skip to content

Commit 055aa15

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add util function to get module state dict (#984)
Summary: Pull Request resolved: #984 Reviewed By: galrotem Differential Revision: D71218699 fbshipit-source-id: f97209f4b90f5c3978cf5f56cbddc966d3ebb807
1 parent 4da9704 commit 055aa15

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

tests/utils/test_prepare_module_gpu.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
FullyShardedDataParallel as FSDP,
1515
MixedPrecisionPolicy,
1616
)
17+
from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
18+
from torchtnt.framework.train import train
19+
from torchtnt.utils.distributed import get_global_rank
20+
from torchtnt.utils.prepare_module import get_module_state_dict
1721

1822
try:
1923
from torch.distributed.fsdp import fully_shard
@@ -404,6 +408,62 @@ def _test_prepare_fsdp2_meta_device() -> None:
404408
# linear and SimpleModule are fsdp modules
405409
tc.assertTrue(_is_fsdp_module(submodule))
406410

411+
def test_get_module_state_dict(self) -> None:
412+
spawn_multi_process(
413+
2,
414+
"nccl",
415+
self._test_get_module_state_dict,
416+
)
417+
418+
@staticmethod
419+
def _test_get_module_state_dict() -> None:
420+
rank = get_global_rank()
421+
422+
fsdp_strategy = FSDPStrategy(
423+
sharding_strategy="FULL_SHARD",
424+
auto_wrap_policy=lambda module, recurse, nonwrapped_numel: True,
425+
)
426+
ddp_strategy = DDPStrategy()
427+
428+
for strategy, rank0_only in (
429+
(fsdp_strategy, True),
430+
(fsdp_strategy, False),
431+
(ddp_strategy, True),
432+
(ddp_strategy, False),
433+
(None, True),
434+
(None, False),
435+
):
436+
module = torch.nn.Sequential(
437+
torch.nn.Linear(2, 100),
438+
torch.nn.Linear(100, 2),
439+
)
440+
441+
unit = DummyAutoUnit(
442+
module=module,
443+
strategy=strategy,
444+
)
445+
446+
dataloader = generate_random_dataloader(10, 2, 10)
447+
train(unit, dataloader, max_epochs=1)
448+
449+
module_sd = get_module_state_dict(unit.module, rank0_only=rank0_only)
450+
451+
tc = unittest.TestCase()
452+
453+
# For FSDP, if the user passed rank0_only=True, we should get an empty state dict
454+
# on all ranks except rank 0
455+
if rank0_only and isinstance(strategy, FSDPStrategy) and rank != 0:
456+
tc.assertEqual(module_sd, {})
457+
458+
else:
459+
# Make sure that the generated state dict has the actual model keys,
460+
# and the values are actual tensors as opposed to ShardedTensor.
461+
tc.assertCountEqual(
462+
["0.weight", "0.bias", "1.weight", "1.bias"],
463+
list(module_sd.keys()),
464+
)
465+
tc.assertIsInstance(module_sd["0.weight"], torch.Tensor)
466+
407467

408468
class SimpleModule(torch.nn.Module):
409469
def __init__(self, meta_device: bool = False) -> None:

torchtnt/utils/prepare_module.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import logging
910
from dataclasses import asdict, dataclass
1011
from functools import partial
1112
from typing import (
@@ -38,6 +39,7 @@
3839
set_optimizer_state_dict,
3940
)
4041
from torch.distributed.device_mesh import init_device_mesh
42+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
4143
from torchtnt.utils.precision import convert_precision_str_to_dtype
4244

4345
try:
@@ -85,6 +87,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
8587
from torchtnt.utils.version import is_torch_version_geq
8688

8789

90+
logger: logging.Logger = logging.getLogger(__name__)
91+
92+
8893
@dataclass
8994
class Strategy:
9095
"""Dataclass representing a parallelization strategy"""
@@ -680,3 +685,40 @@ def _check_and_convert_mp_policy_dtypes(
680685
)
681686

682687
return new_mp_policy
688+
689+
690+
def get_module_state_dict(
691+
module: torch.nn.Module, rank0_only: bool = False
692+
) -> Dict[str, Any]:
693+
"""
694+
Given a module, return a state dict that can be loaded into a CPU instance of the module. This requires different implementation depending on strategy:
695+
- If FSDP, we need to gather all the sharded parameters and offload state dict to CPU in order to avoid OOM.
696+
- If DDP, we need to unwrap the module to avoid extra state_dict prefix
697+
- Otherwise, we can just return the state dict as is
698+
699+
Args:
700+
module: module to be used.
701+
rank0_only: This flag only works for FSDP. If True, only rank 0 will return the state dict. Other ranks will return an empty dict.
702+
For DDP or no strategy case, we don't move the state dice to CPU -- it can be loaded directly into the module.
703+
704+
Note: Even if the state_dict parameters are on GPU, it can still be loaded into a CPU module.
705+
"""
706+
logger.info("Generating module state dict")
707+
708+
# TODO: Add support for FSDP2
709+
if isinstance(module, FSDP):
710+
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
711+
with FSDP.state_dict_type(module, _StateDictType.FULL_STATE_DICT, state_cfg):
712+
return module.state_dict()
713+
714+
if rank0_only:
715+
logger.warning(
716+
"Provided rank0_only=True, but this is no-op for DDP or no strategy. Returning state dict in module's device."
717+
)
718+
719+
if isinstance(module, DDP):
720+
module = module.module
721+
722+
state_dict = module.state_dict()
723+
724+
return state_dict

0 commit comments

Comments
 (0)