Skip to content

Commit d823665

Browse files
committed
Pass vllm_is_ratio to LigerFusedLinearGRPOLoss in compute_liger_loss
When using `use_vllm=True` with `use_liger_kernel=True`, the vLLM importance sampling correction (`importance_sampling_ratio`) was computed and stored in the inputs dict but never passed to the Liger GRPO loss. This caused ~100x larger grad_norm compared to the non-Liger path, leading to training instability. Pass `inputs.get("importance_sampling_ratio")` as `vllm_is_ratio` to `self.liger_grpo_loss(...)` in `compute_liger_loss`, matching the correction already applied in the standard `_compute_loss` path. Requires linkedin/Liger-Kernel#1088 which adds `vllm_is_ratio` parameter support to `LigerFusedLinearGRPOLoss`.
1 parent 7c4e7f8 commit d823665

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

tests/test_grpo_trainer.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@
5858
from peft import LoraConfig, PeftModel, get_peft_model
5959

6060

61+
def _liger_supports_vllm_is_ratio():
62+
"""Check if the installed liger-kernel supports vllm_is_ratio parameter in LigerFusedLinearGRPOLoss."""
63+
try:
64+
import inspect
65+
66+
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
67+
68+
return "vllm_is_ratio" in inspect.signature(LigerFusedLinearGRPOLoss.forward).parameters
69+
except Exception:
70+
return False
71+
72+
6173
def multiply_tool(a: int, b: int) -> int:
6274
"""
6375
Multiplies two integers.
@@ -988,6 +1000,68 @@ def test_training_with_off_policy_mask_with_liger(self):
9881000
new_param = trainer.model.get_parameter(n)
9891001
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
9901002

1003+
@require_liger_kernel
1004+
@pytest.mark.xfail(
1005+
not _liger_supports_vllm_is_ratio(),
1006+
reason="Requires vllm_is_ratio support in liger-kernel (linkedin/Liger-Kernel#1088)",
1007+
strict=True,
1008+
)
1009+
def test_compute_liger_loss_passes_vllm_is_ratio(self):
1010+
"""Test that importance_sampling_ratio from inputs is passed to liger_grpo_loss as vllm_is_ratio."""
1011+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
1012+
1013+
training_args = GRPOConfig(
1014+
output_dir=self.tmp_dir,
1015+
learning_rate=0.1,
1016+
per_device_train_batch_size=3,
1017+
num_generations=3,
1018+
max_completion_length=8,
1019+
use_liger_kernel=True,
1020+
report_to="none",
1021+
logging_strategy="no",
1022+
)
1023+
1024+
trainer = GRPOTrainer(
1025+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
1026+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1027+
args=training_args,
1028+
train_dataset=dataset,
1029+
)
1030+
1031+
# Wrap _generate_and_score_completions to inject importance_sampling_ratio
1032+
original_gen = trainer._generate_and_score_completions
1033+
1034+
def gen_with_is_ratio(*args, **kwargs):
1035+
result = original_gen(*args, **kwargs)
1036+
B, T = result["completion_ids"].shape
1037+
result["importance_sampling_ratio"] = torch.full((B, T), 0.5, device=result["completion_ids"].device)
1038+
return result
1039+
1040+
# Wrap liger_grpo_loss.forward to capture vllm_is_ratio argument
1041+
original_fwd = trainer.liger_grpo_loss.forward
1042+
captured_vllm_is_ratios = []
1043+
1044+
def capturing_forward(*args, **kwargs):
1045+
captured_vllm_is_ratios.append(kwargs.get("vllm_is_ratio"))
1046+
return original_fwd(*args, **kwargs)
1047+
1048+
trainer.liger_grpo_loss.forward = capturing_forward
1049+
trainer._generate_and_score_completions = gen_with_is_ratio
1050+
1051+
trainer.train()
1052+
1053+
# Verify vllm_is_ratio was passed in every call to liger_grpo_loss
1054+
assert len(captured_vllm_is_ratios) > 0, "liger_grpo_loss.forward was never called"
1055+
for vllm_is_ratio in captured_vllm_is_ratios:
1056+
assert vllm_is_ratio is not None, (
1057+
"vllm_is_ratio should not be None when importance_sampling_ratio is present"
1058+
)
1059+
assert (vllm_is_ratio == 0.5).all(), (
1060+
"vllm_is_ratio values should match the injected importance_sampling_ratio"
1061+
)
1062+
1063+
release_memory(trainer.model, trainer)
1064+
9911065
def test_training_with_bias_correction_kl(self):
9921066
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
9931067
training_args = GRPOConfig(

trl/trainer/grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
18971897
bias=unwrapped_model.lm_head.bias,
18981898
old_per_token_logps=inputs.get("old_per_token_logps"),
18991899
ref_per_token_logps=inputs.get("ref_per_token_logps"),
1900+
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
19001901
)
19011902
# Extract metrics from the liger_grpo_loss output
19021903
# KL divergence is the first metric when beta is non-zero

0 commit comments

Comments
 (0)