|
15 | 15 | ContextManager,
|
16 | 16 | Dict,
|
17 | 17 | Iterable,
|
| 18 | + Literal, |
18 | 19 | Optional,
|
| 20 | + Set, |
19 | 21 | Tuple,
|
| 22 | + Type, |
20 | 23 | Union,
|
21 | 24 | )
|
22 | 25 |
|
|
30 | 33 | checkpoint_wrapper,
|
31 | 34 | CheckpointImpl,
|
32 | 35 | )
|
| 36 | +from torch.distributed.device_mesh import init_device_mesh |
| 37 | + |
| 38 | +try: |
| 39 | + from torch.distributed.fsdp import ( |
| 40 | + CPUOffloadPolicy, |
| 41 | + fully_shard, |
| 42 | + MixedPrecisionPolicy, |
| 43 | + ) |
| 44 | + from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState |
| 45 | +except ImportError: |
| 46 | + |
| 47 | + def noop(*args: Any, **kwargs: Any) -> None: |
| 48 | + pass |
| 49 | + |
| 50 | + class NOOP: |
| 51 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 52 | + pass |
| 53 | + |
| 54 | + fully_shard = noop |
| 55 | + MixedPrecisionPolicy = NOOP |
| 56 | + CPUOffloadPolicy = NOOP |
| 57 | + FSDPState = NOOP |
| 58 | + |
33 | 59 | from torch.distributed.fsdp import (
|
34 | 60 | FullyShardedDataParallel as FSDP,
|
35 | 61 | StateDictType as _StateDictType,
|
@@ -146,6 +172,52 @@ def __post_init__(self) -> None:
|
146 | 172 | self.mixed_precision = self.mixed_precision.to_native_mixed_precision()
|
147 | 173 |
|
148 | 174 |
|
| 175 | +@dataclass |
| 176 | +class FSDP2Strategy(Strategy): |
| 177 | + """ |
| 178 | + Dataclass representing the `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_ strategy. |
| 179 | + For more details on the args, see the link. |
| 180 | +
|
| 181 | + Args: |
| 182 | + modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. |
| 183 | + reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage. |
| 184 | + mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used |
| 185 | + cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage. |
| 186 | +
|
| 187 | + Note: |
| 188 | + It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has |
| 189 | + communication overhead. |
| 190 | +
|
| 191 | + Example: |
| 192 | + >>> model |
| 193 | + TransformerDecoder( |
| 194 | + (tok_embeddings): Embedding(128256, 4096) |
| 195 | + (layers): ModuleList( |
| 196 | + (0-31): 32 x TransformerSelfAttentionLayer( |
| 197 | + (attn): MultiHeadAttention( |
| 198 | + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) |
| 199 | + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) |
| 200 | + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) |
| 201 | + (output_proj): Linear(in_features=4096, out_features=4096, bias=False) |
| 202 | + (pos_embeddings): RotaryPositionalEmbeddings() |
| 203 | + ) |
| 204 | + ... |
| 205 | + ) |
| 206 | + (output): Linear(in_features=4096, out_features=128256, bias=False) |
| 207 | + ) |
| 208 | + >>> # You can either specify the module to shard as a name ("Linear") or the module type (torch.nn.Linear) |
| 209 | + >>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"]) |
| 210 | + """ |
| 211 | + |
| 212 | + modules_to_shard: Union[ |
| 213 | + Literal["all"], |
| 214 | + Iterable[Union[str, Type[torch.nn.Module]]], |
| 215 | + ] = "all" |
| 216 | + reshard_after_forward: Union[bool, int] = True |
| 217 | + mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None |
| 218 | + cpu_offload: bool = False |
| 219 | + |
| 220 | + |
149 | 221 | @dataclass
|
150 | 222 | class TorchCompileParams:
|
151 | 223 | """
|
@@ -272,6 +344,89 @@ def prepare_fsdp(
|
272 | 344 | return module
|
273 | 345 |
|
274 | 346 |
|
| 347 | +def prepare_fsdp2( |
| 348 | + module: torch.nn.Module, |
| 349 | + device: torch.device, |
| 350 | + strategy: Optional[FSDP2Strategy] = None, |
| 351 | + process_group: Optional[ProcessGroup] = None, |
| 352 | +) -> torch.nn.Module: |
| 353 | + """ |
| 354 | + Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_ |
| 355 | +
|
| 356 | + Args: |
| 357 | + module: module to be wrapped in FSDP |
| 358 | + device: device to which module will be moved |
| 359 | + strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs |
| 360 | + """ |
| 361 | + strategy = strategy or FSDP2Strategy() |
| 362 | + |
| 363 | + # prepare kwargs for fully_shard api |
| 364 | + pg = process_group or dist.distributed_c10d._get_default_group() |
| 365 | + mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),)) |
| 366 | + fsdp_kwargs: Dict[str, Any] = { |
| 367 | + "mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP |
| 368 | + "reshard_after_forward": strategy.reshard_after_forward, |
| 369 | + } |
| 370 | + if strategy.cpu_offload: |
| 371 | + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() |
| 372 | + if (mp_policy := strategy.mp_policy) is not None: |
| 373 | + if isinstance(mp_policy, MixedPrecisionPolicy): |
| 374 | + fsdp_kwargs["mixed_precision"] = mp_policy |
| 375 | + else: |
| 376 | + fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy( |
| 377 | + param_dtype=mp_policy, |
| 378 | + reduce_dtype=mp_policy, |
| 379 | + output_dtype=mp_policy, |
| 380 | + cast_forward_inputs=True, |
| 381 | + ) |
| 382 | + |
| 383 | + # parse out the modules_to_shard argument |
| 384 | + modules_to_shard = strategy.modules_to_shard |
| 385 | + |
| 386 | + shard_all = modules_to_shard == "all" |
| 387 | + shard_module_names: Set[str] = set() |
| 388 | + shard_module_types: Tuple[Type[torch.nn.Module], ...] = () |
| 389 | + if not shard_all: |
| 390 | + assert ( |
| 391 | + type(modules_to_shard) is not str |
| 392 | + ), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}" |
| 393 | + |
| 394 | + for item in modules_to_shard: |
| 395 | + if isinstance(item, str): |
| 396 | + shard_module_names.add(item) |
| 397 | + else: |
| 398 | + shard_module_types = shard_module_types + (item,) |
| 399 | + |
| 400 | + # apply the fsdp2 sharding bottoms up |
| 401 | + num_layers_sharded = 0 |
| 402 | + for _, m in reversed(list(module.named_modules())): |
| 403 | + if shard_all: |
| 404 | + # fully_shard does not support containers that do not implement forward |
| 405 | + if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)): |
| 406 | + fully_shard(m, **fsdp_kwargs) |
| 407 | + num_layers_sharded += 1 |
| 408 | + elif ( |
| 409 | + isinstance(m, shard_module_types) or type(m).__name__ in shard_module_names |
| 410 | + ): |
| 411 | + # if m exists in shard_module_types, then shard it |
| 412 | + fully_shard(m, **fsdp_kwargs) |
| 413 | + num_layers_sharded += 1 |
| 414 | + |
| 415 | + if num_layers_sharded == 0: |
| 416 | + raise ValueError( |
| 417 | + "No layer modules were sharded with fsdp2. Please check if shard conditions are working as expected." |
| 418 | + ) |
| 419 | + |
| 420 | + # shard the top level model, so that all params are moved off cpu to gpu |
| 421 | + if not _is_fsdp_module(module): |
| 422 | + fully_shard(module, **fsdp_kwargs) |
| 423 | + |
| 424 | + # materialized sharded meta weights to device |
| 425 | + materialize_meta_params(module, device) |
| 426 | + |
| 427 | + return module |
| 428 | + |
| 429 | + |
275 | 430 | class FSDPOptimizerWrapper:
|
276 | 431 | """
|
277 | 432 | Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
|
@@ -301,7 +456,7 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool:
|
301 | 456 | # Also check for composable FSDP API
|
302 | 457 | maybe_composable_state = _get_module_state(module)
|
303 | 458 | if maybe_composable_state is not None:
|
304 |
| - return isinstance(maybe_composable_state, _FSDPState) |
| 459 | + return isinstance(maybe_composable_state, (_FSDPState, FSDPState)) |
305 | 460 |
|
306 | 461 | return False
|
307 | 462 |
|
@@ -366,6 +521,8 @@ def prepare_module(
|
366 | 521 | "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters"
|
367 | 522 | )
|
368 | 523 | module = prepare_fsdp(module, device, strategy)
|
| 524 | + elif isinstance(strategy, FSDP2Strategy): |
| 525 | + module = prepare_fsdp2(module, device, strategy) |
369 | 526 | else:
|
370 | 527 | module = module.to(device)
|
371 | 528 |
|
|
0 commit comments