Skip to content

Commit 7272dbd

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support meta weight loading in DDP (#966)
Summary: Pull Request resolved: #966 Reviewed By: galrotem Differential Revision: D68837358 fbshipit-source-id: e3fcb6adf89e6ae5265a1cb0ccb2ad86a0b2c4e4
1 parent 46d6cee commit 7272dbd

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

tests/utils/test_prepare_module.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from torchtnt.utils.prepare_module import (
1818
DDPStrategy,
1919
FSDPStrategy,
20+
materialize_meta_params,
2021
NOOPStrategy,
22+
on_meta_device,
2123
prepare_module,
2224
TorchCompileParams,
2325
)
@@ -214,3 +216,29 @@ def test_prepare_module_compile_module_state_dict(self) -> None:
214216
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
215217
)
216218
self.assertIsNotNone(compiled_module._compiled_call_impl)
219+
220+
@unittest.skipUnless(
221+
torch_version_geq_2_1_0,
222+
reason="Must be on torch 2.1.0+ to run test",
223+
)
224+
def test_materialize_meta_params(self) -> None:
225+
# Create a simple module with parameters on the meta device
226+
class SimpleModule(torch.nn.Module):
227+
def __init__(self):
228+
super(SimpleModule, self).__init__()
229+
self.linear1 = torch.nn.Linear(10, 10, device="meta")
230+
self.linear2 = torch.nn.Linear(10, 10, device="cpu")
231+
232+
module = SimpleModule()
233+
device = torch.device("cpu")
234+
235+
self.assertFalse(on_meta_device(module)) # top level module has no params
236+
self.assertTrue(on_meta_device(module.linear1))
237+
self.assertFalse(on_meta_device(module.linear2))
238+
239+
# Call the function to test
240+
materialize_meta_params(module, device)
241+
242+
# Check if the parameters are moved to the specified device
243+
for param in module.parameters():
244+
self.assertEqual(param.device, device)

tests/utils/test_prepare_module_gpu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def test_prepare_ddp(self) -> None:
4141
"nccl",
4242
self._test_prepare_ddp,
4343
)
44+
spawn_multi_process(
45+
2,
46+
"nccl",
47+
self._test_prepare_ddp_meta_device,
48+
)
4449

4550
@staticmethod
4651
def _test_prepare_ddp() -> None:
@@ -54,6 +59,18 @@ def _test_prepare_ddp() -> None:
5459
tc = unittest.TestCase()
5560
tc.assertTrue(isinstance(ddp_module, DDP))
5661

62+
@staticmethod
63+
def _test_prepare_ddp_meta_device() -> None:
64+
module = torch.nn.Linear(2, 2, device="meta")
65+
device = init_from_env()
66+
ddp_module = prepare_ddp(
67+
module,
68+
device,
69+
DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True),
70+
)
71+
tc = unittest.TestCase()
72+
tc.assertTrue(isinstance(ddp_module, DDP))
73+
5774
@skip_if_not_gpu
5875
@skip_if_not_distributed
5976
def test_prepare_fsdp(self) -> None:

torchtnt/utils/prepare_module.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
StateDictType,
5151
)
5252

53-
from torchtnt.utils.rank_zero_log import rank_zero_warn
53+
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
5454
from torchtnt.utils.version import is_torch_version_geq
5555

5656

@@ -188,7 +188,7 @@ def prepare_ddp(
188188
Utility to move a module to device and wrap in `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_.
189189
190190
Args:
191-
module: module to be wrapped in DDP
191+
module: module to be wrapped in DDP. If module has params on meta device, they will be materialized on the device prior to DDP wrapping
192192
device: device to which module will be moved
193193
strategy: an instance of :class:`~torchtnt.utils.prepare_module.DDPStrategy` which defines the settings of DDP APIs
194194
@@ -207,6 +207,10 @@ def prepare_ddp(
207207
# remove ddp comm hook variables from params dict
208208
del params_dict["comm_state"]
209209
del params_dict["comm_hook"]
210+
211+
materialize_meta_params(module, device)
212+
213+
# now move rest of module to device
210214
module = module.to(device)
211215

212216
# remove sync batch norm from params dict before converting module
@@ -424,3 +428,24 @@ def convert_str_to_strategy(
424428
f"Strategy {strategy} not supported. Please use one of {list(string_to_strategy_mapping.keys())}"
425429
)
426430
return string_to_strategy_mapping[strategy]
431+
432+
433+
def on_meta_device(module: torch.nn.Module) -> bool:
434+
try:
435+
return next(module.parameters(recurse=False)).device.type == "meta"
436+
except StopIteration:
437+
return False
438+
439+
440+
def materialize_meta_params(module: torch.nn.Module, device: torch.device) -> None:
441+
"""
442+
Materialize meta device parameters to the given device.
443+
444+
Args:
445+
module: module to be used.
446+
device: device to which module will be moved.
447+
"""
448+
for name, submodule in module.named_modules():
449+
if on_meta_device(submodule):
450+
rank_zero_info(f"{name} is on meta device, intializing on device {device}")
451+
submodule.to_empty(device=device, recurse=False)

0 commit comments

Comments
 (0)