Skip to content

Commit d66120d

Browse files
authored
[fsdp] fix: reward model also reads override config attn_implementation (verl-project#4458)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - verl-project#3978 missing the reward one ### Test only need to test in CI ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 896db9b commit d66120d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

verl/workers/fsdp_workers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,12 @@ def _build_model(self, config):
16571657
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
16581658

16591659
trust_remote_code = config.model.get("trust_remote_code", False)
1660-
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
1660+
override_config = OmegaConf.to_container(OmegaConf.create(config.model.get("override_config", {})))
1661+
model_config = AutoConfig.from_pretrained(
1662+
local_path,
1663+
trust_remote_code=trust_remote_code,
1664+
attn_implementation=override_config.get("attn_implementation", "flash_attention_2"),
1665+
)
16611666
model_config.num_labels = 1
16621667

16631668
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
@@ -1672,7 +1677,6 @@ def _build_model(self, config):
16721677
pretrained_model_name_or_path=local_path,
16731678
config=model_config,
16741679
torch_dtype=torch.bfloat16,
1675-
attn_implementation="flash_attention_2",
16761680
trust_remote_code=trust_remote_code,
16771681
)
16781682

0 commit comments

Comments
 (0)