You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #1014
# Context
Fsdp2 users may want to shard based on layer names.
# This Diff
Adds `shard_predicates` parameter so custom functions can be used to check if need to shard on submodules
Reviewed By: galrotem
Differential Revision: D77236696
fbshipit-source-id: 2789e4019f20d5abdd6405770b326cc36e6d3bf0
@@ -192,15 +194,18 @@ class FSDP2Strategy(Strategy):
192
194
For more details on the args, see the link.
193
195
194
196
Args:
195
-
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
196
-
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
197
+
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. Specify None to not shard any modules with this flag.
198
+
shard_predicates: A list of predicates to decide which modules to shard with FSDP. Each predicate takes a module name (fqn) and the module itself. If any predicate returns True, the submodule is sharded.
199
+
reshard_after_forward: If True, reshards parameters post-forward pass to save memory.
197
200
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
198
201
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.
199
202
200
203
Note:
201
204
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
202
205
communication overhead.
203
206
207
+
Note: modules_to_shard and shard_predicates are applied sequentially. If a module is specified in modules_to_shard, it will be sharded regardless of shard_predicates, and vice-versa
208
+
204
209
Example:
205
210
>>> model
206
211
TransformerDecoder(
@@ -222,10 +227,15 @@ class FSDP2Strategy(Strategy):
0 commit comments