|
41 | 41 | load_megatron_optimizer, |
42 | 42 | offload_megatron_model_to_cpu, |
43 | 43 | offload_megatron_optimizer, |
44 | | - per_tensor_generator, |
45 | 44 | register_megatron_training_hooks, |
46 | 45 | ) |
47 | 46 | from verl.utils.model import ( |
48 | 47 | extract_multi_modal_inputs_tensordict, |
49 | 48 | load_mcore_dist_weights, |
50 | | - load_megatron_gptmodel_weights, |
51 | 49 | ) |
52 | 50 | from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig |
53 | 51 |
|
@@ -76,7 +74,7 @@ def __init__( |
76 | 74 | self.engine_config = engine_config |
77 | 75 | self.optimizer_config = optimizer_config |
78 | 76 | self.checkpoint_config = checkpoint_config |
79 | | - |
| 77 | + assert self.engine_config.use_mbridge, "use_mbridge must be True" |
80 | 78 | self._init_device_mesh() |
81 | 79 |
|
82 | 80 | set_random_seed(seed=self.engine_config.seed) |
@@ -110,70 +108,62 @@ def _init_device_mesh(self): |
110 | 108 | ) |
111 | 109 |
|
112 | 110 | def _build_tf_config(self): |
113 | | - from verl.models.mcore import hf_to_mcore_config |
114 | | - from verl.models.mcore.config_converter import mapping_string_to_attn_backend |
| 111 | + from verl.utils.megatron_utils import mapping_string_to_attn_backend |
115 | 112 | from verl.utils.torch_dtypes import PrecisionType |
116 | 113 |
|
117 | 114 | self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype) |
118 | | - if self.param_dtype == torch.float16: |
119 | | - assert self.engine_config.use_mbridge, "fp16 mode requires use_mbridge to be True" |
120 | 115 | self.dtype = PrecisionType.to_dtype(self.param_dtype) |
121 | 116 |
|
122 | 117 | override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) |
123 | 118 |
|
124 | | - use_mbridge = self.engine_config.use_mbridge |
125 | 119 | self.provider = None |
126 | 120 | self.vanilla_bridge = self.engine_config.vanilla_mbridge |
127 | | - if use_mbridge: |
128 | | - if self.vanilla_bridge: |
129 | | - from verl.models.mcore.mbridge import AutoBridge |
130 | | - |
131 | | - bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) |
132 | | - bridge.set_extra_args(**override_transformer_config) |
133 | | - tf_config = bridge.config |
134 | | - tf_config.fp16 = self.param_dtype == torch.float16 |
135 | | - tf_config.bf16 = self.param_dtype == torch.bfloat16 |
136 | | - else: |
137 | | - from verl.models.mcore.bridge import AutoBridge |
138 | | - |
139 | | - # Use Megatron-Bridge to convert HF config to Megatron config |
140 | | - bridge = AutoBridge.from_hf_pretrained( |
141 | | - self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code |
142 | | - ) |
143 | | - # Get Megatron provider and configure it |
144 | | - provider = bridge.to_megatron_provider(load_weights=False) |
145 | | - |
146 | | - # In case of invalid overrides, we need to make sure some critical params are set correctly |
147 | | - provider.params_dtype = self.param_dtype |
148 | | - |
149 | | - # Pass distributed info |
150 | | - provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size |
151 | | - provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size |
152 | | - provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size |
153 | | - provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size |
154 | | - provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size |
155 | | - provider.context_parallel_size = self.engine_config.context_parallel_size |
156 | | - provider.sequence_parallel = self.engine_config.sequence_parallel |
157 | | - |
158 | | - # Match verl implementation (need variable_seq_lengths) |
159 | | - from megatron.core.transformer.enums import AttnBackend |
160 | | - |
161 | | - provider.attention_backend = AttnBackend.flash |
162 | | - provider.variable_seq_lengths = True |
163 | | - provider.moe_token_dispatcher_type = "alltoall" |
164 | | - provider.moe_router_load_balancing_type = "none" |
165 | | - |
166 | | - # Apply transformer config overrides |
167 | | - for key, value in override_transformer_config.items(): |
168 | | - setattr(provider, key, value) |
169 | | - |
170 | | - provider.finalize() |
171 | | - self.provider = provider |
172 | | - tf_config = None # Will be set after model creation |
173 | | - self.bridge = bridge |
| 121 | + if self.vanilla_bridge: |
| 122 | + from verl.models.mcore.mbridge import AutoBridge |
| 123 | + |
| 124 | + bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) |
| 125 | + bridge.set_extra_args(**override_transformer_config) |
| 126 | + tf_config = bridge.config |
| 127 | + tf_config.fp16 = self.param_dtype == torch.float16 |
| 128 | + tf_config.bf16 = self.param_dtype == torch.bfloat16 |
174 | 129 | else: |
175 | | - self.bridge = None |
176 | | - tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config) |
| 130 | + from verl.models.mcore.bridge import AutoBridge |
| 131 | + |
| 132 | + # Use Megatron-Bridge to convert HF config to Megatron config |
| 133 | + bridge = AutoBridge.from_hf_pretrained( |
| 134 | + self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code |
| 135 | + ) |
| 136 | + # Get Megatron provider and configure it |
| 137 | + provider = bridge.to_megatron_provider(load_weights=False) |
| 138 | + |
| 139 | + # In case of invalid overrides, we need to make sure some critical params are set correctly |
| 140 | + provider.params_dtype = self.param_dtype |
| 141 | + |
| 142 | + # Pass distributed info |
| 143 | + provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size |
| 144 | + provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size |
| 145 | + provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size |
| 146 | + provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size |
| 147 | + provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size |
| 148 | + provider.context_parallel_size = self.engine_config.context_parallel_size |
| 149 | + provider.sequence_parallel = self.engine_config.sequence_parallel |
| 150 | + |
| 151 | + # Match verl implementation (need variable_seq_lengths) |
| 152 | + from megatron.core.transformer.enums import AttnBackend |
| 153 | + |
| 154 | + provider.attention_backend = AttnBackend.flash |
| 155 | + provider.variable_seq_lengths = True |
| 156 | + provider.moe_token_dispatcher_type = "alltoall" |
| 157 | + provider.moe_router_load_balancing_type = "none" |
| 158 | + |
| 159 | + # Apply transformer config overrides |
| 160 | + for key, value in override_transformer_config.items(): |
| 161 | + setattr(provider, key, value) |
| 162 | + |
| 163 | + provider.finalize() |
| 164 | + self.provider = provider |
| 165 | + tf_config = None # Will be set after model creation |
| 166 | + self.bridge = bridge |
177 | 167 |
|
178 | 168 | if not self.bridge: |
179 | 169 | self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) |
@@ -232,28 +222,14 @@ def _build_megatron_module(self): |
232 | 222 | if self.engine_config.use_dist_checkpointing: |
233 | 223 | load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model) |
234 | 224 | else: |
235 | | - if self.bridge is not None: |
236 | | - if self.vanilla_bridge: |
237 | | - self.bridge.load_weights(module, self.model_config.local_path) |
238 | | - else: |
239 | | - allowed_mismatched_params = [] |
240 | | - if self.is_value_model: |
241 | | - allowed_mismatched_params = ["output_layer.weight"] |
242 | | - self.bridge.load_hf_weights( |
243 | | - module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params |
244 | | - ) |
| 225 | + if self.vanilla_bridge: |
| 226 | + self.bridge.load_weights(module, self.model_config.local_path) |
245 | 227 | else: |
246 | | - # (vermouth1992) this is a workaround to be compatible with the old API |
247 | | - tmp_config = OmegaConf.create( |
248 | | - {"model": {"path": self.model_config.local_path, "use_shm": self.model_config.use_shm}} |
249 | | - ) |
250 | | - |
251 | | - load_megatron_gptmodel_weights( |
252 | | - tmp_config, |
253 | | - self.model_config.hf_config, |
254 | | - module, |
255 | | - params_dtype=self.dtype, |
256 | | - is_value_model=is_value_model, |
| 228 | + allowed_mismatched_params = [] |
| 229 | + if self.is_value_model: |
| 230 | + allowed_mismatched_params = ["output_layer.weight"] |
| 231 | + self.bridge.load_hf_weights( |
| 232 | + module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params |
257 | 233 | ) |
258 | 234 |
|
259 | 235 | if torch.distributed.get_rank() == 0: |
@@ -562,16 +538,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw |
562 | 538 | def get_per_tensor_param(self): |
563 | 539 | if self._is_offload_param: |
564 | 540 | load_megatron_model_to_gpu(self.module, load_grad=False) |
565 | | - if self.bridge is not None: |
566 | | - per_tensor_param = self.bridge.export_weights(self.module) |
567 | | - else: |
568 | | - per_tensor_param = per_tensor_generator( |
569 | | - self.module, |
570 | | - self.model_config.hf_config, |
571 | | - self.weight_converter, |
572 | | - self.tf_config, |
573 | | - self.layer_name_mapping, |
574 | | - ) |
| 541 | + per_tensor_param = self.bridge.export_weights(self.module) |
575 | 542 | # TODO: support megatron LoRA |
576 | 543 | return per_tensor_param, None |
577 | 544 |
|
|
0 commit comments