Skip to content

Commit 9258eb3

Browse files
author
George
authored
[TRL_SFT_Trainer] Fix and Update Examples code (#1161)
SUMMARY: * Fix examples script failure https://github.com/neuralmagic/llm-compressor-testing/actions/runs/13350457472/job/37286313648 PROBLEM 1. ```bash cpy 2 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_constant.py' ... TypeError: SessionManagerMixIn.__init__() missing 2 required positional arguments: 'data_args' and 'model_args' ``` 2. ```bash (.venv) gohashi@janice:~/llm-compressor$ cpy 2 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_constant.py' ... TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'max_seq_length' ``` 3. ```bash (.venv) gohashi@janice:~/llm-compressor$ cpy 2 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_constant.py' ... AttributeError: 'NoneType' object has no attribute 'save_compressed' ``` 4. ``` (.venv) gohashi@janice:~/llm-compressor$ cpy 2 '/home/gohashi/llm-compressor/examples/trl_mixin/ex_trl_constant.py' ... /home/gohashi/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py:97: FutureWarning: `tokenizer` is deprecated and removed starting from version 0.16.0 for `SFTTrainer.__init__`. Use `processing_class` instead. ... ``` SOLUTION: 1. Caused by https://github.com/vllm-project/llm-compressor/pull/1103/files#diff-059b8cf7e48691cd2d5ddda1d0ba5f584657a70c5804797d38c902b433777335R69-R70, where `model_args` and `data_args` is required. Add it to the code and make `model_args` and `data_args` optional 2. `max_seq_length` is not a part of `TrainingArgs`, which gets called by super first. We see that it is used in `SFTConfig` that inherits `TRLSFTConfig` where `max_seq_length` is used. `TRLSFTConfig` inherits `TrainingArguments` to modify the code. 3. Make `model_args` required 4. Bug warning. Update `tokenizer` to `processing_class` TEST PLAN: * Pass`[examples/trl_mixin/ex_trl_constant.py](https://github.com/vllm-project/llm-compressor/compare/sessionmixin-revert-signature?expand=1#diff-f14ef5a7e5c54f35e347fd75ed37e39b8f6db081199bd6233cf14d2c1b4bdef9)` * Pass existing tests
1 parent 2053ee9 commit 9258eb3

File tree

3 files changed

+13
-25
lines changed

3 files changed

+13
-25
lines changed

examples/trl_mixin/ex_trl_constant.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from transformers import AutoModelForCausalLM, AutoTokenizer
44
from trl import DataCollatorForCompletionOnlyLM
55

6-
from llmcompressor.args import TrainingArguments
6+
from llmcompressor.args import ModelArguments
77

88
model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
99
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
@@ -39,21 +39,23 @@ def formatting_prompts_func(example):
3939
response_template = "Answer:"
4040
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
4141

42-
training_args = TrainingArguments(
42+
trl_sft_config_args = dict(
4343
output_dir=output_dir,
4444
num_train_epochs=0.6,
4545
logging_steps=50,
4646
gradient_checkpointing=True,
47+
max_seq_length=512,
4748
)
49+
model_args = ModelArguments(model=model)
4850

4951
trainer = SFTTrainer(
5052
model=model,
51-
tokenizer=tokenizer,
53+
processing_class=tokenizer,
5254
recipe=recipe,
5355
train_dataset=dataset,
5456
formatting_func=formatting_prompts_func,
5557
data_collator=collator,
56-
args=training_args,
57-
max_seq_length=512,
58+
trl_sft_config_args=trl_sft_config_args,
59+
model_args=model_args,
5860
)
5961
trainer.train()

examples/trl_mixin/sft_trainer.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
1+
from typing import Dict, Optional
2+
13
from trl import SFTConfig as TRLSFTConfig
24
from trl import SFTTrainer as TRLSFTTrainer
35

4-
from llmcompressor.args import TrainingArguments
56
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn
67

78
__all__ = ["SFTTrainer"]
89

910

1011
class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):
11-
def __init__(self, *args, **kwargs):
12-
sft_config_args = kwargs.get("args")
13-
if (
14-
sft_config_args is not None
15-
and sft_config_args.__class__.__name__ == "TrainingArguments"
16-
):
17-
kwargs["args"] = SFTConfig(**sft_config_args.to_dict())
12+
def __init__(self, trl_sft_config_args: Optional[Dict] = None, *args, **kwargs):
13+
if trl_sft_config_args is not None:
14+
kwargs["args"] = TRLSFTConfig(**trl_sft_config_args)
1815
super().__init__(*args, **kwargs)
1916

2017
def _prepare_dataset(self, dataset, *args, **kwargs):
@@ -23,14 +20,3 @@ def _prepare_dataset(self, dataset, *args, **kwargs):
2320
return dataset
2421

2522
return super()._prepare_dataset(dataset, *args, **kwargs)
26-
27-
28-
class SFTConfig(TrainingArguments, TRLSFTConfig):
29-
"""
30-
This class is needed to wrap the llmcompressor.transformers.TrainingArguments
31-
and TRLSFTConfig classes. This allows for the use of arguments and
32-
configurations from both classes when training a model.
33-
"""
34-
35-
def __init__(self, *args, **kwargs):
36-
super().__init__(*args, **kwargs)

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class SessionManagerMixIn:
6666
def __init__(
6767
self,
6868
recipe: str,
69-
data_args: "DatasetArguments",
7069
model_args: "ModelArguments",
70+
data_args: Optional["DatasetArguments"] = None,
7171
teacher: Optional[Union[Module, str]] = None,
7272
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
7373
**kwargs,

0 commit comments

Comments
 (0)