Skip to content

Commit c790552

Browse files
authored
[megatron] fix: set model to eval during compute_log_prob/compute_values (verl-project#4708)
### What does this PR do? Current code doesn't seem to set model to eval when we compute_log_prob, ref_compute_log_prob, and compute_values for the megatron backend, which might go wrong if the model contains dropout layers for e.g. LoRA. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 04a8490 commit c790552

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

verl/workers/actor/megatron_actor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
196196
Returns:
197197
DataProto: torch.Tensor: the log_prob tensor
198198
"""
199+
prev_modes = [m.training for m in self.actor_module]
200+
for module in self.actor_module:
201+
module.eval()
199202
use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False)
200203
micro_batch_size = data.meta_info.get("micro_batch_size", None)
201204
max_token_len = data.meta_info.get("max_token_len", None)
@@ -306,6 +309,8 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
306309
# add empty cache after each compute
307310
get_torch_device().empty_cache()
308311

312+
for module, mode in zip(self.actor_module, prev_modes, strict=False):
313+
module.train(mode)
309314
return log_probs, entropys, layers_topk_idx
310315

311316
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:

verl/workers/critic/megatron_critic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def _validate_config(self, config) -> None:
8787

8888
@GPUMemoryLogger("megatron critic", logger=logger)
8989
def compute_values(self, data: DataProto) -> DataProto:
90+
prev_modes = [m.training for m in self.critic_module]
91+
for module in self.critic_module:
92+
module.eval()
9093
responses = data.batch["responses"]
9194
attention_mask = data.batch["attention_mask"]
9295
use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False)
@@ -139,6 +142,8 @@ def compute_values(self, data: DataProto) -> DataProto:
139142
# add empty cache after each compute
140143
get_torch_device().empty_cache()
141144

145+
for module, mode in zip(self.critic_module, prev_modes, strict=False):
146+
module.train(mode)
142147
return values
143148

144149
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:

0 commit comments

Comments
 (0)