Skip to content

Commit ec123e6

Browse files
authored
[megatron] feat: Support LoRA training with FP16 using Megatron-Bridge. (verl-project#4648)
### What does this PR do? > Support LoRA training with FP16 using Megatron-Bridge. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. <img width="1351" height="492" alt="image" src="https://github.com/user-attachments/assets/cfd55b68-00b9-4f36-9a0a-fc608073e216" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 536a978 commit ec123e6

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
pwd=`pwd`
4+
5+
rollout_mode="async"
6+
rollout_name="vllm" # sglang or vllm
7+
if [ "$rollout_mode" = "async" ]; then
8+
export VLLM_USE_V1=1
9+
return_raw_chat="True"
10+
fi
11+
12+
TP=${TP:-2}
13+
PP=${PP:-2}
14+
CP=${CP:-2}
15+
EP=${EP:-4}
16+
ETP=${ETP:-1}
17+
18+
ALL_OFFLOAD=${ALL_OFFLOAD:-True}
19+
20+
optimizer_offload_fraction=1.
21+
22+
dtype="float16" # ["bfloat16", "float16"]
23+
rollout_name="vllm"
24+
project_name='verl_grpo_example_gsm8k_math_fp16'
25+
exp_name='qwen3_30b_a3b_megatron_lora'
26+
adv_estimator=grpo
27+
28+
# Paths
29+
MODEL_PATH=$HOME/Qwen/Qwen3-30B-A3B-Instruct-2507
30+
CKPTS_DIR=${pwd}/ckpt/${exp_name}
31+
32+
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
33+
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
34+
35+
########################### Parameter Arrays ###########################
36+
37+
DATA=(
38+
data.train_files=${gsm8k_train_path}
39+
data.val_files=${gsm8k_test_path}
40+
data.train_batch_size=128
41+
data.max_prompt_length=1024
42+
data.max_response_length=1024
43+
data.truncation='error'
44+
data.filter_overlong_prompts=True
45+
data.shuffle=False
46+
data.return_raw_chat=$return_raw_chat
47+
data.filter_overlong_prompts_workers=128
48+
)
49+
50+
MODEL=(
51+
actor_rollout_ref.model.path=${MODEL_PATH}
52+
actor_rollout_ref.model.lora.rank=16
53+
actor_rollout_ref.model.lora.alpha=32
54+
actor_rollout_ref.model.lora.dtype=${dtype}
55+
actor_rollout_ref.model.use_fused_kernels=True
56+
)
57+
58+
ACTOR=(
59+
actor_rollout_ref.actor.optim.lr=3e-6
60+
actor_rollout_ref.actor.ppo_mini_batch_size=16
61+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2
62+
actor_rollout_ref.actor.megatron.use_mbridge=True
63+
actor_rollout_ref.actor.megatron.vanilla_mbridge=False
64+
actor_rollout_ref.actor.use_dynamic_bsz=True
65+
actor_rollout_ref.actor.use_kl_loss=True
66+
actor_rollout_ref.actor.kl_loss_coef=0.001
67+
actor_rollout_ref.actor.kl_loss_type=low_var_kl
68+
actor_rollout_ref.actor.entropy_coeff=0
69+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP}
70+
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP}
71+
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
72+
actor_rollout_ref.actor.megatron.context_parallel_size=${CP}
73+
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
74+
actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD}
75+
actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD}
76+
actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD}
77+
actor_rollout_ref.actor.megatron.dtype=${dtype}
78+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
79+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
80+
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
81+
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True
82+
+actor_rollout_ref.actor.megatron.override_ddp_config.grad_reduce_in_fp32=True
83+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction}
84+
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
85+
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
86+
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=${ALL_OFFLOAD}
87+
)
88+
89+
ROLLOUT=(
90+
actor_rollout_ref.rollout.tensor_model_parallel_size=8
91+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4
92+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True
93+
actor_rollout_ref.rollout.name=${rollout_name}
94+
actor_rollout_ref.rollout.gpu_memory_utilization=0.5
95+
actor_rollout_ref.rollout.enforce_eager=True
96+
actor_rollout_ref.rollout.free_cache_engine=True
97+
actor_rollout_ref.rollout.n=4
98+
actor_rollout_ref.rollout.dtype=${dtype}
99+
)
100+
101+
REF=(
102+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4
103+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True
104+
actor_rollout_ref.ref.megatron.dtype=${dtype}
105+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP}
106+
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP}
107+
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP}
108+
actor_rollout_ref.ref.megatron.context_parallel_size=${CP}
109+
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP}
110+
actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD}
111+
)
112+
113+
ALGORITHM=(
114+
algorithm.adv_estimator=${adv_estimator}
115+
)
116+
117+
TRAINER=(
118+
trainer.critic_warmup=0
119+
trainer.logger='["console","wandb"]'
120+
trainer.project_name=${project_name}
121+
trainer.experiment_name=${exp_name}
122+
trainer.n_gpus_per_node=8
123+
trainer.nnodes=1
124+
trainer.save_freq=20
125+
trainer.test_freq=5
126+
trainer.total_epochs=15
127+
trainer.val_before_train=False
128+
trainer.max_actor_ckpt_to_keep=1
129+
trainer.default_local_dir="${CKPTS_DIR}"
130+
trainer.log_val_generations=10
131+
)
132+
133+
########################### Launch ###########################
134+
135+
python3 -m verl.trainer.main_ppo \
136+
--config-path=config \
137+
--config-name='ppo_megatron_trainer.yaml' \
138+
"${DATA[@]}" \
139+
"${ALGORITHM[@]}" \
140+
"${MODEL[@]}" \
141+
"${ROLLOUT[@]}" \
142+
"${ACTOR[@]}" \
143+
"${REF[@]}" \
144+
"${TRAINER[@]}" \
145+
2>&1 | tee ${pwd}/log/${exp_name}_$(date +'%Y%m%d_%H%M%S').log

verl/utils/megatron_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ def peft_pre_wrap_hook(model):
269269
model = provider.provide_distributed_model(
270270
wrap_with_ddp=wrap_config.wrap_with_ddp,
271271
ddp_config=ddp_config,
272+
fp16=provider.fp16,
273+
bf16=provider.bf16,
272274
)
273275

274276
# Extract TransformerConfig from the created model

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def _build_tf_config(self):
152152
# In case of invalid overrides, we need to make sure some critical params are set correctly
153153
provider.params_dtype = self.param_dtype
154154

155+
# Ensure dtype settings propagate to Megatron-Bridge/TE
156+
provider.fp16 = self.param_dtype == torch.float16
157+
provider.bf16 = self.param_dtype == torch.bfloat16
158+
155159
# Pass distributed info
156160
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
157161
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size

verl/workers/megatron_workers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def _init_hf_config_and_tf_config(
199199
# In case of invalid overrides, we need to make sure some critical params are set correctly
200200
provider.params_dtype = dtype
201201

202+
# Ensure dtype settings propagate to Megatron-Bridge/TE
203+
provider.fp16 = fp16
204+
provider.bf16 = bf16
205+
202206
# Pass distributed info
203207
provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size
204208
provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size

0 commit comments

Comments
 (0)