|
1 | 1 | from sft_trainer import SFTTrainer |
2 | 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator |
3 | 3 |
|
4 | | -from llmcompressor.args import DatasetArguments, TrainingArguments |
| 4 | +from llmcompressor.args import DatasetArguments, ModelArguments |
5 | 5 | from llmcompressor.transformers import TextGenerationDataset |
6 | 6 |
|
7 | 7 | model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" |
|
16 | 16 | ) |
17 | 17 |
|
18 | 18 | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 19 | +max_seq_length = 512 |
19 | 20 |
|
20 | 21 | # Load gsm8k using SparseML dataset tools |
21 | 22 | data_args = DatasetArguments( |
22 | | - dataset="gsm8k", dataset_config_name="main", max_seq_length=512 |
| 23 | + dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length |
23 | 24 | ) |
24 | 25 | dataset_manager = TextGenerationDataset.load_from_registry( |
25 | 26 | data_args.dataset, |
26 | 27 | data_args=data_args, |
27 | 28 | split="train", |
28 | | - tokenizer=tokenizer, |
| 29 | + processor=tokenizer, |
29 | 30 | ) |
30 | | -train_dataset = dataset_manager.tokenize_and_process() |
| 31 | +train_dataset = dataset_manager() |
31 | 32 | print(f"--> Training Set Length = {len(train_dataset)}") |
32 | 33 |
|
33 | 34 | # recipe for maintaining model sparsity during finetuning |
|
48 | 49 | """ |
49 | 50 |
|
50 | 51 | data_collator = DefaultDataCollator() |
51 | | -training_args = TrainingArguments( |
| 52 | +trl_sft_config_args = dict( |
52 | 53 | output_dir=output_dir, |
53 | 54 | num_train_epochs=0.6, |
54 | 55 | logging_steps=50, |
55 | 56 | gradient_checkpointing=True, |
56 | 57 | bf16=True, |
57 | 58 | save_safetensors=False, # workaround for shared tensors |
| 59 | + max_seq_length=max_seq_length, |
| 60 | + packing=True, |
58 | 61 | ) |
| 62 | +model_args = ModelArguments(model=model, distill_teacher=teacher) |
| 63 | + |
59 | 64 | trainer = SFTTrainer( |
60 | 65 | model=model, |
61 | 66 | teacher=teacher, |
62 | | - tokenizer=tokenizer, |
| 67 | + processing_class=tokenizer, |
63 | 68 | recipe=recipe, |
64 | 69 | train_dataset=train_dataset, |
65 | 70 | data_collator=data_collator, |
66 | | - args=training_args, |
| 71 | + trl_sft_config_args=trl_sft_config_args, |
67 | 72 | data_args=data_args, |
68 | | - max_seq_length=data_args.max_seq_length, |
69 | | - packing=True, |
| 73 | + model_args=model_args, |
70 | 74 | ) |
71 | 75 | trainer.train() |
72 | | -trainer.save_model() |
| 76 | +trainer.save_model(output_dir) |
0 commit comments