Skip to content

Commit 3232a91

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fix stale weight bug for FSDP EMA + AutoUnit (#962)
Summary: Pull Request resolved: #962 Reviewed By: anshulverma Differential Revision: D68450131 fbshipit-source-id: 8f8981f39ea654a9e83af612c7d93880066308e3
1 parent 9984243 commit 3232a91

File tree

3 files changed

+149
-3
lines changed

3 files changed

+149
-3
lines changed

tests/framework/test_auto_unit_gpu.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import unittest
1111

1212
from copy import deepcopy
13-
from typing import TypeVar
13+
from typing import Tuple, TypeVar
1414
from unittest.mock import MagicMock, patch
1515

1616
import torch
@@ -27,7 +27,9 @@
2727

2828
from torchtnt.framework.auto_unit import AutoPredictUnit, SWALRParams, SWAParams
2929
from torchtnt.framework.evaluate import evaluate
30+
from torchtnt.framework.fit import fit
3031
from torchtnt.framework.predict import predict
32+
from torchtnt.framework.state import ActivePhase, State
3133
from torchtnt.framework.train import train
3234
from torchtnt.utils.distributed import spawn_multi_process
3335
from torchtnt.utils.env import init_from_env, seed
@@ -38,6 +40,25 @@
3840
T = TypeVar("T")
3941

4042

43+
Batch = Tuple[torch.Tensor, torch.Tensor]
44+
45+
46+
class DummySWAAutoUnit(DummyAutoUnit):
47+
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, object]:
48+
"""
49+
Computes loss for given batch. If in EVAL or PREDICT phase, uses swa model's output
50+
"""
51+
inputs, targets = data
52+
if state.active_phase == ActivePhase.TRAIN:
53+
outputs = self.module(inputs)
54+
else:
55+
outputs = self.swa_model(inputs) if self.swa_model else self.module(inputs)
56+
57+
loss = torch.nn.functional.cross_entropy(outputs, targets)
58+
59+
return loss, outputs
60+
61+
4162
class TestAutoUnitGPU(unittest.TestCase):
4263
@skip_if_not_gpu
4364
@skip_if_not_distributed
@@ -184,6 +205,112 @@ def forward(self, x):
184205
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
185206
torch.testing.assert_close(p2, p1, check_device=False)
186207

