Skip to content

Commit f1ebb63

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
disable reshard_after_forward in top level module (#1009)
Summary: Pull Request resolved: #1009 Reviewed By: richardwang-at-fb Differential Revision: D76364770 fbshipit-source-id: 0c5eb22617ca7d903b2db7b159b2ea31f76e7244
1 parent db8367b commit f1ebb63

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/utils/test_prepare_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
293293
global_mesh=mock_global_mesh,
294294
)
295295
mock_fully_shard.assert_called_with(
296-
module, mesh=mock_mesh, reshard_after_forward=True
296+
module, mesh=mock_mesh, reshard_after_forward=False
297297
)
298298

299299
@patch("torchtnt.utils.prepare_module._prepare_module_2d")

torchtnt/utils/prepare_module.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import logging
10+
from copy import deepcopy
1011
from dataclasses import asdict, dataclass, field
1112
from functools import partial
1213
from typing import (
@@ -468,7 +469,13 @@ def prepare_fsdp2(
468469

469470
# shard the top level model, so that all params are moved off cpu to gpu
470471
if not _is_fsdp2_module(module):
471-
fully_shard(module, **fsdp_kwargs)
472+
# disable reshard_after_forward for top level module
473+
# as result is DTensor which may be incompatible with
474+
# certain loss computation
475+
root_kwargs = deepcopy(fsdp_kwargs)
476+
root_kwargs["reshard_after_forward"] = False
477+
478+
fully_shard(module, **root_kwargs)
472479

473480
# materialized sharded meta weights to device
474481
materialize_meta_params(module, device)

0 commit comments

Comments
 (0)