|
14 | 14 | from dataclasses import dataclass
|
15 | 15 | from typing import (
|
16 | 16 | Any,
|
| 17 | + Callable, |
17 | 18 | cast,
|
18 | 19 | ContextManager,
|
19 | 20 | Generic,
|
|
29 | 30 | from pyre_extensions import none_throws
|
30 | 31 | from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
|
31 | 32 | from torch.distributed.tensor import DTensor
|
| 33 | +from torch.distributed.tensor.parallel.loss import loss_parallel |
32 | 34 | from torch.nn.parallel import DistributedDataParallel as DDP
|
33 | 35 | from torch.optim.swa_utils import SWALR
|
34 | 36 | from torchtnt.framework._unit_utils import _step_requires_iterator
|
|
53 | 55 | prepare_module,
|
54 | 56 | Strategy,
|
55 | 57 | TorchCompileParams,
|
| 58 | + TPStrategy, |
56 | 59 | )
|
57 | 60 | from torchtnt.utils.swa import AveragedModel
|
58 | 61 | from typing_extensions import Literal
|
@@ -477,6 +480,7 @@ class AutoUnit(
|
477 | 480 | enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
|
478 | 481 | zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.
|
479 | 482 | global_mesh: an instance of :class:`~torchtnt.utils.device_mesh.GlobalMeshCoordinator` which defines the global mesh topology. Needed to configure TP or 2D parallelism strategies.
|
| 483 | + enable_loss_parallel: if True, the loss will be computed in parallel across all ranks. This is only supported for TP strategy + cross entropy loss. |
480 | 484 |
|
481 | 485 | Note:
|
482 | 486 | Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
|
@@ -514,6 +518,7 @@ def __init__(
|
514 | 518 | enable_prefetch: bool = True,
|
515 | 519 | zero_grad_at_train_step_start: bool = False,
|
516 | 520 | global_mesh: Optional[GlobalMeshCoordinator] = None,
|
| 521 | + enable_loss_parallel: bool = False, |
517 | 522 | ) -> None:
|
518 | 523 | super().__init__(
|
519 | 524 | module=module,
|
@@ -589,6 +594,16 @@ def __init__(
|
589 | 594 | # keep track of when to zero grad at train step start
|
590 | 595 | self._weight_updated_in_prev_step = False
|
591 | 596 |
|
| 597 | + if enable_loss_parallel: |
| 598 | + if not isinstance(strategy, TPStrategy): |
| 599 | + raise ValueError( |
| 600 | + "enable_loss_parallel is only supported with TPStrategy" |
| 601 | + ) |
| 602 | + # pyre-fixme[24]: Attribute must be annotated. |
| 603 | + self.maybe_loss_parallel: Callable = ( |
| 604 | + loss_parallel if enable_loss_parallel else contextlib.nullcontext |
| 605 | + ) |
| 606 | + |
592 | 607 | def __setattr__(self, name: str, value: object) -> None:
|
593 | 608 | if isinstance(value, torch.nn.Module):
|
594 | 609 | self._validate_module_attr(name, value)
|
@@ -700,7 +715,7 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
|
700 | 715 | )
|
701 | 716 |
|
702 | 717 | grad_scaler = self.grad_scaler
|
703 |
| - with maybe_no_sync, maybe_detect_anomaly: |
| 718 | + with maybe_no_sync, maybe_detect_anomaly, self.maybe_loss_parallel(): |
704 | 719 | with self.maybe_autocast_precision:
|
705 | 720 | with get_timing_context(
|
706 | 721 | state, f"{self.__class__.__name__}.compute_loss"
|
|
0 commit comments