Skip to content

Commit 0b0769e

Browse files
Kite0011gemini-code-assist[bot]ISEEKYAN
authored
[megatron] feat: use yaml to manage mbridge args (verl-project#4584)
### What does this PR do? Use dict to manage mbridge args, support modifying mbridge input args, and avoid mbridge version compatibility issues. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yan Bai <[email protected]>
1 parent 7cf31bf commit 0b0769e

File tree

7 files changed

+34
-6
lines changed

7 files changed

+34
-6
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ actor_rollout_ref:
107107
- extra
108108
load_contents: ${.save_contents}
109109
async_save: false
110+
mbridge_config: {}
110111
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
111112
profiler:
112113
_target_: verl.utils.profiler.ProfilerConfig
@@ -551,6 +552,7 @@ critic:
551552
- extra
552553
load_contents: ${.save_contents}
553554
async_save: false
555+
mbridge_config: {}
554556
profiler:
555557
_target_: verl.utils.profiler.ProfilerConfig
556558
tool: ${oc.select:global_profiler.tool,null}

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ actor_rollout_ref:
8787
- extra
8888
load_contents: ${.save_contents}
8989
async_save: false
90+
mbridge_config: {}
9091
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
9192
profiler:
9293
_target_: verl.utils.profiler.ProfilerConfig
@@ -469,6 +470,7 @@ critic:
469470
- extra
470471
load_contents: ${.save_contents}
471472
async_save: false
473+
mbridge_config: {}
472474
profiler:
473475
_target_: verl.utils.profiler.ProfilerConfig
474476
tool: ${oc.select:global_profiler.tool,null}

verl/trainer/config/actor/actor.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ checkpoint:
132132

133133
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
134134
async_save: False
135+
136+
# Mbridge config extension.
137+
# when vanilla_mbridge=True, and your filesystem is a distributed filesystem,(which means you write a file in node A
138+
# and you can read the file in node B immediately)
139+
# set `mbridge_config.distributed_filesystem=True` and `mbridge_config.memory_efficient=True` to
140+
# speed up the checkpoint saving by 10x speed.
141+
mbridge_config: {}
135142

136143
# optimizer configs
137144
optim:

verl/trainer/config/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class CheckpointConfig(BaseConfig):
3636
save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
3737
load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
3838
async_save: bool = False
39+
mbridge_config: dict[str, Any] = field(default_factory=dict)
3940

4041

4142
@dataclass

verl/trainer/config/critic/critic.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ checkpoint:
9898
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
9999
async_save: False
100100

101+
# Mbridge config extension.
102+
mbridge_config: {}
103+
101104
# profile the critic model in `update_critic`
102105
profiler:
103106

verl/trainer/config/sft_trainer_engine.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ checkpoint:
5555

5656
# For more flexibility, you can specify the contents to load from the checkpoint.
5757
load_contents: ${checkpoint.save_contents}
58+
# Mbridge config extension.
59+
mbridge_config: {}
5860

5961
trainer:
6062
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import json
1617
import logging
1718
import os
@@ -578,9 +579,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
578579
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
579580
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
580581
if self.vanilla_bridge:
581-
self.bridge.save_weights(
582-
self.model, hf_ckpt_path, distributed_filesystem=True, memory_efficient=True
583-
)
582+
extended_args = {}
583+
mbridge_config = getattr(self.checkpoint_config, "mbridge_config", None) or {}
584+
for sig in inspect.signature(self.bridge.save_weights).parameters:
585+
if sig == "weights_path" or sig == "models":
586+
continue
587+
if sig in mbridge_config:
588+
extended_args[sig] = mbridge_config[sig]
589+
self.bridge.save_weights(self.model, hf_ckpt_path, **extended_args)
584590
else:
585591
self.bridge.save_hf_weights(self.model, hf_ckpt_path)
586592

@@ -651,9 +657,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
651657
if self.bridge is not None:
652658
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
653659
if self.vanilla_bridge:
654-
self.bridge.save_weights(
655-
self.model, hf_model_ckpt_path, distributed_filesystem=True, memory_efficient=True
656-
)
660+
extended_args = {}
661+
mbridge_config = getattr(self.checkpoint_config, "mbridge_config", None) or {}
662+
for sig in inspect.signature(self.bridge.save_weights).parameters:
663+
if sig == "weights_path" or sig == "models":
664+
continue
665+
if sig in mbridge_config:
666+
extended_args[sig] = mbridge_config[sig]
667+
self.bridge.save_weights(self.model, hf_model_ckpt_path, **extended_args)
657668
else:
658669
self.bridge.save_hf_weights(self.model, hf_model_ckpt_path)
659670
else:

0 commit comments

Comments
 (0)