Commit 94f4654
authored
[fsdp, megatron] fix: Engine Rollout Worker LoRA Parameter Update (verl-project#4836)
### What does this PR do?
Updating vLLM LoRA weights raises an error for both FSDP and Megatron
workers.
**FSDP**: error due to not passing arguments for layered summon
**Megatron**: error due to using incorrect export weights method for
class `AutoBridge`
### Test
#### FSDP
_Script_:
```bash
LEGACY_MODE='disable'
export CUDA_VISIBLE_DEVICES=0,1
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_PATH/gsm8k/train.parquet \
data.val_files=$DATA_PATH/gsm8k/test.parquet \
data.train_batch_size=2 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.rollout.agent.num_workers=2 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.model.lora_rank=64 \
actor_rollout_ref.model.lora_alpha=32 \
actor_rollout_ref.actor.optim.lr=3e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=2 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console"]' \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2.5_3b_grpo_lora' \
trainer.n_gpus_per_node=2 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.use_legacy_worker_impl=$LEGACY_MODE \
trainer.total_epochs=15 \
actor_rollout_ref.actor.use_torch_compile=False \
actor_rollout_ref.actor.fsdp_config.use_torch_compile=False \
trainer.val_before_train=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.ref.fsdp_config.use_torch_compile=False
```
_Error_:
```python
File "/home/jacob.a.helwig/verl/verl/workers/engine_workers.py", line 572, in wake_up
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
File "/home/jacob.a.helwig/verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 252, in update_weights
self.inference_engine.worker.add_lora(lora_request)
File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 494, in add_lora
return self.model_runner.add_lora(lora_request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/v1/worker/lora_model_runner_mixin.py", line 171, in add_lora
return self.lora_manager.add_adapter(lora_request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/worker_manager.py", line 251, in add_adapter
lora = self._load_adapter(lora_request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/utils/vllm/utils.py", line 100, in hijack__load_adapter
lora = self._lora_model_cls.from_lora_tensors(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/models.py", line 135, in from_lora_tensors
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verl/lib/python3.12/site-packages/vllm/lora/utils.py", line 150, in parse_fine_tuned_lora_name
raise ValueError(f"{name} is unsupported LoRA weight")
ValueError: model.embed_tokens.weight is unsupported LoRA weight
```
#### Megatron
_Script_:
```bash
############################ Quick Config ############################
rollout_name="vllm" # sglang or vllm
project_name='verl_grpo_example_gsm8k_math'
exp_name='qwen2_7b_megatron_lora'
adv_estimator=grpo
max_prompt_length=1024
max_response_length=1024
train_prompt_bsz=2
############################ Paths ############################
gsm8k_train_path=$DATA_PATH/gsm8k/train.parquet
gsm8k_test_path=$DATA_PATH/gsm8k/test.parquet
train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"
############################ Parameter Groups ############################
DATA=(
data.train_files="$train_files"
data.val_files="$test_files"
data.max_prompt_length=$max_prompt_length
data.max_response_length=$max_response_length
data.train_batch_size=$train_prompt_bsz
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
)
MODEL=(
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct
actor_rollout_ref.model.lora.rank=64
actor_rollout_ref.model.lora.alpha=32
actor_rollout_ref.model.lora.lora_A_init_method=kaiming
# # Optional: Use canonical LoRA
# actor_rollout_ref.model.lora.type="canonical_lora"
# actor_rollout_ref.model.lora.target_modules='["linear_q","linear_k","linear_v","linear_proj","linear_fc1_up","linear_fc1_gate","linear_fc2"]'
# # Optional: Add dropout to LoRA layers
# actor_rollout_ref.model.lora.dropout=0.05
# actor_rollout_ref.model.lora.dropout_position=pre
)
ACTOR=(
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=2
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=5
actor_rollout_ref.actor.use_dynamic_bsz=True
actor_rollout_ref.actor.megatron.use_mbridge=True
actor_rollout_ref.actor.megatron.vanilla_mbridge=False
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1
actor_rollout_ref.actor.megatron.sequence_parallel=False
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.kl_loss_coef=0.001
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.entropy_coeff=0
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
)
ROLLOUT=(
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=5
actor_rollout_ref.rollout.tensor_model_parallel_size=1
actor_rollout_ref.rollout.name=$rollout_name
actor_rollout_ref.rollout.gpu_memory_utilization=0.6
actor_rollout_ref.rollout.n=4
)
REF=(
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=5
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1
actor_rollout_ref.ref.megatron.sequence_parallel=False
)
ALGORITHM=(
algorithm.adv_estimator=$adv_estimator
algorithm.use_kl_in_reward=False
)
TRAINER=(
trainer.logger='["console"]'
trainer.project_name=$project_name
trainer.experiment_name=$exp_name
trainer.n_gpus_per_node=2
trainer.nnodes=1
trainer.save_freq=20
trainer.test_freq=5
trainer.total_epochs=15
trainer.val_before_train=False
trainer.use_legacy_worker_impl=disable
)
############################ Launch ############################
python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name='ppo_megatron_trainer.yaml' \
"${DATA[@]}" \
"${ALGORITHM[@]}" \
"${MODEL[@]}" \
"${ROLLOUT[@]}" \
"${ACTOR[@]}" \
"${REF[@]}" \
"${TRAINER[@]}" \
"$@"
```
_Error_:
```python
File "/home/jacob.a.helwig/verl/verl/trainer/ppo/ray_trainer.py", line 1409, in fit
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 949, in generate_sequences
self.wake_up()
File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 997, in wake_up
self._run_all([replica.wake_up() for replica in self.rollout_replicas])
File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 1011, in _run_all
asyncio.run(run_all())
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/runners.py", line 195, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/home/jacob.a.helwig/verl/verl/experimental/agent_loop/agent_loop.py", line 1009, in run_all
await asyncio.gather(*tasks)
File "/home/jacob.a.helwig/verl/verl/workers/rollout/replica.py", line 200, in wake_up
await asyncio.gather(*[server.wake_up.remote() for server in self.servers])
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/tasks.py", line 684, in _wrap_awaitable
return await awaitable
^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(AttributeError): ray::vLLMHttpServer.wake_up() (pid=2329228, ip=10.55.149.115, actor_id=7c53a3aba97589c515d254ba01000000, repr=<verl.workers.rollout.vllm_rollout.vllm_async_server.vLLMHttpServer object at 0x75b0d61d8e00>)
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 556, in wake_up
await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers])
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/asyncio/tasks.py", line 684, in _wrap_awaitable
return await awaitable
^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(AttributeError): ray::WorkerDict.wake_up() (pid=2320299, ip=10.55.149.115, actor_id=bccd2231d5e9ef6dce81c33c01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x75eadfc7be00>)
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/miniconda3/envs/verlMega/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/single_controller/ray/base.py", line 848, in async_func
return await getattr(self.worker_dict[key], name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/single_controller/base/decorator.py", line 462, in async_inner
return await func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/utils/transferqueue_utils.py", line 319, in dummy_async_inner
output = await func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/workers/engine_workers.py", line 566, in wake_up
per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jacob.a.helwig/verl/verl/workers/engine/megatron/transformer_impl.py", line 541, in get_per_tensor_param
per_tensor_param = self.bridge.export_weights(self.module)
^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'AutoBridge' object has no attribute 'export_weights'. Did you mean: 'export_hf_weights'?. Did you mean: '_return_value'?
```1 parent e69998c commit 94f4654
File tree
2 files changed
+8
-3
lines changed- verl/workers
- engine/megatron
2 files changed
+8
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
538 | 538 | | |
539 | 539 | | |
540 | 540 | | |
541 | | - | |
| 541 | + | |
542 | 542 | | |
543 | | - | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
544 | 547 | | |
545 | 548 | | |
546 | 549 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
566 | 566 | | |
567 | 567 | | |
568 | 568 | | |
569 | | - | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
570 | 572 | | |
571 | 573 | | |
572 | 574 | | |
| |||
0 commit comments