|
15 | 15 | from datasets import Dataset |
16 | 16 | from loguru import logger |
17 | 17 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler |
18 | | -from transformers.data import default_data_collator |
| 18 | +from transformers.data import DataCollatorWithPadding |
19 | 19 |
|
20 | 20 | from llmcompressor.args import DatasetArguments |
21 | 21 | from llmcompressor.transformers.finetune.data import TextGenerationDataset |
@@ -115,22 +115,29 @@ def get_calibration_dataloader( |
115 | 115 | ) |
116 | 116 |
|
117 | 117 | calibration_dataset = datasets.get("calibration") |
| 118 | + tokenizer = getattr(processor, "tokenizer", processor) |
| 119 | + collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) |
| 120 | + if dataset_args.batch_size > 1 and ( |
| 121 | + tokenizer.pad_token is None or tokenizer.pad_token_id < 0 |
| 122 | + ): |
| 123 | + logger.warning("Could not find padding token. Setting PAD token to EOS token") |
| 124 | + tokenizer.pad_token = tokenizer.eos_token |
118 | 125 |
|
119 | 126 | return format_calibration_data( |
120 | 127 | tokenized_dataset=calibration_dataset, |
| 128 | + collate_fn=collate_fn, |
121 | 129 | num_calibration_samples=dataset_args.num_calibration_samples, |
122 | 130 | batch_size=dataset_args.batch_size, |
123 | 131 | do_shuffle=dataset_args.shuffle_calibration_samples, |
124 | | - collate_fn=dataset_args.data_collator, |
125 | 132 | ) |
126 | 133 |
|
127 | 134 |
|
128 | 135 | def format_calibration_data( |
129 | 136 | tokenized_dataset: Dataset, |
| 137 | + collate_fn: Callable, |
130 | 138 | num_calibration_samples: int | None = None, |
131 | 139 | batch_size: int = 1, |
132 | 140 | do_shuffle: bool = True, |
133 | | - collate_fn: Callable = default_data_collator, |
134 | 141 | ) -> list[torch.Tensor]: |
135 | 142 | """ |
136 | 143 | Creates a dataloader out of the calibration dataset split, trimming it to |
|
0 commit comments