Skip to content

Commit 32a4d82

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support fsdp2 optimizer checkpointing (#977)
Summary: Pull Request resolved: #977 Reviewed By: diego-urgell Differential Revision: D70337970 fbshipit-source-id: 52f0915f01dba18ee420c9b38282db9f6eb926df
1 parent be1ed63 commit 32a4d82

File tree

4 files changed

+91
-14
lines changed

4 files changed

+91
-14
lines changed

tests/framework/test_app_state_mixin.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from torchtnt.utils.env import init_from_env
2121
from torchtnt.utils.lr_scheduler import TLRScheduler
22-
from torchtnt.utils.prepare_module import FSDPOptimizerWrapper
22+
from torchtnt.utils.prepare_module import FSDP2OptimizerWrapper, FSDPOptimizerWrapper
2323
from torchtnt.utils.stateful import MultiStateful
2424

2525

@@ -269,3 +269,15 @@ def test_construct_tracked_optimizers_and_schedulers(self) -> None:
269269
self.assertIsInstance(result["optimizer"], FSDPOptimizerWrapper)
270270
self.assertIsInstance(result["optim2"], torch.optim.Optimizer)
271271
self.assertIsInstance(result["lr_scheduler"], TLRScheduler)
272+
273+
with patch(
274+
"torchtnt.framework.unit._is_fsdp_module", side_effect=lambda m: m == module
275+
), patch(
276+
"torchtnt.framework.unit._is_fsdp2_module",
277+
side_effect=lambda m: m == module,
278+
):
279+
result = auto_unit._construct_tracked_optimizers_and_schedulers()
280+
281+
self.assertIsInstance(result["optimizer"], FSDP2OptimizerWrapper)
282+
self.assertIsInstance(result["optim2"], torch.optim.Optimizer)
283+
self.assertIsInstance(result["lr_scheduler"], TLRScheduler)

torchtnt/framework/unit.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020

2121
from torchtnt.framework.state import State
2222
from torchtnt.utils.lr_scheduler import TLRScheduler
23-
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
23+
from torchtnt.utils.prepare_module import (
24+
_is_fsdp2_module,
25+
_is_fsdp_module,
26+
FSDP2OptimizerWrapper,
27+
FSDPOptimizerWrapper,
28+
)
2429
from torchtnt.utils.progress import Progress
2530
from torchtnt.utils.stateful import MetricStateful, Stateful
2631

@@ -199,13 +204,27 @@ def __delattr__(self, name: str) -> None:
199204

200205
def _construct_tracked_optimizers_and_schedulers(
201206
self,
202-
) -> Dict[str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper, TLRScheduler]]:
207+
) -> Dict[
208+
str,
209+
Union[
210+
torch.optim.Optimizer,
211+
FSDPOptimizerWrapper,
212+
FSDP2OptimizerWrapper,
213+
TLRScheduler,
214+
],
215+
]:
203216
"""
204-
Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
217+
Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper/FSDP2OptimizerWrapper.
205218
"""
206219
# construct custom tracked optimizers with FSDP optimizers
207220
tracked_optimizers_and_schedulers: Dict[
208-
str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper, TLRScheduler]
221+
str,
222+
Union[
223+
torch.optim.Optimizer,
224+
FSDPOptimizerWrapper,
225+
FSDP2OptimizerWrapper,
226+
TLRScheduler,
227+
],
209228
] = {}
210229
tracked_optimizers_and_schedulers.update(self._construct_tracked_optimizers())
211230

