Skip to content

Commit a9eff9c

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support fsdp2 grad scaler (#997)
Summary: Pull Request resolved: #997 # Context FSDP required it's own sharded grad scaler. FSDP2 uses the original grad scaler (amp.grad_scaler). See https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md # This Diff 1) Separates fsdp1 and fsdp2 module check functions 2) only uses sharded grad scaler for fsdp1 modules Reviewed By: galrotem Differential Revision: D74410706 fbshipit-source-id: 5454069ae303a31932182ad1b06a9c8920fd5d07
1 parent f688719 commit a9eff9c

File tree

4 files changed

+25
-18
lines changed

4 files changed

+25
-18
lines changed

tests/utils/test_precision.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,16 @@ def test_convert_precision_str_to_dtype_throws(self) -> None:
4242

4343
def test_get_grad_scaler_from_precision(self) -> None:
4444
grad_scaler = get_grad_scaler_from_precision(
45-
torch.float32, is_fsdp_module=False
45+
torch.float32, is_fsdp1_module=False
4646
)
4747
self.assertIsNone(grad_scaler)
4848

4949
grad_scaler = get_grad_scaler_from_precision(
50-
torch.float16, is_fsdp_module=False
50+
torch.float16, is_fsdp1_module=False
5151
)
5252
self.assertIsInstance(grad_scaler, GradScaler)
5353

54-
grad_scaler = get_grad_scaler_from_precision(torch.float16, is_fsdp_module=True)
54+
grad_scaler = get_grad_scaler_from_precision(
55+
torch.float16, is_fsdp1_module=True
56+
)
5557
self.assertIsInstance(grad_scaler, ShardedGradScaler)

torchtnt/framework/auto_unit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
GradScaler,
4444
)
4545
from torchtnt.utils.prepare_module import (
46+
_is_fsdp1_module,
4647
_is_fsdp2_module,
47-
_is_fsdp_module,
4848
ActivationCheckpointParams,
4949
FSDPStrategy,
5050
prepare_fsdp,
@@ -560,7 +560,7 @@ def __init__(
560560
if self.precision:
561561
self.grad_scaler = get_grad_scaler_from_precision(
562562
self.precision,
563-
is_fsdp_module=_is_fsdp_module(self.module),
563+
is_fsdp1_module=_is_fsdp1_module(self.module),
564564
)
565565

566566
self.step_lr_interval = step_lr_interval

torchtnt/utils/precision.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,23 @@ def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]:
3838

3939

4040
def get_grad_scaler_from_precision(
41-
precision: torch.dtype, *, is_fsdp_module: Optional[bool] = False
41+
precision: torch.dtype, *, is_fsdp1_module: Optional[bool] = False
4242
) -> Optional[GradScaler]:
4343
"""
4444
Returns the correct grad scaler to use based on the precision and whether
45-
or not the model is FSDP.
45+
or not the model is FSDP. FSDP required it's own sharded grad scaler. FSDP2 uses
46+
the original grad scaler (amp.grad_scaler). See https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
4647
4748
Args:
4849
precision: the precision being used
49-
is_fsdp_module: whether the grad scaler is for an FSDP module
50+
is_fsdp1_module: whether the grad scaler is for an FSDP1 module
5051
5152
Returns:
5253
The appropriate grad scaler to use, ``None`` if no grad scaler should be used.
5354
"""
5455

5556
if precision == torch.float16:
56-
if is_fsdp_module:
57+
if is_fsdp1_module:
5758
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
5859

5960
return ShardedGradScaler()

torchtnt/utils/prepare_module.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6767
FullyShardedDataParallel as FSDP,
6868
StateDictType as _StateDictType,
6969
)
70-
from torch.distributed.fsdp._common_utils import _FSDPState
7170
from torch.distributed.fsdp.api import OptimStateDictConfig, StateDictConfig
7271
from torch.distributed.fsdp.fully_sharded_data_parallel import (
7372
BackwardPrefetch as _BackwardPrefetch,
@@ -435,7 +434,7 @@ def prepare_fsdp2(
435434
)
436435

437436
# shard the top level model, so that all params are moved off cpu to gpu
438-
if not _is_fsdp_module(module):
437+
if not _is_fsdp2_module(module):
439438
fully_shard(module, **fsdp_kwargs)
440439

441440
# materialized sharded meta weights to device
@@ -515,18 +514,23 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
515514

516515

517516
def _is_fsdp_module(module: torch.nn.Module) -> bool:
518-
if isinstance(module, FSDP):
519-
return True
517+
"""
518+
Checks if a module is wrapped in FSDP or FSDP2
519+
"""
520+
return _is_fsdp1_module(module) or _is_fsdp2_module(module)
520521

521-
# Also check for composable FSDP API
522-
maybe_composable_state = _get_module_state(module)
523-
if maybe_composable_state is not None:
524-
return isinstance(maybe_composable_state, (_FSDPState, FSDPState))
525522

526-
return False
523+
def _is_fsdp1_module(module: torch.nn.Module) -> bool:
524+
"""
525+
Checks if a module is sharded by original FSDP
526+
"""
527+
return isinstance(module, FSDP)
527528

528529

529530
def _is_fsdp2_module(module: torch.nn.Module) -> bool:
531+
"""
532+
Checks if a module is sharded by FSDP2
533+
"""
530534
maybe_composable_state = _get_module_state(module)
531535
if maybe_composable_state is not None:
532536
return isinstance(maybe_composable_state, FSDPState)

0 commit comments

Comments
 (0)