Skip to content

Commit e1cd47b

Browse files
authored
[algo, rollout, sglang] feat: Support router replay with sglang (verl-project#4840)
### What does this PR do? Support router replay with sglang ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example maybe use with sgl-project/sglang#15751 if you want to set chunked_prefill_size = -1 ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 94f4654 commit e1cd47b

File tree

4 files changed

+154
-11
lines changed

4 files changed

+154
-11
lines changed

examples/router_replay/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ actor_rollout_ref.actor.router_replay.mode="R3"
6969
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
7070
```
7171

72-
R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284.
72+
R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284 and SGLang implementation at https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
2+
set -x
3+
4+
NODES=6
5+
6+
# R2: enable routing replay
7+
# R3: enable rollout routing replay
8+
# If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True
9+
# R3 example is based on SGLang related commit https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051
10+
11+
ROUTING_REPLAY_MODE="R3"
12+
13+
DIST_CKPT_PATH=""
14+
HF_MODEL_PATH=""
15+
TRAIN_DATA_PATH=""
16+
TEST_DATA_PATH=""
17+
18+
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
19+
PP=6
20+
VPP=None
21+
TP=1
22+
EP=8
23+
ETP=1
24+
SGLANG_INFER_TP=4
25+
offload=True
26+
gpu_memory_utilization=0.65
27+
bs=3
28+
micro_bs=3
29+
use_dynamic_bsz=False
30+
max_prompt_length=512
31+
max_response_length=512
32+
ppo_mini_batch_size=3
33+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
34+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
35+
36+
37+
exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${SGLANG_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs}
38+
39+
python3 -m verl.trainer.main_ppo --config-path=config \
40+
--config-name='ppo_megatron_trainer.yaml' \
41+
algorithm.adv_estimator=grpo \
42+
data.train_files=$TRAIN_DATA_PATH \
43+
data.val_files=$TEST_DATA_PATH \
44+
data.train_batch_size=$bs \
45+
data.max_prompt_length=$max_prompt_length \
46+
data.max_response_length=$max_response_length \
47+
data.filter_overlong_prompts=True \
48+
data.truncation='error' \
49+
actor_rollout_ref.model.use_fused_kernels=True \
50+
actor_rollout_ref.model.path=$HF_MODEL_PATH \
51+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
52+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
53+
actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE} \
54+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
55+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
56+
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
57+
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \
58+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
59+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
60+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
61+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
62+
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
63+
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=False \
64+
actor_rollout_ref.actor.megatron.param_offload=${offload} \
65+
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
66+
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
67+
actor_rollout_ref.actor.optim.lr=1e-6 \
68+
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
69+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_bs \
70+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
71+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
72+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
73+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
74+
actor_rollout_ref.actor.use_kl_loss=False \
75+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
76+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
77+
actor_rollout_ref.actor.entropy_coeff=0 \
78+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
79+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
80+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \
81+
actor_rollout_ref.rollout.tensor_model_parallel_size=$SGLANG_INFER_TP \
82+
actor_rollout_ref.rollout.name=sglang \
83+
actor_rollout_ref.rollout.enable_rollout_routing_replay=True \
84+
actor_rollout_ref.rollout.skip_tokenizer_init=True \
85+
actor_rollout_ref.rollout.mode=async \
86+
actor_rollout_ref.actor.megatron.use_mbridge=True \
87+
actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \
88+
actor_rollout_ref.rollout.n=8 \
89+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
90+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
91+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$micro_bs \
92+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
93+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
94+
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
95+
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
96+
actor_rollout_ref.ref.megatron.param_offload=${offload} \
97+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
98+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
99+
algorithm.use_kl_in_reward=False \
100+
trainer.critic_warmup=0 \
101+
trainer.logger=['console'] \
102+
trainer.project_name='verl_grpo_example_gsm8k_math' \
103+
trainer.experiment_name="$exper_name" \
104+
trainer.nnodes=$NODES \
105+
trainer.n_gpus_per_node=8 \
106+
trainer.save_freq=-1 \
107+
trainer.test_freq=10 \
108+
trainer.total_training_steps=50000 \
109+
trainer.balance_batch=False \
110+
trainer.val_before_train=False 2>&1

verl/experimental/agent_loop/agent_loop.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,12 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
587587
if output.routed_experts is not None:
588588
total_length = input_ids.shape[1]
589589
length, layer_num, topk_num = output.routed_experts.shape
590-
experts_tensor = torch.from_numpy(output.routed_experts)
590+
if isinstance(output.routed_experts, np.ndarray):
591+
experts_tensor = torch.from_numpy(output.routed_experts)
592+
elif isinstance(output.routed_experts, torch.Tensor):
593+
experts_tensor = output.routed_experts
594+
else:
595+
raise TypeError(f"Unsupported type for routed_experts: {type(output.routed_experts)}")
591596
routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype)
592597

593598
# Calculate start position: left padding means original prompt starts at the end

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
200200
enable_weights_cpu_backup = True if self.rollout_mode == RolloutMode.COLOCATED else False
201201
args["enable_weights_cpu_backup"] = enable_weights_cpu_backup
202202

203+
if self.config.enable_rollout_routing_replay:
204+
args.update({"enable_return_routed_experts": True})
205+
203206
# NOTE: We can't directly call SGLang's launch_server since it's not an async function.
204207
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py
205208
sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config
@@ -297,16 +300,22 @@ async def generate(
297300
sampling_params["max_new_tokens"] = max_new_tokens
298301
return_logprob = sampling_params.pop("logprobs", False)
299302

300-
request = GenerateReqInput(
301-
rid=request_id,
302-
input_ids=prompt_ids,
303-
sampling_params=sampling_params,
304-
return_logprob=return_logprob,
305-
image_data=image_data,
303+
request = {
304+
"rid": request_id,
305+
"input_ids": prompt_ids,
306+
"sampling_params": sampling_params,
307+
"return_logprob": return_logprob,
308+
"image_data": image_data,
306309
# TODO: support video input for sglang
307310
# video_data=video_data,
308-
)
309-
output = await self.tokenizer_manager.generate_request(request, None).__anext__()
311+
}
312+
313+
if self.config.enable_rollout_routing_replay:
314+
request.update({"return_routed_experts": True})
315+
316+
generate_request = GenerateReqInput(**request)
317+
318+
output = await self.tokenizer_manager.generate_request(generate_request, None).__anext__()
310319
if return_logprob:
311320
output_token_logprobs = output["meta_info"]["output_token_logprobs"]
312321
log_probs, token_ids = zip(
@@ -315,7 +324,26 @@ async def generate(
315324
else:
316325
token_ids = output["output_ids"]
317326
log_probs = None
318-
return TokenOutput(token_ids=token_ids, log_probs=log_probs)
327+
328+
routed_experts = None
329+
if self.config.enable_rollout_routing_replay:
330+
if self.config.skip_tokenizer_init:
331+
routed_experts = output.get("meta_info", {}).get("routed_experts", None)
332+
else:
333+
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
334+
335+
hf_config = self.model_config.hf_config
336+
if not hasattr(hf_config, "num_hidden_layers") or not hasattr(hf_config, "num_experts_per_tok"):
337+
raise AttributeError(
338+
"enable_rollout_routing_replay is set, but hf_config is missing "
339+
"'num_hidden_layers' or 'num_experts_per_tok'. This feature requires an MoE model "
340+
"configuration that defines these attributes."
341+
)
342+
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
343+
-1, hf_config.num_hidden_layers, hf_config.num_experts_per_tok
344+
)
345+
346+
return TokenOutput(token_ids=token_ids, log_probs=log_probs, routed_experts=routed_experts)
319347

320348

321349
_rollout_worker_actor_cls = ray.remote(ServerAdapter)

0 commit comments

Comments
 (0)