@@ -45,12 +45,22 @@ def parallelize_qwen3(
45
45
Sequence length { job_config .training .seq_len } must be divisible by the product of TP degree
46
46
({ parallel_dims .tp } ) and 2 * CP degree ({ parallel_dims .cp } ).
47
47
"""
48
+
49
+ if (
50
+ job_config .parallelism .context_parallel_degree > 1
51
+ and model .model_args .use_flex_attn
52
+ ):
53
+ raise NotImplementedError ("CP support for FlexAttention is still in progress." )
54
+
55
+ model_compile_enabled = (
56
+ job_config .compile .enable and "model" in job_config .compile .components
57
+ )
48
58
if parallel_dims .tp_enabled :
49
59
if (
50
60
job_config .parallelism .enable_async_tensor_parallel
51
- and not job_config . training . compile
61
+ and not model_compile_enabled
52
62
):
53
- raise RuntimeError ("Async TP requires --training .compile" )
63
+ raise RuntimeError ("Async TP requires torch .compile" )
54
64
55
65
enable_float8_linear = "float8" in job_config .model .converters
56
66
float8_is_rowwise = job_config .float8 .recipe_name in (
@@ -75,7 +85,7 @@ def parallelize_qwen3(
75
85
apply_ac (model , job_config .activation_checkpoint )
76
86
77
87
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
78
- if job_config . training . compile :
88
+ if model_compile_enabled :
79
89
apply_compile (model )
80
90
81
91
if parallel_dims .fsdp_enabled :
@@ -95,11 +105,6 @@ def parallelize_qwen3(
95
105
reshard_after_forward_policy = job_config .parallelism .fsdp_reshard_after_forward ,
96
106
)
97
107
98
- if parallel_dims .dp_replicate_enabled :
99
- logger .info ("Applied HSDP to the model" )
100
- else :
101
- logger .info ("Applied FSDP to the model" )
102
-
103
108
if parallel_dims .dp_replicate_enabled :
104
109
logger .info ("Applied HSDP to the model" )
105
110
else :
@@ -116,7 +121,7 @@ def parallelize_qwen3(
116
121
apply_ddp (
117
122
model ,
118
123
world_mesh ,
119
- enable_compile = job_config . training . compile ,
124
+ enable_compile = model_compile_enabled ,
120
125
enable_compiled_autograd = job_config .parallelism .enable_compiled_autograd ,
121
126
)
122
127
0 commit comments