Skip to content

Commit 9e26589

Browse files
author
George Ohashi
committed
rename data_arguments to dataset_arguments
1 parent d50baba commit 9e26589

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Callable, Dict, List, Optional, Union
3+
4+
from transformers import DefaultDataCollator
5+
6+
7+
@dataclass
8+
class DVCDatasetArguments:
9+
"""
10+
Arguments for training using DVC
11+
"""
12+
13+
dvc_data_repository: Optional[str] = field(
14+
default=None,
15+
metadata={"help": "Path to repository used for dvc_dataset_path"},
16+
)
17+
18+
19+
@dataclass
20+
class CustomDatasetArguments(DVCDatasetArguments):
21+
"""
22+
Arguments for training using custom datasets
23+
"""
24+
25+
dataset_path: Optional[str] = field(
26+
default=None,
27+
metadata={
28+
"help": (
29+
"Path to the custom dataset. Supports json, csv, dvc. "
30+
"For DVC, the to dvc dataset to load, of format dvc://path. "
31+
"For csv or json, the path containing the dataset. "
32+
),
33+
},
34+
)
35+
36+
text_column: str = field(
37+
default="text",
38+
metadata={
39+
"help": (
40+
"Optional key to be used as the `text` input to tokenizer/processor "
41+
"after dataset preprocesssing"
42+
)
43+
},
44+
)
45+
46+
remove_columns: Union[None, str, List] = field(
47+
default=None,
48+
metadata={"help": "Column names to remove after preprocessing (deprecated)"},
49+
)
50+
51+
preprocessing_func: Union[None, str, Callable] = field(
52+
default=None,
53+
metadata={
54+
"help": (
55+
"Typically a function which applies a chat template. Can take the form "
56+
"of either a function to apply to the dataset, a name defined in "
57+
"src/llmcompressor/transformers/utils/preprocessing_functions.py, or "
58+
"a path to a function definition of the form /path/to/file.py:func"
59+
)
60+
},
61+
)
62+
63+
data_collator: Callable[[Any], Any] = field(
64+
default_factory=lambda: DefaultDataCollator(),
65+
metadata={"help": "The function to used to form a batch from the dataset"},
66+
)
67+
68+
69+
@dataclass
70+
class DatasetArguments(CustomDatasetArguments):
71+
"""
72+
Arguments pertaining to what data we are going to input our model for
73+
calibration, training or eval
74+
75+
Using `HfArgumentParser` we can turn this class into argparse
76+
arguments to be able to specify them on the command line
77+
"""
78+
79+
dataset: Optional[str] = field(
80+
default=None,
81+
metadata={
82+
"help": (
83+
"The name of the dataset to use (via the datasets library). "
84+
"Supports input as a string or DatasetDict from HF"
85+
)
86+
},
87+
)
88+
dataset_config_name: Optional[str] = field(
89+
default=None,
90+
metadata={
91+
"help": ("The configuration name of the dataset to use"),
92+
},
93+
)
94+
max_seq_length: int = field(
95+
default=384,
96+
metadata={
97+
"help": "The maximum total input sequence length after tokenization. "
98+
"Sequences longer than this will be truncated, sequences shorter will "
99+
"be padded."
100+
},
101+
)
102+
concatenate_data: bool = field(
103+
default=False,
104+
metadata={
105+
"help": "Whether or not to concatenate datapoints to fill max_seq_length"
106+
},
107+
)
108+
raw_kwargs: Dict = field(
109+
default_factory=dict,
110+
metadata={"help": "Additional keyboard args to pass to datasets load_data"},
111+
)
112+
splits: Union[None, str, List, Dict] = field(
113+
default=None,
114+
metadata={"help": "Optional percentages of each split to download"},
115+
)
116+
num_calibration_samples: Optional[int] = field(
117+
default=512,
118+
metadata={"help": "Number of samples to use for one-shot calibration"},
119+
)
120+
shuffle_calibration_samples: Optional[bool] = field(
121+
default=True,
122+
metadata={
123+
"help": "whether to shuffle the dataset before selecting calibration data"
124+
},
125+
)
126+
streaming: Optional[bool] = field(
127+
default=False,
128+
metadata={"help": "True to stream data from a cloud dataset"},
129+
)
130+
overwrite_cache: bool = field(
131+
default=False,
132+
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
133+
)
134+
preprocessing_num_workers: Optional[int] = field(
135+
default=None,
136+
metadata={"help": "The number of processes to use for the preprocessing."},
137+
)
138+
pad_to_max_length: bool = field(
139+
default=True,
140+
metadata={
141+
"help": "Whether to pad all samples to `max_seq_length`. If False, "
142+
"will pad the samples dynamically when batching to the maximum length "
143+
"in the batch (which can be faster on GPU but will be slower on TPU)."
144+
},
145+
)
146+
max_train_samples: Optional[int] = field(
147+
default=None,
148+
metadata={
149+
"help": "For debugging purposes or quicker training, truncate the number "
150+
"of training examples to this value if set."
151+
},
152+
)
153+
max_eval_samples: Optional[int] = field(
154+
default=None,
155+
metadata={
156+
"help": "For debugging purposes or quicker training, truncate the number "
157+
"of evaluation examples to this value if set."
158+
},
159+
)
160+
max_predict_samples: Optional[int] = field(
161+
default=None,
162+
metadata={
163+
"help": (
164+
"For debugging purposes or quicker training, truncate the number of "
165+
"prediction examples to this value if set."
166+
),
167+
},
168+
)
169+
min_tokens_per_module: Optional[float] = field(
170+
default=None,
171+
metadata={
172+
"help": (
173+
"The minimum percentage of tokens (out of the total number) "
174+
"that the module should 'receive' throughout the forward "
175+
"pass of the calibration. If a module receives fewer tokens, "
176+
"a warning will be logged. Defaults to 1/num_of_experts."
177+
"note: this argument is only relevant for MoE models"
178+
),
179+
},
180+
)
181+
trust_remote_code_data: bool = field(
182+
default=False,
183+
metadata={
184+
"help": "Whether or not to allow for datasets defined on the Hub using "
185+
"a dataset script. This option should only be set to True for "
186+
"repositories you trust and in which you have read the code, as it "
187+
"will execute code present on the Hub on your local machine."
188+
},
189+
)

0 commit comments

Comments
 (0)