Skip to content

Commit 4da9704

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support str precision in mp_policy for fsdp2 (#985)
Summary: Pull Request resolved: #985 Reviewed By: diego-urgell Differential Revision: D71144318 fbshipit-source-id: be4b4176f91be1f9e2057324f56afbbbed56d0dd
1 parent 6e5b158 commit 4da9704

File tree

3 files changed

+85
-6
lines changed

3 files changed

+85
-6
lines changed

tests/utils/test_prepare_module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from unittest.mock import patch
1212

1313
import torch
14+
from torch.distributed.fsdp import MixedPrecisionPolicy
1415
from torch.nn.parallel import DistributedDataParallel as DDP
1516
from torchtnt.utils.distributed import spawn_multi_process
1617
from torchtnt.utils.env import init_from_env
1718
from torchtnt.utils.prepare_module import (
19+
_check_and_convert_mp_policy_dtypes,
1820
DDPStrategy,
1921
FSDPStrategy,
2022
materialize_meta_params,
@@ -242,3 +244,25 @@ def __init__(self):
242244
# Check if the parameters are moved to the specified device
243245
for param in module.parameters():
244246
self.assertEqual(param.device, device)
247+
248+
def test_check_and_convert_mp_policy_dtypes(self) -> None:
249+
mp_policy = MixedPrecisionPolicy(
250+
# pyre-ignore: Incompatible parameter type [6] (intentional for this test)
251+
param_dtype="bf16",
252+
# pyre-ignore: Incompatible parameter type [6] (intentional for this test)
253+
reduce_dtype="fp16",
254+
cast_forward_inputs=False,
255+
)
256+
new_mp_policy = _check_and_convert_mp_policy_dtypes(mp_policy)
257+
self.assertEqual(new_mp_policy.param_dtype, torch.bfloat16)
258+
self.assertEqual(new_mp_policy.reduce_dtype, torch.float16)
259+
self.assertEqual(new_mp_policy.output_dtype, None)
260+
self.assertFalse(new_mp_policy.cast_forward_inputs)
261+
262+
# pyre-ignore: Incompatible parameter type [6] (intentional for this test)
263+
invalid_mp_policy = MixedPrecisionPolicy(param_dtype=16)
264+
with self.assertRaisesRegex(
265+
ValueError,
266+
"MixedPrecisionPolicy requires all dtypes to be torch.dtype.",
267+
):
268+
_check_and_convert_mp_policy_dtypes(invalid_mp_policy)

tests/utils/test_prepare_module_gpu.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from typing import Any
1111

1212
import torch
13-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13+
from torch.distributed.fsdp import (
14+
FullyShardedDataParallel as FSDP,
15+
MixedPrecisionPolicy,
16+
)
1417

1518
try:
1619
from torch.distributed.fsdp import fully_shard
@@ -329,7 +332,8 @@ def test_prepare_fsdp2(self) -> None:
329332
@staticmethod
330333
def _test_prepare_fsdp2_none_sharded_raises() -> None:
331334
"""
332-
Test with a strategy that does not shard any modules, should raise error
335+
Test with a strategy that does not shard any modules, should raise error. And also raise error
336+
for invalid mp_policy.
333337
"""
334338
tc = unittest.TestCase()
335339

@@ -339,6 +343,11 @@ def _test_prepare_fsdp2_none_sharded_raises() -> None:
339343
with tc.assertRaises(ValueError):
340344
prepare_fsdp2(module, device, strategy)
341345

346+
# pyre-ignore[6]: Incompatible parameter type (intentional for testing)
347+
strategy = FSDP2Strategy(mp_policy=MixedPrecisionPolicy(param_dtype=16))
348+
with tc.assertRaises(ValueError):
349+
prepare_fsdp2(module, device, strategy)
350+
342351
@staticmethod
343352
def _test_prepare_fsdp2_shard_all() -> None:
344353
"""
@@ -364,7 +373,9 @@ def _test_prepare_fsdp2_submodule() -> None:
364373
for t in (torch.nn.Linear, "Linear"):
365374
module = SimpleModule()
366375
device = torch.device("cuda")
367-
strategy = FSDP2Strategy(modules_to_shard=(t,))
376+
# pyre-ignore: Incompatible parameter type [6] (intentional for this test)
377+
mp_policy = MixedPrecisionPolicy(param_dtype="bf16")
378+
strategy = FSDP2Strategy(modules_to_shard=(t,), mp_policy=mp_policy)
368379
prepare_fsdp2(module, device, strategy)
369380

370381
for submodule in module.modules():

torchtnt/utils/prepare_module.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
set_optimizer_state_dict,
3939
)
4040
from torch.distributed.device_mesh import init_device_mesh
41+
from torchtnt.utils.precision import convert_precision_str_to_dtype
4142

4243
try:
4344
from torch.distributed.fsdp import (
@@ -218,7 +219,7 @@ class FSDP2Strategy(Strategy):
218219
Iterable[Union[str, Type[torch.nn.Module]]],
219220
] = "all"
220221
reshard_after_forward: Union[bool, int] = True
221-
mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None
222+
mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None
222223
cpu_offload: bool = False
223224

224225

@@ -375,13 +376,20 @@ def prepare_fsdp2(
375376
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
376377
if (mp_policy := strategy.mp_policy) is not None:
377378
if isinstance(mp_policy, MixedPrecisionPolicy):
379+
mp_policy = _check_and_convert_mp_policy_dtypes(mp_policy)
378380
fsdp_kwargs["mp_policy"] = mp_policy
379-
else:
381+
elif isinstance(mp_policy, str):
382+
dtype = convert_precision_str_to_dtype(mp_policy)
383+
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
384+
param_dtype=dtype,
385+
reduce_dtype=dtype,
386+
output_dtype=dtype,
387+
)
388+
elif isinstance(mp_policy, torch.dtype):
380389
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
381390
param_dtype=mp_policy,
382391
reduce_dtype=mp_policy,
383392
output_dtype=mp_policy,
384-
cast_forward_inputs=True,
385393
)
386394

387395
# parse out the modules_to_shard argument
@@ -636,3 +644,39 @@ def materialize_meta_params(module: torch.nn.Module, device: torch.device) -> No
636644
if on_meta_device(submodule):
637645
rank_zero_info(f"{name} is on meta device, intializing on device {device}")
638646
submodule.to_empty(device=device, recurse=False)
647+
648+
649+
def _check_and_convert_mp_policy_dtypes(
650+
mp_policy: MixedPrecisionPolicy,
651+
) -> MixedPrecisionPolicy:
652+
"""
653+
Converts precision strings to torch.dtype and validates that all dtypes are of type torch.dtype.
654+
Returns new MixedPrecisionPolicy as its attributes are frozen (cannot assign new values to fields)
655+
"""
656+
657+
dtypes = (mp_policy.param_dtype, mp_policy.reduce_dtype, mp_policy.output_dtype)
658+
dtypes = filter(None, dtypes)
659+
for dtype in dtypes:
660+
if not isinstance(dtype, (str, torch.dtype)):
661+
raise ValueError(
662+
f"MixedPrecisionPolicy requires all dtypes to be torch.dtype or string. Got dtype={dtype} with type {type(dtype)}"
663+
)
664+
665+
param_dtype = mp_policy.param_dtype
666+
reduce_dtype = mp_policy.reduce_dtype
667+
output_dtype = mp_policy.output_dtype
668+
if isinstance(mp_policy.param_dtype, str):
669+
param_dtype = convert_precision_str_to_dtype(mp_policy.param_dtype)
670+
if isinstance(mp_policy.reduce_dtype, str):
671+
reduce_dtype = convert_precision_str_to_dtype(mp_policy.reduce_dtype)
672+
if isinstance(mp_policy.output_dtype, str):
673+
output_dtype = convert_precision_str_to_dtype(mp_policy.output_dtype)
674+
675+
new_mp_policy = MixedPrecisionPolicy(
676+
param_dtype=param_dtype,
677+
reduce_dtype=reduce_dtype,
678+
output_dtype=output_dtype,
679+
cast_forward_inputs=mp_policy.cast_forward_inputs,
680+
)
681+
682+
return new_mp_policy

0 commit comments

Comments
 (0)