Skip to content

Commit 255a6ab

Browse files
authored
fix qwen3 compile config in parallelize.py (#1623)
fix compile config in parallelize.py, ref https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py
1 parent 2bfcdd8 commit 255a6ab

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,22 @@ def parallelize_qwen3(
4545
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
4646
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
4747
"""
48+
49+
if (
50+
job_config.parallelism.context_parallel_degree > 1
51+
and model.model_args.use_flex_attn
52+
):
53+
raise NotImplementedError("CP support for FlexAttention is still in progress.")
54+
55+
model_compile_enabled = (
56+
job_config.compile.enable and "model" in job_config.compile.components
57+
)
4858
if parallel_dims.tp_enabled:
4959
if (
5060
job_config.parallelism.enable_async_tensor_parallel
51-
and not job_config.training.compile
61+
and not model_compile_enabled
5262
):
53-
raise RuntimeError("Async TP requires --training.compile")
63+
raise RuntimeError("Async TP requires torch.compile")
5464

5565
enable_float8_linear = "float8" in job_config.model.converters
5666
float8_is_rowwise = job_config.float8.recipe_name in (
@@ -75,7 +85,7 @@ def parallelize_qwen3(
7585
apply_ac(model, job_config.activation_checkpoint)
7686

7787
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
78-
if job_config.training.compile:
88+
if model_compile_enabled:
7989
apply_compile(model)
8090

8191
if parallel_dims.fsdp_enabled:
@@ -95,11 +105,6 @@ def parallelize_qwen3(
95105
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
96106
)
97107

98-
if parallel_dims.dp_replicate_enabled:
99-
logger.info("Applied HSDP to the model")
100-
else:
101-
logger.info("Applied FSDP to the model")
102-
103108
if parallel_dims.dp_replicate_enabled:
104109
logger.info("Applied HSDP to the model")
105110
else:
@@ -116,7 +121,7 @@ def parallelize_qwen3(
116121
apply_ddp(
117122
model,
118123
world_mesh,
119-
enable_compile=job_config.training.compile,
124+
enable_compile=model_compile_enabled,
120125
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121126
)
122127

0 commit comments

Comments
 (0)