Skip to content

Commit b53f0f1

Browse files
[veomni] feat: support model resharding between veomni and rollout engine (#5033)
### What does this PR do? test with Qwen3-30B-A3B-Instruct. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: A1waysBeenHere <moyicong1999@163.com>
1 parent 772c224 commit b53f0f1

File tree

3 files changed

+209
-121
lines changed

3 files changed

+209
-121
lines changed

verl/utils/veomni_utils.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

verl/workers/engine/veomni/transformer_impl.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717
from dataclasses import dataclass, field
18-
from typing import Any, Callable, Sequence
18+
from typing import Any, Callable, Optional, Sequence
1919

2020
import torch
2121
import torch.distributed as dist
@@ -33,20 +33,22 @@
3333
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
3434
from verl.utils.device import get_device_id, get_device_name
3535
from verl.utils.fsdp_utils import fsdp_version
36+
from verl.utils.model import convert_weight_keys
3637
from verl.utils.profiler import log_gpu_memory_usage
37-
from verl.utils.veomni_utils import (
38-
load_veomni_model_to_gpu,
39-
load_veomni_optimizer,
40-
offload_veomni_model_to_cpu,
41-
offload_veomni_optimizer,
42-
)
4338
from verl.workers.config import HFModelConfig, VeOmniEngineConfig, VeOmniOptimizerConfig
4439
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
4540

4641
from ..base import BaseEngineCtx, EngineRegistry
4742
from ..fsdp.transformer_impl import FSDPEngine, FSDPEngineWithLMHead
4843
from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches
49-
from .utils import VL_TYPE2INDEX
44+
from .utils import (
45+
MOE_PARAM_HANDERS,
46+
VL_TYPE2INDEX,
47+
load_veomni_model_to_gpu,
48+
load_veomni_optimizer,
49+
offload_veomni_model_to_cpu,
50+
offload_veomni_optimizer,
51+
)
5052

5153
logger = logging.getLogger(__file__)
5254

@@ -61,23 +63,19 @@ def __init__(
6163
**kwargs,
6264
):
6365
"""
64-
Initialize the FSDPEngine.
66+
Initialize the VeOmniEngine.
6567
6668
Sets up distributed device meshes, LoRA, and offload policies based on config.
6769
6870
Args:
69-
config: Configuration object with FSDP and model settings.
71+
config: Configuration object with VeOmni and model settings.
7072
"""
7173

72-
# TODO: Preprocessing operations for the MOE model are appended here,
73-
# instead of relying on Veomni's transformation scripts.
74-
7574
self.model_config = model_config
7675
self.engine_config = engine_config
7776
self.optimizer_config = optimizer_config
7877
self.checkpoint_config = checkpoint_config
79-
80-
self.mode = None
78+
assert self.engine_config.data_parallel_mode == "fsdp2", "VeOmniEngine only supports fsdp2."
8179

8280
self.rank = dist.get_rank()
8381

@@ -223,34 +221,6 @@ def _build_model_optimizer(self):
223221
self.engine_config.activation_gpu_limit,
224222
)
225223

226-
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
227-
"""
228-
Move model parameters, optimizer states, or both to the specified device.
229-
Note that this function executes irrespective of offload config. It serves as manual control.
230-
231-
Args:
232-
device: Target device identifier.
233-
model: If True, move the model.
234-
optimizer: If True, move the optimizer states.
235-
"""
236-
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)
237-
238-
device_name = get_device_name()
239-
240-
assert device in (device_name, "cpu")
241-
if device == device_name:
242-
if model:
243-
load_veomni_model_to_gpu(self.module)
244-
if optimizer and self.optimizer is not None:
245-
load_veomni_optimizer(self.optimizer, device)
246-
elif device == "cpu":
247-
if model:
248-
offload_veomni_model_to_cpu(self.module)
249-
if optimizer and self.optimizer is not None:
250-
offload_veomni_optimizer(self.optimizer)
251-
else:
252-
raise ValueError(f"Invalid device type: {device}")
253-
254224
def optimizer_step(self):
255225
"""
256226
Perform an optimization step using the optimizer.
@@ -348,6 +318,117 @@ def eval_mode(self, **kwargs):
348318
"""
349319
return EngineEvalModeCtx(self, **kwargs)
350320

