|
58 | 58 | from peft import LoraConfig, PeftModel, get_peft_model |
59 | 59 |
|
60 | 60 |
|
| 61 | +def _liger_supports_vllm_is_ratio(): |
| 62 | + """Check if the installed liger-kernel supports vllm_is_ratio parameter in LigerFusedLinearGRPOLoss.""" |
| 63 | + try: |
| 64 | + import inspect |
| 65 | + |
| 66 | + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss |
| 67 | + |
| 68 | + return "vllm_is_ratio" in inspect.signature(LigerFusedLinearGRPOLoss.forward).parameters |
| 69 | + except Exception: |
| 70 | + return False |
| 71 | + |
| 72 | + |
61 | 73 | def multiply_tool(a: int, b: int) -> int: |
62 | 74 | """ |
63 | 75 | Multiplies two integers. |
@@ -988,6 +1000,68 @@ def test_training_with_off_policy_mask_with_liger(self): |
988 | 1000 | new_param = trainer.model.get_parameter(n) |
989 | 1001 | assert not torch.equal(param, new_param), f"Parameter {n} has not changed." |
990 | 1002 |
|
| 1003 | + @require_liger_kernel |
| 1004 | + @pytest.mark.xfail( |
| 1005 | + not _liger_supports_vllm_is_ratio(), |
| 1006 | + reason="Requires vllm_is_ratio support in liger-kernel (linkedin/Liger-Kernel#1088)", |
| 1007 | + strict=True, |
| 1008 | + ) |
| 1009 | + def test_compute_liger_loss_passes_vllm_is_ratio(self): |
| 1010 | + """Test that importance_sampling_ratio from inputs is passed to liger_grpo_loss as vllm_is_ratio.""" |
| 1011 | + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
| 1012 | + |
| 1013 | + training_args = GRPOConfig( |
| 1014 | + output_dir=self.tmp_dir, |
| 1015 | + learning_rate=0.1, |
| 1016 | + per_device_train_batch_size=3, |
| 1017 | + num_generations=3, |
| 1018 | + max_completion_length=8, |
| 1019 | + use_liger_kernel=True, |
| 1020 | + report_to="none", |
| 1021 | + logging_strategy="no", |
| 1022 | + ) |
| 1023 | + |
| 1024 | + trainer = GRPOTrainer( |
| 1025 | + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
| 1026 | + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
| 1027 | + args=training_args, |
| 1028 | + train_dataset=dataset, |
| 1029 | + ) |
| 1030 | + |
| 1031 | + # Wrap _generate_and_score_completions to inject importance_sampling_ratio |
| 1032 | + original_gen = trainer._generate_and_score_completions |
| 1033 | + |
| 1034 | + def gen_with_is_ratio(*args, **kwargs): |
| 1035 | + result = original_gen(*args, **kwargs) |
| 1036 | + B, T = result["completion_ids"].shape |
| 1037 | + result["importance_sampling_ratio"] = torch.full((B, T), 0.5, device=result["completion_ids"].device) |
| 1038 | + return result |
| 1039 | + |
| 1040 | + # Wrap liger_grpo_loss.forward to capture vllm_is_ratio argument |
| 1041 | + original_fwd = trainer.liger_grpo_loss.forward |
| 1042 | + captured_vllm_is_ratios = [] |
| 1043 | + |
| 1044 | + def capturing_forward(*args, **kwargs): |
| 1045 | + captured_vllm_is_ratios.append(kwargs.get("vllm_is_ratio")) |
| 1046 | + return original_fwd(*args, **kwargs) |
| 1047 | + |
| 1048 | + trainer.liger_grpo_loss.forward = capturing_forward |
| 1049 | + trainer._generate_and_score_completions = gen_with_is_ratio |
| 1050 | + |
| 1051 | + trainer.train() |
| 1052 | + |
| 1053 | + # Verify vllm_is_ratio was passed in every call to liger_grpo_loss |
| 1054 | + assert len(captured_vllm_is_ratios) > 0, "liger_grpo_loss.forward was never called" |
| 1055 | + for vllm_is_ratio in captured_vllm_is_ratios: |
| 1056 | + assert vllm_is_ratio is not None, ( |
| 1057 | + "vllm_is_ratio should not be None when importance_sampling_ratio is present" |
| 1058 | + ) |
| 1059 | + assert (vllm_is_ratio == 0.5).all(), ( |
| 1060 | + "vllm_is_ratio values should match the injected importance_sampling_ratio" |
| 1061 | + ) |
| 1062 | + |
| 1063 | + release_memory(trainer.model, trainer) |
| 1064 | + |
991 | 1065 | def test_training_with_bias_correction_kl(self): |
992 | 1066 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
993 | 1067 | training_args = GRPOConfig( |
|
0 commit comments