Skip to content

Commit cab22e7

Browse files
authored
Centralize Async TP Enablement with maybe_enable_async_tp API (#1619)
This PR addresses duplicated code related to enabling async TP across different parts of the codebase. It introduces a new API, `maybe_enable_async_tp()`, which centralizes the enablement logic and is reused consistently in all models. Note that while this PR fixes one async TP bug in TorchTitan, it does not fully resolve #1613, as there appear to be additional bugs in PyTorch's async TP implementation.
1 parent f738a03 commit cab22e7

File tree

4 files changed

+37
-49
lines changed

4 files changed

+37
-49
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from torch.distributed.device_mesh import DeviceMesh
10+
11+
from torchtitan.config import JobConfig
12+
from torchtitan.tools.logging import logger
13+
14+
15+
def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh):
16+
if not job_config.parallelism.enable_async_tensor_parallel:
17+
return
18+
19+
if not (job_config.compile.enable and "model" in job_config.compile.components):
20+
raise RuntimeError("Async TP requires --training.compile")
21+
22+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
23+
24+
torch._inductor.config._micro_pipeline_tp = True
25+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
26+
27+
logger.info("Async TP is enabled")

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ReordererSequenceParallel,
2929
TensorParallel,
3030
)
31+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
3132

3233
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
3334
from torchtitan.tools.logging import logger
@@ -66,12 +67,6 @@ def parallelize_llama(
6667
job_config.compile.enable and "model" in job_config.compile.components
6768
)
6869
if parallel_dims.tp_enabled:
69-
if (
70-
job_config.parallelism.enable_async_tensor_parallel
71-
and not model_compile_enabled
72-
):
73-
raise RuntimeError("Async TP requires torch.compile")
74-
7570
enable_float8_linear = "float8" in job_config.model.converters
7671
float8_is_rowwise = job_config.float8.recipe_name in (
7772
"rowwise",
@@ -88,8 +83,8 @@ def parallelize_llama(
8883
world_mesh["tp"],
8984
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9085
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
91-
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
9286
)
87+
maybe_enable_async_tp(job_config, world_mesh["tp"])
9388

9489
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
9590
apply_moe_ep_tp(
@@ -177,7 +172,6 @@ def apply_non_moe_tp(
177172
tp_mesh: DeviceMesh,
178173
loss_parallel: bool,
179174
enable_float8_tensorwise_tp: bool,
180-
enable_async_tp: bool,
181175
):
182176
"""Apply tensor parallelism."""
183177
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -256,14 +250,8 @@ def apply_non_moe_tp(
256250
parallelize_plan=layer_plan,
257251
)
258252

259-
if enable_async_tp:
260-
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
261-
262-
torch._inductor.config._micro_pipeline_tp = True
263-
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
264-
265253
logger.info(
266-
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
254+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
267255
"Tensor Parallelism to the model"
268256
)
269257

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2020
from torchtitan.distributed import ParallelDims
2121
from torchtitan.distributed.expert_parallel import NoParallel
22+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
2223
from torchtitan.experiments.llama4.infra.parallelize import (
2324
apply_compile,
2425
apply_fsdp,
@@ -51,16 +52,7 @@ def parallelize_deepseekv3(
5152
):
5253
raise NotImplementedError("CP support for FlexAttention is still in progress.")
5354

54-
model_compile_enabled = (
55-
job_config.compile.enable and "model" in job_config.compile.components
56-
)
5755
if parallel_dims.tp_enabled:
58-
if (
59-
job_config.parallelism.enable_async_tensor_parallel
60-
and not model_compile_enabled
61-
):
62-
raise RuntimeError("Async TP requires --training.compile")
63-
6456
enable_float8_linear = "float8" in job_config.model.converters
6557
float8_is_rowwise = job_config.float8.recipe_name in (
6658
"rowwise",
@@ -79,8 +71,8 @@ def parallelize_deepseekv3(
7971
world_mesh["tp"],
8072
loss_parallel=not job_config.parallelism.disable_loss_parallel,
8173
enable_float8_tensorwise_tp=False,
82-
enable_async_tp=False,
8374
)
75+
maybe_enable_async_tp(job_config, world_mesh["tp"])
8476

8577
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
8678
apply_moe_ep_tp(
@@ -100,7 +92,7 @@ def parallelize_deepseekv3(
10092
if job_config.activation_checkpoint.mode != "none":
10193
apply_ac(model, job_config.activation_checkpoint)
10294

103-
if model_compile_enabled:
95+
if job_config.compile.enable and "model" in job_config.compile.components:
10496
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
10597
torch._dynamo.config.capture_scalar_outputs = True
10698
apply_compile(model)
@@ -167,7 +159,6 @@ def apply_non_moe_tp(
167159
tp_mesh: DeviceMesh,
168160
loss_parallel: bool,
169161
enable_float8_tensorwise_tp: bool,
170-
enable_async_tp: bool,
171162
):
172163
"""Apply tensor parallelism."""
173164
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -260,13 +251,7 @@ def apply_non_moe_tp(
260251
parallelize_plan=layer_plan,
261252
)
262253

263-
if enable_async_tp:
264-
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
265-
266-
torch._inductor.config._micro_pipeline_tp = True
267-
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
268-
269254
logger.info(
270-
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
255+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
271256
"Tensor Parallelism to the model"
272257
)

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
3232
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
3333
from torchtitan.distributed import ParallelDims
34+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
3435
from torchtitan.tools.logging import logger
3536

3637

@@ -67,12 +68,6 @@ def parallelize_llama(
6768
job_config.compile.enable and "model" in job_config.compile.components
6869
)
6970
if parallel_dims.tp_enabled:
70-
if (
71-
job_config.parallelism.enable_async_tensor_parallel
72-
and not model_compile_enabled
73-
):
74-
raise RuntimeError("Async TP requires torch.compile")
75-
7671
enable_float8_linear = "float8" in job_config.model.converters
7772
float8_is_rowwise = job_config.float8.recipe_name in (
7873
"rowwise",
@@ -89,8 +84,8 @@ def parallelize_llama(
8984
world_mesh["tp"],
9085
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9186
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
92-
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
9387
)
88+
maybe_enable_async_tp(job_config, world_mesh["tp"])
9489

9590
if job_config.activation_checkpoint.mode != "none":
9691
apply_ac(model, job_config.activation_checkpoint)
@@ -144,7 +139,6 @@ def apply_tp(
144139
tp_mesh: DeviceMesh,
145140
loss_parallel: bool,
146141
enable_float8_tensorwise_tp: bool,
147-
enable_async_tp: bool,
148142
):
149143
"""Apply tensor parallelism."""
150144
# 1. Parallelize the embedding and shard its outputs (which are the first
@@ -221,14 +215,8 @@ def apply_tp(
221215
parallelize_plan=layer_plan,
222216
)
223217

224-
if enable_async_tp:
225-
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
226-
227-
torch._inductor.config._micro_pipeline_tp = True
228-
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
229-
230218
logger.info(
231-
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
219+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}"
232220
"Tensor Parallelism to the model"
233221
)
234222

0 commit comments

Comments
 (0)