208+
@skip_if_not_distributed
209+
@skip_if_not_gpu
210+
def test_stochastic_weight_averaging_fsdp_with_eval(self) -> None:
211+
"""
212+
Test that swa params with FSDP is identical to non-FSDP swa
213+
"""
214+
spawn_multi_process(
215+
2,
216+
"nccl",
217+
self._test_stochastic_weight_averaging_fsdp_with_eval,
218+
)
219+
220+
@staticmethod
221+
def _test_stochastic_weight_averaging_fsdp_with_eval() -> None:
222+
"""
223+
Compares the swa model parameters after training without FSDP and with FSDP.
224+
They should be identical.
225+
"""
226+
227+
class Net(torch.nn.Module):
228+
def __init__(self):
229+
super(Net, self).__init__()
230+
self.l1 = torch.nn.Linear(2, 2)
231+
self.b1 = torch.nn.BatchNorm1d(2)
232+
self.l2 = torch.nn.Linear(2, 2)
233+
234+
def forward(self, x):
235+
x = self.l1(x)
236+
x = self.b1(x)
237+
x = self.l2(x)
238+
return x
239+
240+
# so all ranks start with same initialized weights
241+
device = init_from_env()
242+
seed(0)
243+
my_module = Net()
244+
245+
auto_unit = DummySWAAutoUnit(
246+
module=deepcopy(my_module),
247+
device=device,
248+
step_lr_interval="step",
249+
swa_params=SWAParams(
250+
warmup_steps_or_epochs=1,
251+
step_or_epoch_update_freq=1,
252+
swalr_params=SWALRParams(
253+
anneal_steps_or_epochs=3,
254+
),
255+
averaging_method="ema",
256+
),
257+
)
258+
259+
auto_unit_fsdp = DummySWAAutoUnit(
260+
module=my_module,
261+
device=device,
262+
step_lr_interval="step",
263+
strategy=FSDPStrategy(),
264+
swa_params=SWAParams(
265+
warmup_steps_or_epochs=1,
266+
step_or_epoch_update_freq=1,
267+
swalr_params=SWALRParams(
268+
anneal_steps_or_epochs=3,
269+
),
270+
averaging_method="ema",
271+
),
272+
)
273+
274+
input_dim = 2
275+
dataset_len = 10
276+
batch_size = 2
277+
278+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
279+
eval_dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
280+
fit(
281+
auto_unit,
282+
dataloader,
283+
eval_dataloader,
284+
max_epochs=3,
285+
max_train_steps_per_epoch=5,
286+
evaluate_every_n_epochs=0,
287+
)
288+
289+
fit(
290+
auto_unit_fsdp,
291+
dataloader,
292+
eval_dataloader,
293+
max_epochs=3,
294+
max_train_steps_per_epoch=5,
295+
# this is key arg, to ensure that swa model is updated
296+
# even after swa model forward pass is used in eval
297+
evaluate_every_n_epochs=1,
298+
)
299+
300+
swa_params = list(auto_unit.swa_model.parameters())
301+
swa_buffers = list(auto_unit.swa_model.buffers())
302+
with FSDP.summon_full_params(auto_unit_fsdp.swa_model):
303+
swa_fsdp_params = auto_unit_fsdp.swa_model.parameters()
304+
swa_fsdp_buffers = auto_unit_fsdp.swa_model.buffers()
305+
306+
# Iterate and compare each parameter
307+
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
308+
torch.testing.assert_close(p2, p1, check_device=False)
309+
310+
# Iterate and compare each buffer
311+
for b1, b2 in zip(swa_buffers, swa_fsdp_buffers, strict=True):
312+
torch.testing.assert_close(b2, b1, check_device=False)
313+
187314
@skip_if_not_gpu
188315
@patch("torch.autocast")
189316
def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:

torchtnt/framework/auto_unit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ class AutoUnit(
461461
detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
462462
clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
463463
clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
464-
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
464+
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging (Please see note if using with FSDP)
465465
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
466466
activation_checkpoint_params: params for enabling activation checkpointing
467467
training: if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
@@ -481,6 +481,10 @@ class AutoUnit(
481481
Note:
482482
Torch compile support is only available in PyTorch 2.0 or higher.
483483
484+
Note:
485+
If using SWA with FSDP, the SWA model will be sharded with the same FSDP configuration as the original model. If you need the swa model's output in evaluation / prediction step,
486+
please call `self.swa_model(inputs, ...)` to ensure all hooks (especially for FSDP) are fired correctly.
487+
484488
"""
485489

486490
def __init__(

torchtnt/utils/swa.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import Callable, List, Literal, Optional
9+
from typing import Any, Callable, List, Literal, Optional
1010

1111
import torch
12+
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
1213

1314
_AVERAGED_MODEL_AVAIL: bool = True
1415

@@ -105,6 +106,20 @@ def __init__(
105106
use_buffers=use_buffers,
106107
)
107108

109+
# pyre-ignore: Missing return annotation [3]: Return type must be specified as type other than `Any`
110+
def forward(self, *args: Any, **kwargs: Any) -> Any:
111+
output = self.module(*args, **kwargs)
112+
113+
# for fsdp modules, we need to manually reshard the swa_model in case the
114+
# model fwd was used in evaluation loop, due to how fsdp manages the param state
115+
# see https://github.com/pytorch/pytorch/issues/117742
116+
for m in FullyShardedDataParallel.fsdp_modules(self.module):
117+
if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD:
118+
# pyre-ignore: Incompatible parameter type [6]: In call `torch.distributed.fsdp._runtime_utils._reshard`, for 2nd positional argument, expected `FlatParamHandle` but got `Optional[FlatParamHandle]`.
119+
torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True)
120+
121+
return output
122+
108123
def update_parameters(self, model: torch.nn.Module) -> None:
109124
self._num_updates += 1
110125
if self._use_lit:

0 commit comments

Comments
 (0)