Skip to content

Commit 9e24689

Browse files
authored
[DeepSeek] add torch.compile + async TP (#1588)
verified that torch.compile works. However, I didn't see async TP in trace. cc @danielvegamyhre @fegin Could you help take a look?
1 parent a54725c commit 9e24689

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,9 @@ def parallelize_llama(
108108

109109
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
110110
if job_config.training.compile:
111-
apply_compile(model)
112-
113111
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
114112
torch._dynamo.config.capture_scalar_outputs = True
113+
apply_compile(model)
115114

116115
dp_mesh: DeviceMesh | None = None
117116
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -503,7 +502,7 @@ def apply_compile(model: nn.Module):
503502
repeated structure. Alternatively one can compile the whole model (after applying DP).
504503
"""
505504
for layer_id, transformer_block in model.layers.named_children():
506-
# TODO: remove when torch.compile supports fullgraph=True for llama4 moe
505+
# TODO: remove when torch.compile supports fullgraph=True for MoE
507506
fullgraph = True
508507
if transformer_block.moe_enabled:
509508
fullgraph = False

torchtitan/models/deepseek_v3/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml
4747
- Tensor Parallel (TP)
4848
- Expert Parallel (EP)
4949
- Pipeline Parallel (PP)
50+
- torch.compile
5051

5152

5253
## HuggingFace -> DCP Checkpoint Conversion
@@ -65,8 +66,8 @@ Some limitations:
6566
## To be added
6667
- Parallelism
6768
- Context Parallel support for DeepSeek V3
68-
- torch.compile
6969
- Quantization
7070
- Testing
71-
- perfomance and loss converging tests
72-
- CI integration
71+
- loss converging tests (verified)
72+
- perfomance (WIP)
73+
- CI integration (WIP)

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,11 @@ def parallelize_deepseekv3(
4747
raise NotImplementedError("CP support for FlexAttention is still in progress.")
4848

4949
if parallel_dims.tp_enabled:
50-
if job_config.parallelism.enable_async_tensor_parallel:
51-
# TODO(jianiw): This branch needs to be tested and enabled
52-
raise NotImplementedError(
53-
"Currently, async TP is not tested for deepseekv3. \
54-
torch.compile is not supported yet, which is required for async TP."
55-
)
50+
if (
51+
job_config.parallelism.enable_async_tensor_parallel
52+
and not job_config.training.compile
53+
):
54+
raise RuntimeError("Async TP requires --training.compile")
5655

5756
enable_float8_linear = "float8" in job_config.model.converters
5857
float8_is_rowwise = job_config.float8.recipe_name in (
@@ -94,7 +93,9 @@ def parallelize_deepseekv3(
9493
apply_ac(model, job_config.activation_checkpoint)
9594

9695
if job_config.training.compile:
97-
raise NotImplementedError("torch.compile is not supported yet for deepseekv3")
96+
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
97+
torch._dynamo.config.capture_scalar_outputs = True
98+
apply_compile(model)
9899

99100
dp_mesh: DeviceMesh | None = None
100101
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -251,6 +252,12 @@ def apply_non_moe_tp(
251252
parallelize_plan=layer_plan,
252253
)
253254

255+
if enable_async_tp:
256+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
257+
258+
torch._inductor.config._micro_pipeline_tp = True
259+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
260+
254261
logger.info(
255262
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
256263
"Tensor Parallelism to the model"

0 commit comments

Comments
 (0)