321+
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
322+
"""
323+
Move model parameters, optimizer states, or both to the specified device.
324+
Note that this function executes irrespective of offload config. It serves as manual control.
325+
326+
Args:
327+
device: Target device identifier.
328+
model: If True, move the model.
329+
optimizer: If True, move the optimizer states.
330+
"""
331+
super(FSDPEngine, self).to(device=device, model=model, optimizer=optimizer, grad=grad)
332+
333+
device_name = get_device_name()
334+
335+
assert device in (device_name, "cpu")
336+
if device == device_name:
337+
if model:
338+
load_veomni_model_to_gpu(self.module)
339+
if optimizer and self.optimizer is not None:
340+
load_veomni_optimizer(self.optimizer, device)
341+
elif device == "cpu":
342+
if model:
343+
offload_veomni_model_to_cpu(self.module)
344+
if optimizer and self.optimizer is not None:
345+
offload_veomni_optimizer(self.optimizer)
346+
else:
347+
raise ValueError(f"Invalid device type: {device}")
348+
349+
def save_checkpoint(
350+
self,
351+
local_path: str,
352+
hdfs_path: Optional[str] = None,
353+
global_step: int = 0,
354+
max_ckpt_to_keep: Optional[int] = None,
355+
**kwargs,
356+
) -> None:
357+
"""
358+
Save VeOmni checkpoint, handling parameter offload as needed.
359+
"""
360+
origin_module_device = next(self.module.parameters()).device.type
361+
if self._is_offload_param or origin_module_device == "cpu":
362+
load_veomni_model_to_gpu(self.module)
363+
364+
self.checkpoint_manager.save_checkpoint(
365+
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
366+
)
367+
368+
torch.distributed.barrier()
369+
if self._is_offload_param:
370+
offload_veomni_model_to_cpu(self.module)
371+
372+
def load_checkpoint(
373+
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
374+
) -> None:
375+
"""
376+
Load VeOmni checkpoint, restoring parameters and optimizer state.
377+
"""
378+
if self._is_offload_param:
379+
load_veomni_model_to_gpu(self.module)
380+
381+
self.checkpoint_manager.load_checkpoint(
382+
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
383+
)
384+
385+
torch.distributed.barrier()
386+
if self._is_offload_param:
387+
offload_veomni_model_to_cpu(self.module)
388+
389+
if self._is_offload_optimizer:
390+
offload_veomni_optimizer(self.optimizer)
391+
392+
def get_per_tensor_param(self, **kwargs):
393+
load_veomni_model_to_gpu(self.module)
394+
395+
params = self.module.state_dict()
396+
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
397+
398+
if self._is_offload_param:
399+
offload_veomni_model_to_cpu(self.module)
400+
401+
device = get_device_id()
402+
ps = parallel_state.get_parallel_state()
403+
model_type = getattr(self.module.config, "model_type", "default")
404+
process_func = MOE_PARAM_HANDERS.get(model_type, lambda n, t: iter([(n, t)]))
405+
406+
def param_generator():
407+
for name, param in params.items():
408+
unsharded_tensor = param.full_tensor() if isinstance(param, DTensor) else param
409+
410+
is_expert_layer = "mlp.experts." in name
411+
is_proj = any(p in name for p in ["down_proj", "gate_proj", "up_proj", "gate_up_proj"])
412+
413+
if is_expert_layer and is_proj and ps.ep_enabled:
414+
output_shape = list(unsharded_tensor.shape)
415+
output_shape[0] *= ps.ep_size
416+
stacked_tensor = torch.empty(output_shape, dtype=unsharded_tensor.dtype, device=device)
417+
418+
# all gather expert tensors [32, H, I] -> [128, H, I]
419+
torch.distributed.all_gather_into_tensor(stacked_tensor, unsharded_tensor, group=ps.ep_group)
420+
yield from process_func(name, stacked_tensor)
421+
422+
del stacked_tensor
423+
else:
424+
if is_expert_layer:
425+
yield from process_func(name, unsharded_tensor)
426+
else:
427+
yield name, unsharded_tensor
428+
429+
# TODO: support VeOmni LoRA
430+
return param_generator(), None
431+
351432

352433
class EngineEvalModeCtx(BaseEngineCtx):
353434
def __init__(self, engine: VeOmniEngine, **kwargs):
@@ -382,6 +463,8 @@ def __enter__(self):
382463
assert isinstance(self.engine, VeOmniEngine)
383464
super().__enter__()
384465
self.engine.ulysses_sharding_manager.__enter__()
466+
# TODO: Switch to eval mode after Integrating the CI environment
467+
# VeOmni (ref: https://github.com/ByteDance-Seed/VeOmni/pull/421)
385468
self.engine.module.train()
386469

387470
def __exit__(self, exc_type, exc_value, traceback):

0 commit comments

Comments
 (0)