@@ -34,7 +34,7 @@ def parse_args():
3434 parser .add_argument (
3535 "--dataset" ,
3636 type = str ,
37- choices = ["ultrachat" , "sharegpt" ],
37+ choices = ["ultrachat" , "sharegpt" , "opc" ],
3838 help = "The demo dataset to quickly run the training for speculative decoding" ,
3939 )
4040 parser .add_argument (
@@ -108,6 +108,20 @@ def load_dataset_from_path(data_path: Path):
108108 return ds
109109
110110
111+ import hashlib
112+
113+
114+ def process_opc_sft_stage1 (row ) -> Dict :
115+ row_id = hashlib .md5 ((row ["instruction" ] + row ["output" ]).encode ()).hexdigest ()
116+ return {
117+ "id" : row_id ,
118+ "conversations" : [
119+ {"role" : "user" , "content" : row ["instruction" ]},
120+ {"role" : "assistant" , "content" : row ["output" ]},
121+ ],
122+ }
123+
124+
111125def main ():
112126 args = parse_args ()
113127 # load dataset
@@ -121,6 +135,11 @@ def main():
121135 print ("Loading dataset from custom data path: " , args .data_path )
122136 ds = load_dataset_from_path (Path (args .data_path ))
123137 proc_fn = process_sharegpt_row
138+ elif args .dataset == "opc" :
139+ ds = load_dataset (
140+ "OpenCoder-LLM/opc-sft-stage1" , "largescale_diverse_instruct"
141+ )["train" ]
142+ proc_fn = process_opc_sft_stage1
124143 else :
125144 raise ValueError (
126145 f"This script only supports ultrachat_200k and sharegpt datasets for demo purpose, if you wish to use other datasets, please modify this script."
0 commit comments