|
38 | 38 | set_optimizer_state_dict,
|
39 | 39 | )
|
40 | 40 | from torch.distributed.device_mesh import init_device_mesh
|
| 41 | +from torchtnt.utils.precision import convert_precision_str_to_dtype |
41 | 42 |
|
42 | 43 | try:
|
43 | 44 | from torch.distributed.fsdp import (
|
@@ -218,7 +219,7 @@ class FSDP2Strategy(Strategy):
|
218 | 219 | Iterable[Union[str, Type[torch.nn.Module]]],
|
219 | 220 | ] = "all"
|
220 | 221 | reshard_after_forward: Union[bool, int] = True
|
221 |
| - mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None |
| 222 | + mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None |
222 | 223 | cpu_offload: bool = False
|
223 | 224 |
|
224 | 225 |
|
@@ -375,13 +376,20 @@ def prepare_fsdp2(
|
375 | 376 | fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
|
376 | 377 | if (mp_policy := strategy.mp_policy) is not None:
|
377 | 378 | if isinstance(mp_policy, MixedPrecisionPolicy):
|
| 379 | + mp_policy = _check_and_convert_mp_policy_dtypes(mp_policy) |
378 | 380 | fsdp_kwargs["mp_policy"] = mp_policy
|
379 |
| - else: |
| 381 | + elif isinstance(mp_policy, str): |
| 382 | + dtype = convert_precision_str_to_dtype(mp_policy) |
| 383 | + fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy( |
| 384 | + param_dtype=dtype, |
| 385 | + reduce_dtype=dtype, |
| 386 | + output_dtype=dtype, |
| 387 | + ) |
| 388 | + elif isinstance(mp_policy, torch.dtype): |
380 | 389 | fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
|
381 | 390 | param_dtype=mp_policy,
|
382 | 391 | reduce_dtype=mp_policy,
|
383 | 392 | output_dtype=mp_policy,
|
384 |
| - cast_forward_inputs=True, |
385 | 393 | )
|
386 | 394 |
|
387 | 395 | # parse out the modules_to_shard argument
|
@@ -636,3 +644,39 @@ def materialize_meta_params(module: torch.nn.Module, device: torch.device) -> No
|
636 | 644 | if on_meta_device(submodule):
|
637 | 645 | rank_zero_info(f"{name} is on meta device, intializing on device {device}")
|
638 | 646 | submodule.to_empty(device=device, recurse=False)
|
| 647 | + |
| 648 | + |
| 649 | +def _check_and_convert_mp_policy_dtypes( |
| 650 | + mp_policy: MixedPrecisionPolicy, |
| 651 | +) -> MixedPrecisionPolicy: |
| 652 | + """ |
| 653 | + Converts precision strings to torch.dtype and validates that all dtypes are of type torch.dtype. |
| 654 | + Returns new MixedPrecisionPolicy as its attributes are frozen (cannot assign new values to fields) |
| 655 | + """ |
| 656 | + |
| 657 | + dtypes = (mp_policy.param_dtype, mp_policy.reduce_dtype, mp_policy.output_dtype) |
| 658 | + dtypes = filter(None, dtypes) |
| 659 | + for dtype in dtypes: |
| 660 | + if not isinstance(dtype, (str, torch.dtype)): |
| 661 | + raise ValueError( |
| 662 | + f"MixedPrecisionPolicy requires all dtypes to be torch.dtype or string. Got dtype={dtype} with type {type(dtype)}" |
| 663 | + ) |
| 664 | + |
| 665 | + param_dtype = mp_policy.param_dtype |
| 666 | + reduce_dtype = mp_policy.reduce_dtype |
| 667 | + output_dtype = mp_policy.output_dtype |
| 668 | + if isinstance(mp_policy.param_dtype, str): |
| 669 | + param_dtype = convert_precision_str_to_dtype(mp_policy.param_dtype) |
| 670 | + if isinstance(mp_policy.reduce_dtype, str): |
| 671 | + reduce_dtype = convert_precision_str_to_dtype(mp_policy.reduce_dtype) |
| 672 | + if isinstance(mp_policy.output_dtype, str): |
| 673 | + output_dtype = convert_precision_str_to_dtype(mp_policy.output_dtype) |
| 674 | + |
| 675 | + new_mp_policy = MixedPrecisionPolicy( |
| 676 | + param_dtype=param_dtype, |
| 677 | + reduce_dtype=reduce_dtype, |
| 678 | + output_dtype=output_dtype, |
| 679 | + cast_forward_inputs=mp_policy.cast_forward_inputs, |
| 680 | + ) |
| 681 | + |
| 682 | + return new_mp_policy |
0 commit comments