Skip to content

Commit 145a28b

Browse files
authored
[worker] feat: New engine share actor and ref for LoRA (#4867)
### What does this PR do? Continuation of #4673, now sharing actor and ref for LoRA is also supported in new engine ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. Signed-off-by: Hollow Man <[email protected]>
1 parent f16b245 commit 145a28b

File tree

14 files changed

+54
-33
lines changed

14 files changed

+54
-33
lines changed

examples/split_placement/main_ppo_split.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from verl import DataProto
2525
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
26+
from verl.trainer.ppo.utils import need_reference_policy
2627
from verl.utils.reward_score import gsm8k, math_reward
2728

2829

@@ -171,7 +172,7 @@ def main_task(config):
171172
}
172173

173174
# use reference model
174-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
175+
if need_reference_policy(config):
175176
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
176177
mapping[Role.RefPolicy] = actor_rollout_ref_pool_id
177178

verl/experimental/fully_async_policy/fully_async_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from verl.experimental.fully_async_policy.fully_async_trainer import FullyAsyncTrainer
2727
from verl.experimental.fully_async_policy.message_queue import MessageQueue, MessageQueueClient
2828
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
29-
from verl.trainer.ppo.utils import Role
29+
from verl.trainer.ppo.utils import Role, need_reference_policy
3030
from verl.utils.fs import copy_to_local
3131

3232

@@ -122,7 +122,7 @@ def create_role_worker_mapping(config):
122122
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
123123

124124
# Add reference policy (if KL loss or reward is required)
125-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
125+
if need_reference_policy(config):
126126
role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker)
127127

128128
return role_worker_mapping, ray_worker_group_cls

verl/experimental/fully_async_policy/fully_async_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373

7474
self.role_worker_mapping = role_worker_mapping
7575
self.resource_pool_manager = resource_pool_manager
76-
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
76+
self.use_reference_policy = need_reference_policy(self.config)
7777
self.use_rm = need_reward_model(self.role_worker_mapping)
7878
self.use_critic = need_critic(self.config)
7979
self.ray_worker_group_cls = ray_worker_group_cls

verl/experimental/one_step_off_policy/main_ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def create_role_worker_mapping(config):
124124
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
125125

126126
# Add reference policy (if KL loss or reward is required)
127-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
127+
if need_reference_policy(config):
128128
role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker)
129129

130130
return role_worker_mapping, ray_worker_group_cls
@@ -151,7 +151,7 @@ def run(self, config):
151151
# validate config
152152
validate_config(
153153
config=config,
154-
use_reference_policy=need_reference_policy(role_worker_mapping),
154+
use_reference_policy=need_reference_policy(config),
155155
use_critic=need_critic(config),
156156
)
157157

verl/experimental/one_step_off_policy/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106

107107
self.role_worker_mapping = role_worker_mapping
108108
self.resource_pool_manager = resource_pool_manager
109-
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
109+
self.use_reference_policy = need_reference_policy(self.config)
110110
self.use_rm = need_reward_model(self.role_worker_mapping)
111111
self.use_critic = need_critic(config)
112112
self.ray_worker_group_cls = ray_worker_group_cls

verl/experimental/transfer_queue/main_ppo.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,8 @@
2323
from omegaconf import OmegaConf
2424

2525
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
26-
from verl.trainer.main_ppo import (
27-
TaskRunner as MainTaskRunner,
28-
)
29-
from verl.trainer.main_ppo import (
30-
create_rl_dataset,
31-
create_rl_sampler,
32-
)
26+
from verl.trainer.main_ppo import TaskRunner as MainTaskRunner
27+
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
3328
from verl.trainer.ppo.reward import load_reward_manager
3429
from verl.trainer.ppo.utils import need_critic, need_reference_policy
3530
from verl.utils.config import validate_config
@@ -148,7 +143,7 @@ def run(self, config):
148143
# validate config
149144
validate_config(
150145
config=config,
151-
use_reference_policy=need_reference_policy(self.role_worker_mapping),
146+
use_reference_policy=need_reference_policy(config),
152147
use_critic=need_critic(config),
153148
)
154149

verl/experimental/transfer_queue/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def __init__(
369369

370370
self.role_worker_mapping = role_worker_mapping
371371
self.resource_pool_manager = resource_pool_manager
372-
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
372+
self.use_reference_policy = need_reference_policy(self.config)
373373
self.use_rm = need_reward_model(self.role_worker_mapping)
374374
self.use_critic = need_critic(self.config)
375375
self.ray_worker_group_cls = ray_worker_group_cls

verl/trainer/main_ppo.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,14 @@ def add_actor_rollout_worker(self, config):
133133

134134
actor_rollout_cls = ActorRolloutRefWorker
135135
ray_worker_group_cls = RayWorkerGroup
136+
137+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
138+
if lora_rank <= 0:
139+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
140+
ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
136141
# NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker,
137142
# while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker.
138-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
143+
if need_reference_policy(config) and not ref_in_actor:
139144
role = Role.ActorRolloutRef
140145
else:
141146
role = Role.ActorRollout
@@ -249,7 +254,7 @@ def add_ref_policy_worker(self, config, ref_policy_cls):
249254
if use_legacy_worker_impl == "disable":
250255
return
251256

252-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
257+
if need_reference_policy(config):
253258
self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
254259
self.mapping[Role.RefPolicy] = "global_pool"
255260

@@ -291,7 +296,7 @@ def run(self, config):
291296
# validate config
292297
validate_config(
293298
config=config,
294-
use_reference_policy=need_reference_policy(self.role_worker_mapping),
299+
use_reference_policy=need_reference_policy(config),
295300
use_critic=need_critic(config),
296301
)
297302

verl/trainer/ppo/ray_trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def __init__(
339339

340340
self.role_worker_mapping = role_worker_mapping
341341
self.resource_pool_manager = resource_pool_manager
342-
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
342+
self.use_reference_policy = need_reference_policy(self.config)
343343
# legacy reward model implementation
344344
self.use_rm = need_reward_model(self.role_worker_mapping)
345345
self.use_reward_loop = self.config.reward_model.use_reward_loop
@@ -1244,8 +1244,14 @@ def _compute_ref_log_prob(self, batch: DataProto) -> DataProto:
12441244
# step 2: convert from padding to nopadding
12451245
batch_td = left_right_2_no_padding(batch_td)
12461246
# step 3: add meta info
1247-
tu.assign_non_tensor(batch_td, calculate_entropy=False, compute_loss=False)
1248-
output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
1247+
metadata = {"calculate_entropy": False, "compute_loss": False}
1248+
if self.ref_in_actor:
1249+
metadata["no_lora_adapter"] = True
1250+
tu.assign_non_tensor(batch_td, **metadata)
1251+
if self.ref_in_actor:
1252+
output = self.actor_rollout_wg.compute_log_prob(batch_td)
1253+
else:
1254+
output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
12491255
# gather output
12501256
log_probs = tu.get(output, "log_probs")
12511257
# step 4. No padding to padding

verl/trainer/ppo/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def from_string(cls, name: str):
7070

7171

7272
def need_reference_policy(
73-
role_worker_mapping: dict[Role, WorkerType],
73+
config: DictConfig,
7474
) -> bool:
75-
"""Given a role worker mapping, do we need ref policy."""
76-
return Role.RefPolicy in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping
75+
"""Given the config, do we need ref policy."""
76+
return config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss
7777

7878

7979
def need_reward_model(

0 commit comments

Comments
 (0)