Skip to content

Commit 0102d04

Browse files
authored
[trainer] feat: add reward loop config to default config (verl-project#4452)
### What does this PR do? Future PRs will transfer from legacy rm implementation to reward loop (in both rule-based, genrm, disrm, ...) gradually; this PR adds reward loop configs to defaults, which inherit the legacy reward model config, so it will not break any current api. Specifically, future PRs will: - align results between reward loop disrm and legacy fsdp/megatron disrm - deprecate fsdp/megatron disrm, use reward loop disrm as default - use reward loop rule-based, disrm-based, genrm-based as default - deprecate legacy reward model config ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent f332fc8 commit 0102d04

File tree

7 files changed

+145
-10
lines changed

7 files changed

+145
-10
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,15 @@ def __init__(
297297
self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template
298298
self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template
299299

300-
self.reward_manager_worker = RewardLoopWorker.options(
301-
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
302-
node_id=ray.get_runtime_context().get_node_id(),
303-
soft=False,
304-
),
305-
).remote(self.config, self.reward_router_address)
300+
use_reward_loop = True if self.config.reward_model.use_reward_loop else None
301+
self.use_reward_loop = use_reward_loop
302+
if use_reward_loop:
303+
self.reward_loop_worker = RewardLoopWorker.options(
304+
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
305+
node_id=ray.get_runtime_context().get_node_id(),
306+
soft=False,
307+
),
308+
).remote(self.config, self.reward_router_address)
306309

