You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #1003
Adds `global_mesh` arg in AutoUnit and forwards into `prepare_module` for model sharding
Reviewed By: vdogaru
Differential Revision: D74410711
fbshipit-source-id: fb7caedef706c9d8f7876f14d6d31e1d4aaa7151
AutoPredictUnit is a convenience for users who are running inference and would like to have certain features handled for them, such as:
@@ -348,6 +350,7 @@ def __init__(
348
350
strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp`` or ``fsdp``.
349
351
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html
350
352
detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection
353
+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
351
354
352
355
Note:
353
356
Torch compile support is only available in PyTorch 2.0 or higher.
@@ -365,6 +368,7 @@ def __init__(
365
368
self.device,
366
369
strategy=strategy,
367
370
torch_compile_params=torch_compile_params,
371
+
global_mesh=global_mesh,
368
372
)
369
373
370
374
# pyre-fixme[3]: Return annotation cannot be `Any`.
@@ -474,6 +478,7 @@ class AutoUnit(
474
478
in a much more efficient way.
475
479
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
476
480
zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.
481
+
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
477
482
478
483
Note:
479
484
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
0 commit comments