Skip to content

Commit 7390c77

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fix mp policy forwarding in fsdp2 (#970)
Summary: Pull Request resolved: #970 Reviewed By: galrotem, anshulverma Differential Revision: D69669442 fbshipit-source-id: fff1e475ab1a31fc3291ee249c21292af1fc0561
1 parent 5bc1702 commit 7390c77

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/utils/test_prepare_module_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _test_prepare_fsdp2_shard_all() -> None:
348348

349349
module = SimpleModule()
350350
device = torch.device("cuda")
351-
strategy = FSDP2Strategy(modules_to_shard="all")
351+
strategy = FSDP2Strategy(modules_to_shard="all", mp_policy=torch.bfloat16)
352352
prepare_fsdp2(module, device, strategy)
353353

354354
for submodule in module.modules():

torchtnt/utils/prepare_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,9 @@ def prepare_fsdp2(
371371
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
372372
if (mp_policy := strategy.mp_policy) is not None:
373373
if isinstance(mp_policy, MixedPrecisionPolicy):
374-
fsdp_kwargs["mixed_precision"] = mp_policy
374+
fsdp_kwargs["mp_policy"] = mp_policy
375375
else:
376-
fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy(
376+
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
377377
param_dtype=mp_policy,
378378
reduce_dtype=mp_policy,
379379
output_dtype=mp_policy,

0 commit comments

Comments
 (0)