21
21
import json
22
22
import os
23
23
24
- from datasets import load_from_disk , Dataset
24
+ from datasets import load_dataset , Dataset
25
25
from datasets .distributed import split_dataset_by_node
26
26
from peft import LoraConfig , get_peft_model
27
27
import transformers
@@ -71,28 +71,25 @@ def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
71
71
return model , tokenizer
72
72
73
73
# This function is a modified version of the original.
74
- def load_and_preprocess_data (dataset_dir , transformer_type , tokenizer ):
74
+ def load_and_preprocess_data (dataset_file , transformer_type , tokenizer ):
75
75
# Load and preprocess the dataset
76
76
logger .info ("Load and preprocess dataset" )
77
77
78
- file_path = os .path .realpath (dataset_dir )
78
+ file_path = os .path .realpath (dataset_file )
79
79
80
- if transformer_type != AutoModelForImageClassification :
81
- dataset = load_from_disk (file_path )
80
+ dataset = load_dataset ('json' ,data_files = file_path )
82
81
82
+ if transformer_type != AutoModelForImageClassification :
83
83
logger .info (f"Dataset specification: { dataset } " )
84
84
logger .info ("-" * 40 )
85
85
86
86
logger .info ("Tokenize dataset" )
87
87
# TODO (andreyvelich): Discuss how user should set the tokenizer function.
88
- num_cores = os .cpu_count ()
89
88
dataset = dataset .map (
90
- lambda x : tokenizer (x ["text " ], padding = True , truncation = True , max_length = 128 ),
89
+ lambda x : tokenizer (x ["output " ], padding = True , truncation = True , max_length = 128 ),
91
90
batched = True ,
92
- num_proc = num_cores
91
+ keep_in_memory = True
93
92
)
94
- else :
95
- dataset = load_from_disk (file_path )
96
93
97
94
# Check if dataset contains `train` key. Otherwise, load full dataset to train_data.
98
95
if "train" in dataset :
@@ -175,7 +172,7 @@ def parse_arguments():
175
172
parser .add_argument ("--model_uri" , help = "model uri" )
176
173
parser .add_argument ("--transformer_type" , help = "model transformer type" )
177
174
parser .add_argument ("--model_dir" , help = "directory containing model" )
178
- parser .add_argument ("--dataset_dir " , help = "directory containing dataset " )
175
+ parser .add_argument ("--dataset_file " , help = "dataset file path " )
179
176
parser .add_argument ("--lora_config" , help = "lora_config" )
180
177
parser .add_argument (
181
178
"--training_parameters" , help = "hugging face training parameters"
@@ -197,7 +194,7 @@ def parse_arguments():
197
194
198
195
logger .info ("Preprocess dataset" )
199
196
train_data , eval_data = load_and_preprocess_data (
200
- args .dataset_dir , transformer_type , tokenizer
197
+ args .dataset_file , transformer_type , tokenizer
201
198
)
202
199
203
200
logger .info ("Setup LoRA config for model" )
0 commit comments