-
Notifications
You must be signed in to change notification settings - Fork 457
Expand file tree
/
Copy pathbase.py
More file actions
330 lines (280 loc) · 12.5 KB
/
base.py
File metadata and controls
330 lines (280 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import inspect
from functools import cached_property
from inspect import _ParameterKind as Kind
from typing import Any, Callable, Dict, List, Union
from compressed_tensors.registry import RegistryMixin
from datasets import Dataset, IterableDataset
from datasets.formatting.formatting import LazyRow
from loguru import logger
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
from llmcompressor.typing import DatasetType, Processor
from llmcompressor.utils import import_from_path
class TextGenerationDataset(RegistryMixin):
"""
Base class for text datasets. Applies the following transformations to a dataset
in order to prepare the dataset to be loaded by a dataloader
1. Load dataset from huggingface or local cache
2. Preprocess dataset according to preprocess function or chat/dataset template
3. Tokenize dataset using model tokenizer/processor
4. Apply post processing such as grouping text and/or adding labels for finetuning
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
# used to mask out the prompt so prompt tokens do not contribute to training loss
PROMPT_KEY = "prompt"
def __init__(
self,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
self.data_args = data_args
self.split = split
self.processor = processor
# get tokenizer
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
if self.tokenizer is not None:
# fill in pad token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
# configure sequence length
max_seq_length = data_args.max_seq_length
if data_args.max_seq_length > self.tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"maximum length for model ({self.tokenizer.model_max_length}). "
f"Using max_seq_length={self.tokenizer.model_max_length}."
)
self.max_seq_length = min(
data_args.max_seq_length, self.tokenizer.model_max_length
)
# configure padding
self.padding = (
False
if self.data_args.concatenate_data
else "max_length"
if self.data_args.pad_to_max_length
else False
)
else:
self.max_seq_length = None
self.padding = False
def __call__(self, add_labels: bool = True) -> DatasetType:
dataset = self.data_args.dataset
if isinstance(dataset, str):
# load dataset: load from huggingface or disk
dataset = self.load_dataset()
logger.debug(f"Raw dataset: {get_columns(dataset)}")
if self.preprocess is not None:
# preprocess: apply template or preprocessing function
dataset = self.map(
dataset,
self.preprocess,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
desc="Preprocessing",
)
logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")
# rename and remove columns match processor kwargs
dataset = self.rename_columns(dataset)
logger.debug(f"Dataset after column renaming: {get_columns(dataset)}")
# use processor.model_input_names to determine if the ds is already tokenized
model_input_names = getattr(self.processor, "model_input_names", ["input_ids"])
if not any(col_name in model_input_names for col_name in get_columns(dataset)):
# tokenize/ process
dataset = self.filter_tokenizer_args(dataset)
logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}")
dataset = self.map(
dataset,
self.tokenize,
batched=False, # batching is not well supported for vision processors
keep_in_memory=True, # bug occurs when not batched and not in memory,
# subsequent ds.map calls are always batched,
# regardless of `batched` argument
remove_columns=get_columns(dataset), # assumes that input names
# and output names are disjoint
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Tokenizing",
)
logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")
if self.data_args.concatenate_data:
# postprocess: group text
dataset = self.map(
dataset,
self.group_text,
batched=True,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Concatenating data",
)
logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")
if add_labels:
# postprocess: add labels
dataset = self.map(
dataset,
self.add_labels,
batched=False, # not compatible with batching, need row lengths
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Adding labels",
)
logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")
elif self.PROMPT_KEY in get_columns(dataset):
dataset = dataset.remove_columns(self.PROMPT_KEY)
logger.debug("Removed prompt key")
logger.debug(f"Model kwargs after postprocessing: {get_columns(dataset)}")
return dataset
def load_dataset(self):
"""
Load the raw dataset from Hugging Face, using cached copy if available
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
if self.data_args.dataset_path is not None:
if self.data_args.dvc_data_repository is not None:
self.data_args.raw_kwargs["storage_options"] = {
"url": self.data_args.dvc_data_repository
}
self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path
else:
self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path(
self.data_args.dataset_path,
self.data_args.dataset
if hasattr(self.data_args, "dataset")
else self.data_args.dataset_name,
)
logger.debug(f"Loading dataset {self.data_args.dataset}")
return get_raw_dataset(
self.data_args,
None,
split=self.split,
streaming=self.data_args.streaming,
**self.data_args.raw_kwargs,
)
@cached_property
def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
"""
The function must return keys which correspond to processor/tokenizer kwargs,
optionally including PROMPT_KEY
"""
preprocessing_func = self.data_args.preprocessing_func
if callable(preprocessing_func):
return preprocessing_func
if isinstance(preprocessing_func, str):
if ":" in preprocessing_func:
# load func_name from "/path/to/file.py:func_name"
return import_from_path(preprocessing_func)
else:
# load from the registry
return PreprocessingFunctionRegistry.get_value_from_registry(
name=preprocessing_func
)
return self.dataset_template
@property
def dataset_template(self) -> Union[Callable[[Any], Any], None]:
return None
def rename_columns(self, dataset: DatasetType) -> DatasetType:
# rename columns to match processor/tokenizer kwargs
column_names = get_columns(dataset)
if self.data_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`")
dataset = dataset.rename_column(self.data_args.text_column, "text")
return dataset
def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
# assumes that inputs are not passed via self.processor.__call__ args and kwargs
signature = inspect.signature(self.processor.__call__)
tokenizer_args = set(
key
for key, param in signature.parameters.items()
if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD)
)
logger.debug(
f"Found processor args `{tokenizer_args}`. Removing all other columns"
)
column_names = get_columns(dataset)
return dataset.remove_columns(
list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
)
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
# separate prompt
prompt = data.pop(self.PROMPT_KEY, None)
# tokenize
data = self.processor(
**data,
padding=self.padding,
max_length=self.max_seq_length,
truncation=True,
)
# store unpadded prompt so we can mask out correct number of elements in labels
if prompt is not None:
data[self.PROMPT_KEY] = self.processor(
text=prompt,
max_length=self.max_seq_length,
truncation=True,
)["input_ids"]
return data
def group_text(self, data: LazyRow) -> Dict[str, Any]:
concatenated_data = {k: sum(data[k], []) for k in data.keys()}
total_length = len(concatenated_data[list(data.keys())[0]])
total_length = (total_length // self.max_seq_length) * self.max_seq_length
result = {
k: [
t[i : i + self.max_seq_length]
for i in range(0, total_length, self.max_seq_length)
]
for k, t in concatenated_data.items()
}
return result
def add_labels(self, data: LazyRow) -> LazyRow:
if "pixel_values" in data:
raise NotImplementedError(
"Label masking for vision datasets has not been implemented yet"
)
# if the dataset uses prompts, mask them out so they don't contribute
# to the loss calculation
prompt_len = 0
if self.PROMPT_KEY in data:
prompt_len = len(data[self.PROMPT_KEY])
data["labels"] = data["input_ids"].copy()
data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len
# mask out padding in the labels as well
padding = len(data["attention_mask"]) - sum(data["attention_mask"])
if padding > 0:
data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding
return data
def map(
self,
dataset: Union[Dataset, IterableDataset],
function: Callable[[Any], Any],
**kwargs,
) -> Union[Dataset, IterableDataset]:
"""
Wrapper function around Dataset.map and IterableDataset.map.
If the dataset is streaming (in the case of IterableDataset), non-applicable
arguments are ignored and the dataset features are resolved
"""
if isinstance(dataset, IterableDataset):
# remove arguments that don't apply to streaming
kwargs.pop("num_proc", None)
kwargs.pop("load_from_cache_file", None)
kwargs.pop("desc", None)
kwargs.pop("keep_in_memory", None)
dataset = dataset.map(function, **kwargs)
if isinstance(dataset, IterableDataset):
dataset = dataset._resolve_features()
return dataset
def get_columns(dataset: DatasetType) -> List[str]:
column_names = dataset.column_names
if isinstance(column_names, dict):
column_names = sum(column_names.values(), [])
return column_names