Skip to content

Commit cf92dac

Browse files
[trainer] feat: VeOmniEngine supports qwen3_vl ulysses (verl-project#4806)
### What does this PR do? as title. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 252d769 commit cf92dac

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

tests/special_e2e/sft/test_sft_engine_all.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2"
2525
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
2626

2727
# test with veomni
28-
# FIXME(ji-huazhong): set SP=1 cause qwen_vl do not support SP right now
29-
echo "run with sp1 fsdp_size4 num_gpus8 fsdp_strategy fsdp2"
30-
BACKEND=veomni SP_SIZE=1 FSDP_SIZE=8 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
28+
echo "run with sp2 fsdp_size4 num_gpus8 fsdp_strategy fsdp2"
29+
BACKEND=veomni SP_SIZE=2 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
3130

3231

3332
# test with megatron

verl/workers/engine/veomni/transformer_impl.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515

1616
import logging
17-
from typing import Any, Callable
17+
from dataclasses import dataclass, field
18+
from typing import Any, Callable, Sequence
1819

1920
import torch
2021
import torch.distributed as dist
@@ -133,7 +134,6 @@ def initialize(self):
133134
attn_implementation=self.engine_config.attn_implementation,
134135
moe_implementation=self.engine_config.moe_implementation,
135136
init_device=self.engine_config.init_device,
136-
force_use_huggingface=self.engine_config.force_use_huggingface,
137137
)
138138

139139
module_config = self.module.config
@@ -250,10 +250,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
250250
return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)
251251

252252
def get_data_parallel_rank(self):
253-
if parallel_state.get_parallel_state().ulysses_size > 1:
254-
return parallel_state.get_parallel_state().device_mesh["dp"].get_local_rank()
255-
else:
256-
return torch.distributed.get_rank()
253+
return parallel_state.get_parallel_state().device_mesh.get_local_rank("dp")
257254

258255
def get_data_parallel_size(self):
259256
return torch.distributed.get_world_size() // parallel_state.get_parallel_state().ulysses_size
@@ -299,7 +296,7 @@ def __enter__(self):
299296
assert isinstance(self.engine, VeOmniEngine)
300297
super().__enter__()
301298
self.engine.ulysses_sharding_manager.__enter__()
302-
self.engine.module.eval()
299+
self.engine.module.train()
303300

304301
def __exit__(self, exc_type, exc_value, traceback):
305302
assert isinstance(self.engine, VeOmniEngine)
@@ -333,6 +330,41 @@ def __exit__(self, exc_type, exc_value, traceback):
333330
super().__exit__(exc_type, exc_value, traceback)
334331

335332

333+
@dataclass
334+
class OmniSequenceShardCollator:
335+
"""
336+
Data collator to chunk inputs along the sequence length.
337+
"""
338+
339+
# features to slice sequence dimension
340+
sp_slice_features: dict[str, int] = field(
341+
default_factory=lambda: {
342+
"input_ids": -1,
343+
"labels": -1,
344+
"pixel_values": 0,
345+
"pixel_values_videos": 0,
346+
},
347+
metadata={"help": "features to slice sequence dimension."},
348+
)
349+
350+
def __post_init__(self):
351+
self.sp_size = parallel_state.get_parallel_state().sp_size
352+
self.sp_rank = parallel_state.get_parallel_state().sp_rank
353+
354+
def sp_slice(self, feature: torch.Tensor, dim: int = -1) -> dict[str, "torch.Tensor"]:
355+
seq_length = feature.size(dim)
356+
sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
357+
return feature.narrow(dim, self.sp_rank * sp_chunk_size, sp_chunk_size)
358+
359+
def __call__(self, batch: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
360+
# sp slice
361+
for key in batch.keys():
362+
if key in self.sp_slice_features.keys():
363+
batch[key] = self.sp_slice(batch[key], dim=self.sp_slice_features[key])
364+
365+
return batch
366+
367+
336368
@EngineRegistry.register(model_type="language_model", backend=["veomni"], device=["cuda", "npu"])
337369
class VeOmniEngineWithLMHead(VeOmniEngine, FSDPEngineWithLMHead):
338370
def prepare_model_inputs(self, micro_batch: TensorDict):
@@ -344,4 +376,8 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
344376
video_mask = input_ids_rmpad == VL_TYPE2INDEX[self.module.config.model_type]["VIDEO_INPUT_INDEX"]
345377
model_inputs.update({"image_mask": image_mask, "video_mask": video_mask})
346378

379+
if parallel_state.get_parallel_state().sp_enabled:
380+
omni_sequence_shard_collator = OmniSequenceShardCollator()
381+
omni_sequence_shard_collator(model_inputs)
382+
347383
return model_inputs, output_args

0 commit comments

Comments
 (0)