Skip to content

Commit 9591481

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
enable loss parallel support (#1019)
Summary: Pull Request resolved: #1019 Reviewed By: diego-urgell Differential Revision: D79193109 fbshipit-source-id: 1a6fa82ab7eebb4186a7462c9d3ae8d3d19d12f9
1 parent 5d188e8 commit 9591481

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dataclasses import dataclass
1515
from typing import (
1616
Any,
17+
Callable,
1718
cast,
1819
ContextManager,
1920
Generic,
@@ -29,6 +30,7 @@
2930
from pyre_extensions import none_throws
3031
from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
3132
from torch.distributed.tensor import DTensor
33+
from torch.distributed.tensor.parallel.loss import loss_parallel
3234
from torch.nn.parallel import DistributedDataParallel as DDP
3335
from torch.optim.swa_utils import SWALR
3436
from torchtnt.framework._unit_utils import _step_requires_iterator
@@ -53,6 +55,7 @@
5355
prepare_module,
5456
Strategy,
5557
TorchCompileParams,
58+
TPStrategy,
5659
)
5760
from torchtnt.utils.swa import AveragedModel
5861
from typing_extensions import Literal
@@ -477,6 +480,7 @@ class AutoUnit(
477480
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
478481
zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.
479482
global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
483+
enable_loss_parallel: if True, the loss will be computed in parallel across all ranks. This is only supported for TP strategy + cross entropy loss.
480484
481485
Note:
482486
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
@@ -514,6 +518,7 @@ def __init__(
514518
enable_prefetch: bool = True,
515519
zero_grad_at_train_step_start: bool = False,
516520
global_mesh: Optional[GlobalMeshCoordinator] = None,
521+
enable_loss_parallel: bool = False,
517522
) -> None:
518523
super().__init__(
519524
module=module,
@@ -589,6 +594,16 @@ def __init__(
589594
# keep track of when to zero grad at train step start
590595
self._weight_updated_in_prev_step = False
591596

597+
if enable_loss_parallel:
598+
if not isinstance(strategy, TPStrategy):
599+
raise ValueError(
600+
"enable_loss_parallel is only supported with TPStrategy"
601+
)
602+
# pyre-fixme[24]: Attribute must be annotated.
603+
self.maybe_loss_parallel: Callable = (
604+
loss_parallel if enable_loss_parallel else contextlib.nullcontext
605+
)
606+
592607
def __setattr__(self, name: str, value: object) -> None:
593608
if isinstance(value, torch.nn.Module):
594609
self._validate_module_attr(name, value)
@@ -700,7 +715,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
700715
)
701716

702717
grad_scaler = self.grad_scaler
703-
with maybe_no_sync, maybe_detect_anomaly:
718+
with maybe_no_sync, maybe_detect_anomaly, self.maybe_loss_parallel():
704719
with self.maybe_autocast_precision:
705720
with get_timing_context(
706721
state, f"{self.__class__.__name__}.compute_loss"

0 commit comments

Comments
 (0)