307310
trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
308311
RolloutTraceConfig.init(
@@ -551,7 +554,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
551554
enable_async_reward = (
552555
self.reward_router_address is not None and self.config.reward_model.enable_resource_pool
553556
) or not self.config.reward_model.enable
554-
if output.reward_score is None and enable_async_reward:
557+
if output.reward_score is None and enable_async_reward and self.use_reward_loop:
555558
batch = TensorDict(
556559
{
557560
"prompts": prompt_output["input_ids"], # [1, prompt_length]
@@ -572,7 +575,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
572575
batch=batch,
573576
non_tensor_batch=non_tensor_batch,
574577
)
575-
result = await self.reward_manager_worker.compute_score.remote(data)
578+
result = await self.reward_loop_worker.compute_score.remote(data)
576579
output.reward_score = result["reward_score"]
577580
output.extra_fields["reward_extra_info"] = result["reward_extra_info"]
578581

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,30 @@ reward_model:
571571
use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True}
572572
dtype: bfloat16
573573
load_weight: true
574+
use_reward_loop: true
575+
rollout:
576+
_target_: verl.workers.config.RolloutConfig
577+
name: ???
578+
dtype: bfloat16
579+
gpu_memory_utilization: 0.5
580+
enforce_eager: true
581+
cudagraph_capture_sizes: null
582+
free_cache_engine: true
583+
data_parallel_size: 1
584+
expert_parallel_size: 1
585+
tensor_model_parallel_size: 2
586+
max_num_batched_tokens: 8192
587+
max_model_len: null
588+
max_num_seqs: 1024
589+
load_format: auto
590+
engine_kwargs: {}
591+
limit_images: null
592+
enable_chunked_prefill: true
593+
enable_prefix_caching: true
594+
disable_log_stats: true
595+
skip_tokenizer_init: true
596+
prompt_length: 512
597+
response_length: 512
574598
algorithm:
575599
rollout_correction:
576600
rollout_is: null

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,30 @@ reward_model:
495495
save_path: ${oc.select:global_profiler.save_path,null}
496496
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
497497
ulysses_sequence_parallel_size: 1
498+
use_reward_loop: true
499+
rollout:
500+
_target_: verl.workers.config.RolloutConfig
501+
name: ???
502+
dtype: bfloat16
503+
gpu_memory_utilization: 0.5
504+
enforce_eager: true
505+
cudagraph_capture_sizes: null
506+
free_cache_engine: true
507+
data_parallel_size: 1
508+
expert_parallel_size: 1
509+
tensor_model_parallel_size: 2
510+
max_num_batched_tokens: 8192
511+
max_model_len: null
512+
max_num_seqs: 1024
513+
load_format: auto
514+
engine_kwargs: {}
515+
limit_images: null
516+
enable_chunked_prefill: true
517+
enable_prefix_caching: true
518+
disable_log_stats: true
519+
skip_tokenizer_init: true
520+
prompt_length: 512
521+
response_length: 512
498522
algorithm:
499523
rollout_correction:
500524
rollout_is: null

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ defaults:
1818
# Critic model config.
1919
- critic@critic: megatron_critic
2020
# Reward model config.
21-
- reward_model@reward_model: megatron_reward_model
21+
- reward_model@reward_model: megatron_reward_loop
2222
# Rollout correction config.
2323
- algorithm@algorithm.rollout_correction: rollout_correction
2424
- _self_

verl/trainer/config/ppo_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ defaults:
3131
- critic@critic: dp_critic
3232

3333
# Reward model config.
34-
- reward_model@reward_model: dp_reward_model
34+
- reward_model@reward_model: dp_reward_loop
3535

3636
# Rollout correction config.
3737
- algorithm@algorithm.rollout_correction: rollout_correction
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
defaults:
2+
- dp_reward_model
3+
- _self_
4+
5+
use_reward_loop: True
6+
reward_manager: naive
7+
enable: False
8+
9+
# Whether to deploy the model to a separate resource pool.
10+
enable_resource_pool: False
11+
n_gpus_per_node: 0
12+
nnodes: 0
13+
14+
model:
15+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
16+
external_lib: ${actor_rollout_ref.model.external_lib}
17+
trust_remote_code: False
18+
19+
rollout:
20+
_target_: verl.workers.config.RolloutConfig
21+
name: ???
22+
dtype: bfloat16
23+
gpu_memory_utilization: 0.5
24+
enforce_eager: true
25+
cudagraph_capture_sizes: null
26+
free_cache_engine: true
27+
data_parallel_size: 1
28+
expert_parallel_size: 1
29+
tensor_model_parallel_size: 2
30+
max_num_batched_tokens: 8192
31+
max_model_len: null
32+
max_num_seqs: 1024
33+
load_format: auto
34+
engine_kwargs: {}
35+
limit_images: null
36+
enable_chunked_prefill: true
37+
enable_prefix_caching: true
38+
disable_log_stats: true
39+
skip_tokenizer_init: true
40+
41+
prompt_length: 512
42+
response_length: 512
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
defaults:
2+
- megatron_reward_model
3+
- _self_
4+
5+
use_reward_loop: True
6+
reward_manager: naive
7+
enable: False
8+
9+
# Whether to deploy the model to a separate resource pool.
10+
enable_resource_pool: False
11+
n_gpus_per_node: 0
12+
nnodes: 0
13+
14+
model:
15+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
16+
external_lib: ${actor_rollout_ref.model.external_lib}
17+
trust_remote_code: False
18+
19+
rollout:
20+
_target_: verl.workers.config.RolloutConfig
21+
name: ???
22+
dtype: bfloat16
23+
gpu_memory_utilization: 0.5
24+
enforce_eager: true
25+
cudagraph_capture_sizes: null
26+
free_cache_engine: true
27+
data_parallel_size: 1
28+
expert_parallel_size: 1
29+
tensor_model_parallel_size: 2
30+
max_num_batched_tokens: 8192
31+
max_model_len: null
32+
max_num_seqs: 1024
33+
load_format: auto
34+
engine_kwargs: {}
35+
limit_images: null
36+
enable_chunked_prefill: true
37+
enable_prefix_caching: true
38+
disable_log_stats: true
39+
skip_tokenizer_init: true
40+
41+
prompt_length: 512
42+
response_length: 512

0 commit comments

Comments
 (0)