Skip to content

Commit 3e0d4b1

Browse files
authored
[algo] feat: reduce routed expert padding via NestedTensor and uint8 dtype (verl-project#5240)
This PR optimizes the routed expert data to reduce communication and memory overhead. - Converts `routed_experts` into a `NestedTensor` representation to avoid padding-heavy dense tensors. - Packs routed expert data into `uint8` format to reduce transmission size. - Removes unnecessary `attention_mask` propagation for routed expert execution. ### Experimental Results The results indicate that the proposed optimization reduces padding-related communication and memory overhead by **around 15%** compared to the original implementation, while preserving execution correctness. <img width="425" height="292" alt="企业微信截图_92aac7da-0169-491e-af86-c1a38661ac7e" src="https://github.com/user-attachments/assets/393cd669-d75d-405c-b61f-95d8926076e9" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Signed-off-by: xhx1022 <1737006628@qq.com>
1 parent 0b0769e commit 3e0d4b1

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

verl/utils/megatron/router_replay_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from verl.models.mcore.util import (
3838
postprocess_packed_seqs,
3939
preprocess_packed_seqs,
40+
preprocess_thd_no_padding,
4041
)
4142
from verl.utils.device import get_device_name
4243
from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction
@@ -233,7 +234,10 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N
233234
None: The function updates internal RouterReplay instances in-place.
234235
"""
235236
with torch.no_grad():
236-
layers_topk_idx_rmpad, _ = preprocess_packed_seqs(layers_topk_idx, attention_mask, pre_process=True)
237+
if layers_topk_idx.is_nested:
238+
layers_topk_idx_rmpad, _ = preprocess_thd_no_padding(layers_topk_idx, pre_process=True)
239+
else:
240+
layers_topk_idx_rmpad, _ = preprocess_packed_seqs(layers_topk_idx, attention_mask, pre_process=True)
237241
layers_topk_idx_rmpad = layers_topk_idx_rmpad.contiguous() # 1, dynamic_bs_all, layer_num, topk
238242

239243
# 1, dynamic_bs_split, layer_num, topk

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,14 +666,12 @@ def prepare_model_inputs(self, batch: TensorDict):
666666
multi_modal_inputs = extract_multi_modal_inputs(batch.get("multi_modal_inputs", []))
667667

668668
routed_experts = batch.get("routed_experts", [])
669-
attention_mask = batch.get("attention_mask", None)
670669

671670
return {
672671
"input_ids": input_ids,
673672
"loss_mask": loss_mask,
674673
"multi_modal_inputs": multi_modal_inputs,
675674
"routed_experts": routed_experts,
676-
"attention_mask": attention_mask,
677675
}
678676

679677
def prepare_model_outputs(self, output: dict, data: TensorDict):
@@ -712,8 +710,7 @@ def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micr
712710

713711
if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank):
714712
layers_topk_idx = model_inputs["routed_experts"]
715-
attention_mask = model_inputs["attention_mask"].to(bool)
716-
set_router_replay_data(layers_topk_idx, attention_mask, self.tf_config, vp_rank)
713+
set_router_replay_data(layers_topk_idx, None, self.tf_config, vp_rank)
717714

718715
if pad_mode == DatasetPadMode.NO_PADDING:
719716
label = input_ids.clone()

verl/workers/utils/padding.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tensordict import TensorDict
1818

1919
from verl.utils import tensordict_utils as tu
20-
from verl.utils.attention_utils import unpad_input
20+
from verl.utils.attention_utils import index_first_axis, unpad_input
2121

2222

2323
def left_right_2_no_padding(data: TensorDict) -> TensorDict:
@@ -70,6 +70,16 @@ def left_right_2_no_padding(data: TensorDict) -> TensorDict:
7070
data["position_ids"] = position_ids_nested
7171
data["loss_mask"] = data["response_mask"]
7272

73+
routed_experts = data.get("routed_experts", None)
74+
if routed_experts is not None and not routed_experts.is_nested:
75+
if routed_experts.max() <= 255:
76+
routed_experts = routed_experts.to(torch.uint8)
77+
routed_experts_rmpad = index_first_axis(routed_experts.unsqueeze(-1).flatten(0, 1), indices)
78+
routed_experts_nested = torch.nested.nested_tensor_from_jagged(
79+
routed_experts_rmpad.squeeze(-1), offsets=cu_seqlens
80+
)
81+
data["routed_experts"] = routed_experts_nested
82+
7383
return data
7484

7585

0 commit comments

Comments
 (0)