-
Notifications
You must be signed in to change notification settings - Fork 458
Expand file tree
/
Copy pathsft_trainer.py
More file actions
36 lines (27 loc) · 1.26 KB
/
sft_trainer.py
File metadata and controls
36 lines (27 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from trl import SFTConfig as TRLSFTConfig
from trl import SFTTrainer as TRLSFTTrainer
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn
from llmcompressor.transformers.utils.arg_parser import TrainingArguments
__all__ = ["SFTTrainer"]
class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):
def __init__(self, *args, **kwargs):
sft_config_args = kwargs.get("args")
if (
sft_config_args is not None
and sft_config_args.__class__.__name__ == "TrainingArguments"
):
kwargs["args"] = SFTConfig(**sft_config_args.to_dict())
super().__init__(*args, **kwargs)
def _prepare_dataset(self, dataset, *args, **kwargs):
if "input_ids" in dataset.column_names:
# dataset is already tokenized, skip preprocessing
return dataset
return super()._prepare_dataset(dataset, *args, **kwargs)
class SFTConfig(TrainingArguments, TRLSFTConfig):
"""
This class is needed to wrap the llmcompressor.transformers.TrainingArguments
and TRLSFTConfig classes. This allows for the use of arguments and
configurations from both classes when training a model.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)