@@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin):
3131 3. Tokenize dataset using model tokenizer/processor
3232 4. Apply post processing such as grouping text and/or adding labels for finetuning
3333
34- :param data_args : configuration settings for dataset loading
34+ :param dataset_args : configuration settings for dataset loading
3535 :param split: split from dataset to load, for instance `test` or `train[:5%]`
3636 :param processor: processor or tokenizer to use on dataset
3737 """
@@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin):
4141
4242 def __init__ (
4343 self ,
44- data_args : DatasetArguments ,
44+ dataset_args : DatasetArguments ,
4545 split : str ,
4646 processor : Processor ,
4747 ):
48- self .data_args = data_args
48+ self .dataset_args = dataset_args
4949 self .split = split
5050 self .processor = processor
5151
@@ -58,23 +58,23 @@ def __init__(
5858 self .tokenizer .pad_token = self .tokenizer .eos_token
5959
6060 # configure sequence length
61- max_seq_length = data_args .max_seq_length
62- if data_args .max_seq_length > self .tokenizer .model_max_length :
61+ max_seq_length = dataset_args .max_seq_length
62+ if dataset_args .max_seq_length > self .tokenizer .model_max_length :
6363 logger .warning (
6464 f"The max_seq_length passed ({ max_seq_length } ) is larger than "
6565 f"maximum length for model ({ self .tokenizer .model_max_length } ). "
6666 f"Using max_seq_length={ self .tokenizer .model_max_length } ."
6767 )
6868 self .max_seq_length = min (
69- data_args .max_seq_length , self .tokenizer .model_max_length
69+ dataset_args .max_seq_length , self .tokenizer .model_max_length
7070 )
7171
7272 # configure padding
7373 self .padding = (
7474 False
75- if self .data_args .concatenate_data
75+ if self .dataset_args .concatenate_data
7676 else "max_length"
77- if self .data_args .pad_to_max_length
77+ if self .dataset_args .pad_to_max_length
7878 else False
7979 )
8080
@@ -83,7 +83,7 @@ def __init__(
8383 self .padding = False
8484
8585 def __call__ (self , add_labels : bool = True ) -> DatasetType :
86- dataset = self .data_args .dataset
86+ dataset = self .dataset_args .dataset
8787
8888 if isinstance (dataset , str ):
8989 # load dataset: load from huggingface or disk
@@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
9696 dataset ,
9797 self .preprocess ,
9898 batched = False ,
99- num_proc = self .data_args .preprocessing_num_workers ,
100- load_from_cache_file = not self .data_args .overwrite_cache ,
99+ num_proc = self .dataset_args .preprocessing_num_workers ,
100+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
101101 desc = "Preprocessing" ,
102102 )
103103 logger .debug (f"Dataset after preprocessing: { get_columns (dataset )} " )
@@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
121121 # regardless of `batched` argument
122122 remove_columns = get_columns (dataset ), # assumes that input names
123123 # and output names are disjoint
124- num_proc = self .data_args .preprocessing_num_workers ,
125- load_from_cache_file = not self .data_args .overwrite_cache ,
124+ num_proc = self .dataset_args .preprocessing_num_workers ,
125+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
126126 desc = "Tokenizing" ,
127127 )
128128 logger .debug (f"Model kwargs after tokenizing: { get_columns (dataset )} " )
129129
130- if self .data_args .concatenate_data :
130+ if self .dataset_args .concatenate_data :
131131 # postprocess: group text
132132 dataset = self .map (
133133 dataset ,
134134 self .group_text ,
135135 batched = True ,
136- num_proc = self .data_args .preprocessing_num_workers ,
137- load_from_cache_file = not self .data_args .overwrite_cache ,
136+ num_proc = self .dataset_args .preprocessing_num_workers ,
137+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
138138 desc = "Concatenating data" ,
139139 )
140140 logger .debug (f"Model kwargs after concatenating: { get_columns (dataset )} " )
@@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
145145 dataset ,
146146 self .add_labels ,
147147 batched = False , # not compatible with batching, need row lengths
148- num_proc = self .data_args .preprocessing_num_workers ,
149- load_from_cache_file = not self .data_args .overwrite_cache ,
148+ num_proc = self .dataset_args .preprocessing_num_workers ,
149+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
150150 desc = "Adding labels" ,
151151 )
152152 logger .debug (f"Model kwargs after adding labels: { get_columns (dataset )} " )
@@ -165,27 +165,31 @@ def load_dataset(self):
165165 :param cache_dir: disk location to search for cached dataset
166166 :return: the requested dataset
167167 """
168- if self .data_args .dataset_path is not None :
169- if self .data_args .dvc_data_repository is not None :
170- self .data_args .raw_kwargs ["storage_options" ] = {
171- "url" : self .data_args .dvc_data_repository
168+ if self .dataset_args .dataset_path is not None :
169+ if self .dataset_args .dvc_data_repository is not None :
170+ self .dataset_args .raw_kwargs ["storage_options" ] = {
171+ "url" : self .dataset_args .dvc_data_repository
172172 }
173- self .data_args .raw_kwargs ["data_files" ] = self .data_args .dataset_path
173+ self .dataset_args .raw_kwargs ["data_files" ] = (
174+ self .dataset_args .dataset_path
175+ )
174176 else :
175- self .data_args .raw_kwargs ["data_files" ] = get_custom_datasets_from_path (
176- self .data_args .dataset_path ,
177- self .data_args .dataset
178- if hasattr (self .data_args , "dataset" )
179- else self .data_args .dataset_name ,
177+ self .dataset_args .raw_kwargs ["data_files" ] = (
178+ get_custom_datasets_from_path (
179+ self .dataset_args .dataset_path ,
180+ self .dataset_args .dataset
181+ if hasattr (self .dataset_args , "dataset" )
182+ else self .dataset_args .dataset_name ,
183+ )
180184 )
181185
182- logger .debug (f"Loading dataset { self .data_args .dataset } " )
186+ logger .debug (f"Loading dataset { self .dataset_args .dataset } " )
183187 return get_raw_dataset (
184- self .data_args ,
188+ self .dataset_args ,
185189 None ,
186190 split = self .split ,
187- streaming = self .data_args .streaming ,
188- ** self .data_args .raw_kwargs ,
191+ streaming = self .dataset_args .streaming ,
192+ ** self .dataset_args .raw_kwargs ,
189193 )
190194
191195 @cached_property
@@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
194198 The function must return keys which correspond to processor/tokenizer kwargs,
195199 optionally including PROMPT_KEY
196200 """
197- preprocessing_func = self .data_args .preprocessing_func
201+ preprocessing_func = self .dataset_args .preprocessing_func
198202
199203 if callable (preprocessing_func ):
200204 return preprocessing_func
@@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
218222 def rename_columns (self , dataset : DatasetType ) -> DatasetType :
219223 # rename columns to match processor/tokenizer kwargs
220224 column_names = get_columns (dataset )
221- if self .data_args .text_column in column_names and "text" not in column_names :
222- logger .debug (f"Renaming column `{ self .data_args .text_column } ` to `text`" )
223- dataset = dataset .rename_column (self .data_args .text_column , "text" )
225+ if self .dataset_args .text_column in column_names and "text" not in column_names :
226+ logger .debug (f"Renaming column `{ self .dataset_args .text_column } ` to `text`" )
227+ dataset = dataset .rename_column (self .dataset_args .text_column , "text" )
224228
225229 return dataset
226230
0 commit comments