@@ -224,25 +243,38 @@ def _construct_tracked_optimizers_and_schedulers(
224243

225244
def _construct_tracked_optimizers(
226245
self,
227-
) -> Dict[str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper]]:
246+
) -> Dict[
247+
str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper, FSDP2OptimizerWrapper]
248+
]:
228249
"""
229-
Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
250+
Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper/FSDP2OptimizerWrapper.
230251
"""
231-
fsdp_tracked_optimizers: Dict[str, FSDPOptimizerWrapper] = {}
252+
fsdp_tracked_optimizers: Dict[
253+
str, Union[FSDPOptimizerWrapper, FSDP2OptimizerWrapper]
254+
] = {}
232255
for module in self.tracked_modules().values():
233256
if _is_fsdp_module(module):
234257
# find optimizers for module, if exists
235258
optimizer_list = _find_optimizers_for_module(
236259
module, self.tracked_optimizers()
237260
)
261+
262+
is_fsdp2 = _is_fsdp2_module(module)
263+
238264
for optim_name, optimizer in optimizer_list:
239-
fsdp_tracked_optimizers[optim_name] = FSDPOptimizerWrapper(
240-
module, optimizer
241-
)
265+
if is_fsdp2:
266+
fsdp_tracked_optimizers[optim_name] = FSDP2OptimizerWrapper(
267+
module, optimizer
268+
)
269+
else:
270+
fsdp_tracked_optimizers[optim_name] = FSDPOptimizerWrapper(
271+
module, optimizer
272+
)
242273

243274
# construct custom tracked optimizers with FSDP optimizers
244275
tracked_optimizers: Dict[
245-
str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper]
276+
str,
277+
Union[torch.optim.Optimizer, FSDPOptimizerWrapper, FSDP2OptimizerWrapper],
246278
] = {
247279
key: value
248280
for key, value in self.tracked_optimizers().items()

torchtnt/utils/prepare_module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
checkpoint_wrapper,
3434
CheckpointImpl,
3535
)
36+
from torch.distributed.checkpoint.state_dict import (
37+
get_optimizer_state_dict,
38+
set_optimizer_state_dict,
39+
)
3640
from torch.distributed.device_mesh import init_device_mesh
3741

3842
try:
@@ -449,6 +453,24 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
449453
self.optimizer.load_state_dict(optim_state_dict)
450454

451455

456+
class FSDP2OptimizerWrapper:
457+
"""
458+
Wrapper for FSDP2 optimizer which uses distributed state dict APIs.
459+
"""
460+
461+
def __init__(
462+
self, module: torch.nn.Module, optimizer: torch.optim.Optimizer
463+
) -> None:
464+
self.module = module
465+
self.optimizer = optimizer
466+
467+
def state_dict(self) -> Dict[str, Any]:
468+
return get_optimizer_state_dict(self.module, self.optimizer)
469+
470+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
471+
set_optimizer_state_dict(self.module, self.optimizer, state_dict)
472+
473+
452474
def _is_fsdp_module(module: torch.nn.Module) -> bool:
453475
if isinstance(module, FSDP):
454476
return True
@@ -461,6 +483,14 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool:
461483
return False
462484

463485

486+
def _is_fsdp2_module(module: torch.nn.Module) -> bool:
487+
maybe_composable_state = _get_module_state(module)
488+
if maybe_composable_state is not None:
489+
return isinstance(maybe_composable_state, FSDPState)
490+
491+
return False
492+
493+
464494
def prepare_module(
465495
module: torch.nn.Module,
466496
device: torch.device,

torchtnt/utils/stateful.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from torchtnt.utils.lr_scheduler import TLRScheduler
13-
from torchtnt.utils.prepare_module import FSDPOptimizerWrapper
13+
from torchtnt.utils.prepare_module import FSDP2OptimizerWrapper, FSDPOptimizerWrapper
1414
from torchtnt.utils.progress import Progress
1515

1616
from typing_extensions import Protocol, runtime_checkable
@@ -28,7 +28,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
2828
StatefulDict = Dict[str, Stateful]
2929
ModuleDict = Dict[str, torch.nn.Module]
3030
OptimizerAndLRSchedulerDict = Dict[
31-
str, Union[TLRScheduler, torch.optim.Optimizer, FSDPOptimizerWrapper]
31+
str,
32+
Union[
33+
TLRScheduler, torch.optim.Optimizer, FSDPOptimizerWrapper, FSDP2OptimizerWrapper
34+
],
3235
]
3336
ProgressDict = Dict[str, Progress]
3437

0 commit comments

Comments
 (0)