-
Notifications
You must be signed in to change notification settings - Fork 454
Expand file tree
/
Copy pathdataset_arguments.py
More file actions
182 lines (166 loc) · 6 KB
/
dataset_arguments.py
File metadata and controls
182 lines (166 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union
from transformers import DefaultDataCollator
from llmcompressor.pipelines.registry import PIPELINES
@dataclass
class DVCDatasetArguments:
"""
Arguments for training using DVC
"""
dvc_data_repository: Optional[str] = field(
default=None,
metadata={"help": "Path to repository used for dvc_dataset_path"},
)
@dataclass
class CustomDatasetArguments(DVCDatasetArguments):
"""
Arguments for training using custom datasets
"""
dataset_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the custom dataset. Supports json, csv, dvc. "
"For DVC, the to dvc dataset to load, of format dvc://path. "
"For csv or json, the path containing the dataset. "
),
},
)
text_column: str = field(
default="text",
metadata={
"help": (
"Optional key to be used as the `text` input to tokenizer/processor "
"after dataset preprocesssing"
)
},
)
remove_columns: Union[None, str, List] = field(
default=None,
metadata={"help": "Column names to remove after preprocessing (deprecated)"},
)
preprocessing_func: Union[None, str, Callable] = field(
default=None,
metadata={
"help": (
"Typically a function which applies a chat template. Can take the form "
"of either a function to apply to the dataset, a name defined in "
"src/llmcompressor/transformers/utils/preprocessing_functions.py, or "
"a path to a function definition of the form /path/to/file.py:func"
)
},
)
data_collator: Callable[[Any], Any] = field(
default_factory=lambda: DefaultDataCollator(),
metadata={"help": "The function to used to form a batch from the dataset"},
)
@dataclass
class DatasetArguments(CustomDatasetArguments):
"""
Arguments pertaining to what data we are going to input our model for
calibration, training
Using `HfArgumentParser` we can turn this class into argparse
arguments to be able to specify them on the command line
"""
dataset: Optional[str] = field(
default=None,
metadata={
"help": (
"The name of the dataset to use (via the datasets library). "
"Supports input as a string or DatasetDict from HF"
)
},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={
"help": ("The configuration name of the dataset to use"),
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter will "
"be padded."
},
)
concatenate_data: bool = field(
default=False,
metadata={
"help": "Whether or not to concatenate datapoints to fill max_seq_length"
},
)
raw_kwargs: Dict = field(
default_factory=dict,
metadata={"help": "Additional keyboard args to pass to datasets load_data"},
)
splits: Union[None, str, List, Dict] = field(
default=None,
metadata={"help": "Optional percentages of each split to download"},
)
num_calibration_samples: Optional[int] = field(
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
shuffle_calibration_samples: Optional[bool] = field(
default=True,
metadata={
"help": "whether to shuffle the dataset before selecting calibration data"
},
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "True to stream data from a cloud dataset"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. If False, "
"will pad the samples dynamically when batching to the maximum length "
"in the batch (which can be faster on GPU but will be slower on TPU)."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number "
"of training examples to this value if set."
},
)
min_tokens_per_module: Optional[float] = field(
default=None,
metadata={
"help": (
"The minimum percentage of tokens (out of the total number) "
"that the module should 'receive' throughout the forward "
"pass of the calibration. If a module receives fewer tokens, "
"a warning will be logged. Defaults to 1/num_of_experts."
"note: this argument is only relevant for MoE models"
),
},
)
trust_remote_code_data: bool = field(
default=False,
metadata={
"help": "Whether or not to allow for datasets defined on the Hub using "
"a dataset script. This option should only be set to True for "
"repositories you trust and in which you have read the code, as it "
"will execute code present on the Hub on your local machine."
},
)
pipeline: Optional[str] = field(
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model. "
f"Options: {PIPELINES.keys()}"
},
)