Skip to content

Commit d7c82bd

Browse files
authored
[megatron] chore: clean legacy code path part 1, make engine use mbridge by default (#4528)
### What does this PR do? this is one of a series PRs to clean the legacy megatron code path and make bridge the default path for megatron. #4496 This PR make megatron engine use only support bridge, and mark the legacy code in `mcore/registry.py` ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 2fd6591 commit d7c82bd

File tree

8 files changed

+117
-123
lines changed

8 files changed

+117
-123
lines changed

recipe/open_math_reasoning/run_sft_qwen3_8b.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ MEGATRON_ENGINE_CONFIG="\
5555
engine.pipeline_model_parallel_size=${PP_SIZE} \
5656
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
5757
engine.context_parallel_size=${CP_SIZE} \
58-
engine.use_mbridge=False"
58+
engine.use_mbridge=True"
5959

6060
if [ "$backend" = "fsdp" ]; then
6161
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"

tests/trainer/config/legacy_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ actor_rollout_ref:
111111
dist_checkpointing_path: null
112112
seed: 42
113113
override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
114-
use_mbridge: False
114+
use_mbridge: True
115115
vanilla_mbridge: True
116116
profile: # profile the actor model in `update_policy`
117117
use_profile: False # open it when you want to profile the actor model

verl/models/mcore/registry.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,54 @@
2222
import torch
2323
import torch.nn as nn
2424

25+
from .model_forward import gptmodel_forward_no_padding, model_forward_gen
26+
from .model_forward_fused import fused_forward_model_gen
27+
28+
29+
class SupportedVLM(Enum):
30+
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
31+
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
32+
QWEN3_VL = "Qwen3VLForConditionalGeneration"
33+
34+
35+
def get_mcore_forward_fn(hf_config) -> Callable:
36+
"""
37+
Get the forward function for given model architecture.
38+
"""
39+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
40+
if hf_config.architectures[0] in SupportedVLM:
41+
return model_forward_gen(True)
42+
else:
43+
# default to language model
44+
return model_forward_gen(False)
45+
46+
47+
def get_mcore_forward_no_padding_fn(hf_config) -> Callable:
48+
"""
49+
Get the forward function for given model architecture.
50+
"""
51+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
52+
return gptmodel_forward_no_padding
53+
54+
55+
def get_mcore_forward_fused_fn(hf_config) -> Callable:
56+
"""
57+
Get the forward function for given model architecture.
58+
"""
59+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
60+
if hf_config.architectures[0] in SupportedVLM:
61+
return fused_forward_model_gen(True)
62+
else:
63+
# default to language model
64+
return fused_forward_model_gen(False)
65+
66+
67+
# ruff: noqa
68+
69+
########################################################
70+
# below is the deprecated code
71+
########################################################
72+
2573
from .config_converter import (
2674
PretrainedConfig,
2775
TransformerConfig,
@@ -33,8 +81,6 @@
3381
hf_to_mcore_config_qwen2moe,
3482
hf_to_mcore_config_qwen3moe,
3583
)
36-
from .model_forward import gptmodel_forward_no_padding, model_forward_gen
37-
from .model_forward_fused import fused_forward_model_gen
3884
from .model_initializer import (
3985
BaseModelInitializer,
4086
DeepseekV3Model,
@@ -239,33 +285,6 @@ def init_mcore_model(
239285
)
240286

241287

242-
def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
243-
"""
244-
Get the forward function for given model architecture.
245-
"""
246-
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
247-
model = get_supported_model(hf_config.architectures[0])
248-
return MODEL_FORWARD_REGISTRY[model]
249-
250-
251-
def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:
252-
"""
253-
Get the forward function for given model architecture.
254-
"""
255-
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
256-
model = get_supported_model(hf_config.architectures[0])
257-
return MODEL_FORWARD_NOPAD_REGISTRY[model]
258-
259-
260-
def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
261-
"""
262-
Get the forward function for given model architecture.
263-
"""
264-
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
265-
model = get_supported_model(hf_config.architectures[0])
266-
return MODEL_FORWARD_FUSED_REGISTRY[model]
267-
268-
269288
def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
270289
"""
271290
Get the weight converter for given model architecture.

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ actor_rollout_ref:
5252
recompute_num_layers: null
5353
attention_backend: flash
5454
override_mcore_model_config: {}
55-
use_mbridge: false
55+
use_mbridge: true
5656
vanilla_mbridge: true
5757
use_remove_padding: true
5858
forward_only: false
@@ -433,7 +433,7 @@ critic:
433433
recompute_num_layers: null
434434
attention_backend: flash
435435
override_mcore_model_config: {}
436-
use_mbridge: false
436+
use_mbridge: true
437437
vanilla_mbridge: true
438438
use_remove_padding: true
439439
forward_only: false

verl/trainer/config/engine/megatron.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ override_transformer_config:
7575
override_mcore_model_config: {}
7676

7777
# oc.select: default val for ref.megatron.use_mbridge
78-
use_mbridge: False
78+
use_mbridge: True
7979

8080
# oc.select: default val for ref.megatron.vanilla_mbridge
8181
vanilla_mbridge: True

verl/utils/megatron_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,3 +1220,11 @@ def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer):
12201220
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
12211221
if len(model) == 1:
12221222
config.param_sync_func = config.param_sync_func[0]
1223+
1224+
1225+
def mapping_string_to_attn_backend(args: dict) -> dict:
1226+
if "attention_backend" in args and isinstance(args["attention_backend"], str):
1227+
from megatron.core.transformer.enums import AttnBackend
1228+
1229+
args["attention_backend"] = AttnBackend[args["attention_backend"]]
1230+
return args

verl/workers/config/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class McoreEngineConfig(EngineConfig):
120120
override_ddp_config: dict[str, Any] = field(default_factory=dict)
121121
override_transformer_config: dict[str, Any] = field(default_factory=dict)
122122
override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
123-
use_mbridge: bool = False
123+
use_mbridge: bool = True
124124
vanilla_mbridge: bool = True
125125
strategy: str = "megatron"
126126

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 55 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,11 @@
4141
load_megatron_optimizer,
4242
offload_megatron_model_to_cpu,
4343
offload_megatron_optimizer,
44-
per_tensor_generator,
4544
register_megatron_training_hooks,
4645
)
4746
from verl.utils.model import (
4847
extract_multi_modal_inputs_tensordict,
4948
load_mcore_dist_weights,
50-
load_megatron_gptmodel_weights,
5149
)
5250
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
5351

@@ -76,7 +74,7 @@ def __init__(
7674
self.engine_config = engine_config
7775
self.optimizer_config = optimizer_config
7876
self.checkpoint_config = checkpoint_config
79-
77+
assert self.engine_config.use_mbridge, "use_mbridge must be True"
8078
self._init_device_mesh()
8179

8280
set_random_seed(seed=self.engine_config.seed)
@@ -110,70 +108,62 @@ def _init_device_mesh(self):
110108
)
111109

112110
def _build_tf_config(self):
113-
from verl.models.mcore import hf_to_mcore_config
114-
from verl.models.mcore.config_converter import mapping_string_to_attn_backend
111+
from verl.utils.megatron_utils import mapping_string_to_attn_backend
115112
from verl.utils.torch_dtypes import PrecisionType
116113

117114
self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype)
118-
if self.param_dtype == torch.float16:
119-
assert self.engine_config.use_mbridge, "fp16 mode requires use_mbridge to be True"
120115
self.dtype = PrecisionType.to_dtype(self.param_dtype)
121116

122117
override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config})
123118

124-
use_mbridge = self.engine_config.use_mbridge
125119
self.provider = None
126120
self.vanilla_bridge = self.engine_config.vanilla_mbridge
127-
if use_mbridge:
128-
if self.vanilla_bridge:
129-
from verl.models.mcore.mbridge import AutoBridge
130-
131-
bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype)
132-
bridge.set_extra_args(**override_transformer_config)
133-
tf_config = bridge.config
134-
tf_config.fp16 = self.param_dtype == torch.float16
135-
tf_config.bf16 = self.param_dtype == torch.bfloat16
136-
else:
137-
from verl.models.mcore.bridge import AutoBridge
138-
139-
# Use Megatron-Bridge to convert HF config to Megatron config
140-
bridge = AutoBridge.from_hf_pretrained(
141-
self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code
142-
)
143-
# Get Megatron provider and configure it
144-
provider = bridge.to_megatron_provider(load_weights=False)
145-
146-
# In case of invalid overrides, we need to make sure some critical params are set correctly
147-
provider.params_dtype = self.param_dtype
148-
149-
# Pass distributed info
150-
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
151-
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
152-
provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size
153-
provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size
154-
provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size
155-
provider.context_parallel_size = self.engine_config.context_parallel_size
156-
provider.sequence_parallel = self.engine_config.sequence_parallel
157-
158-
# Match verl implementation (need variable_seq_lengths)
159-
from megatron.core.transformer.enums import AttnBackend
160-
161-
provider.attention_backend = AttnBackend.flash
162-
provider.variable_seq_lengths = True
163-
provider.moe_token_dispatcher_type = "alltoall"
164-
provider.moe_router_load_balancing_type = "none"
165-
166-
# Apply transformer config overrides
167-
for key, value in override_transformer_config.items():
168-
setattr(provider, key, value)
169-
170-
provider.finalize()
171-
self.provider = provider
172-
tf_config = None # Will be set after model creation
173-
self.bridge = bridge
121+
if self.vanilla_bridge:
122+
from verl.models.mcore.mbridge import AutoBridge
123+
124+
bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype)
125+
bridge.set_extra_args(**override_transformer_config)
126+
tf_config = bridge.config
127+
tf_config.fp16 = self.param_dtype == torch.float16
128+
tf_config.bf16 = self.param_dtype == torch.bfloat16
174129
else:
175-
self.bridge = None
176-
tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config)
130+
from verl.models.mcore.bridge import AutoBridge
131+
132+
# Use Megatron-Bridge to convert HF config to Megatron config
133+
bridge = AutoBridge.from_hf_pretrained(
134+
self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code
135+
)
136+
# Get Megatron provider and configure it
137+
provider = bridge.to_megatron_provider(load_weights=False)
138+
139+
# In case of invalid overrides, we need to make sure some critical params are set correctly
140+
provider.params_dtype = self.param_dtype
141+
142+
# Pass distributed info
143+
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
144+
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
145+
provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size
146+
provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size
147+
provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size
148+
provider.context_parallel_size = self.engine_config.context_parallel_size
149+
provider.sequence_parallel = self.engine_config.sequence_parallel
150+
151+
# Match verl implementation (need variable_seq_lengths)
152+
from megatron.core.transformer.enums import AttnBackend
153+
154+
provider.attention_backend = AttnBackend.flash
155+
provider.variable_seq_lengths = True
156+
provider.moe_token_dispatcher_type = "alltoall"
157+
provider.moe_router_load_balancing_type = "none"
158+
159+
# Apply transformer config overrides
160+
for key, value in override_transformer_config.items():
161+
setattr(provider, key, value)
162+
163+
provider.finalize()
164+
self.provider = provider
165+
tf_config = None # Will be set after model creation
166+
self.bridge = bridge
177167

178168
if not self.bridge:
179169
self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)
@@ -232,28 +222,14 @@ def _build_megatron_module(self):
232222
if self.engine_config.use_dist_checkpointing:
233223
load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model)
234224
else:
235-
if self.bridge is not None:
236-
if self.vanilla_bridge:
237-
self.bridge.load_weights(module, self.model_config.local_path)
238-
else:
239-
allowed_mismatched_params = []
240-
if self.is_value_model:
241-
allowed_mismatched_params = ["output_layer.weight"]
242-
self.bridge.load_hf_weights(
243-
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
244-
)
225+
if self.vanilla_bridge:
226+
self.bridge.load_weights(module, self.model_config.local_path)
245227
else:
246-
# (vermouth1992) this is a workaround to be compatible with the old API
247-
tmp_config = OmegaConf.create(
248-
{"model": {"path": self.model_config.local_path, "use_shm": self.model_config.use_shm}}
249-
)
250-
251-
load_megatron_gptmodel_weights(
252-
tmp_config,
253-
self.model_config.hf_config,
254-
module,
255-
params_dtype=self.dtype,
256-
is_value_model=is_value_model,
228+
allowed_mismatched_params = []
229+
if self.is_value_model:
230+
allowed_mismatched_params = ["output_layer.weight"]
231+
self.bridge.load_hf_weights(
232+
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
257233
)
258234

259235
if torch.distributed.get_rank() == 0:
@@ -562,16 +538,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
562538
def get_per_tensor_param(self):
563539
if self._is_offload_param:
564540
load_megatron_model_to_gpu(self.module, load_grad=False)
565-
if self.bridge is not None:
566-
per_tensor_param = self.bridge.export_weights(self.module)
567-
else:
568-
per_tensor_param = per_tensor_generator(
569-
self.module,
570-
self.model_config.hf_config,
571-
self.weight_converter,
572-
self.tf_config,
573-
self.layer_name_mapping,
574-
)
541+
per_tensor_param = self.bridge.export_weights(self.module)
575542
# TODO: support megatron LoRA
576543
return per_tensor_param, None
577544

0 commit comments

Comments
 (0)