Skip to content

Commit cc71d2a

Browse files
committed
initial effort
1 parent 1f02964 commit cc71d2a

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

torchtitan/experiments/rl/config_registry.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,47 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
104104
)
105105

106106

107+
def rl_grpo_qwen3_30b_a3b() -> RLTrainer.Config:
108+
"""GRPO training config for Qwen3-30B-A3B MoE (6 GPUs: 4 gen + 2 train)."""
109+
return RLTrainer.Config(
110+
model_spec=model_registry("30B-A3B"),
111+
hf_assets_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-30B-A3B",
112+
num_steps=10,
113+
batch_invariant_mode=True,
114+
trainer=PolicyTrainer.Config(
115+
optimizer=OptimizersContainer.Config(lr=2e-6),
116+
lr_scheduler=LRSchedulersContainer.Config(
117+
warmup_steps=2,
118+
decay_type="linear",
119+
),
120+
training=TrainingConfig(),
121+
parallelism=ParallelismConfig(
122+
tensor_parallel_degree=2,
123+
expert_parallel_degree=1,
124+
expert_tensor_parallel_degree=1,
125+
),
126+
),
127+
generator=VLLMGenerator.Config(
128+
model_dtype="bfloat16",
129+
compile=GeneratorCompileConfig(
130+
backend="none",
131+
cudagraph_mode="none",
132+
),
133+
parallelism=ParallelismConfig(
134+
tensor_parallel_degree=4,
135+
data_parallel_replicate_degree=1,
136+
),
137+
num_samples_per_prompt=8,
138+
sampling=SamplingConfig(
139+
temperature=0.8,
140+
top_p=0.95,
141+
max_tokens=100,
142+
),
143+
attention_backend="CUSTOM",
144+
),
145+
)
146+
147+
107148
def rl_grpo_qwen3_debug() -> RLTrainer.Config:
108149
"""Debug config for quick iteration -- small model, few steps (2 GPUs: 1 gen + 1 train)."""
109150
return RLTrainer.Config(

torchtitan/experiments/rl/models/parallelize.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtitan.config.configs import CompileConfig
2828
from torchtitan.distributed import ParallelDims
2929
from torchtitan.distributed.compile import apply_compile_dense_rl
30+
from torchtitan.models.llama4.parallelize import apply_moe_ep_tp
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -65,6 +66,15 @@ def parallelize_qwen3(
6566
has_position_id=has_position_id,
6667
)
6768

69+
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
70+
apply_moe_ep_tp(
71+
model,
72+
tp_mesh=parallel_dims.get_optional_mesh("tp"),
73+
ep_mesh=parallel_dims.get_optional_mesh("ep"),
74+
etp_mesh=parallel_dims.get_optional_mesh("etp"),
75+
ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]),
76+
)
77+
6878
if (
6979
compile_config is not None
7080
and compile_config.enable
@@ -169,11 +179,6 @@ def apply_non_moe_tp(
169179
"feed_forward.w3": ColwiseParallel(use_local_output=False),
170180
}
171181
)
172-
else:
173-
raise ValueError(
174-
"Running vLLM inference with torchtitan Qwen3 MoE model is not supported yet."
175-
)
176-
177182
parallelize_module(
178183
# pyrefly: ignore [bad-argument-type]
179184
module=transformer_block,

torchtitan/models/llama4/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def parallelize_llama(
214214
if parallel_dims.dp_replicate_enabled
215215
else ["efsdp"]
216216
)
217-
edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names)
217+
edp_mesh: DeviceMesh | None = parallel_dims.get_optional_mesh(edp_mesh_names)
218218

219219
apply_fsdp(
220220
model,

0 commit comments

Comments
 (0)