11import argparse
22import json
33import os
4+ import subprocess
45from pathlib import Path
56from typing import Dict , Tuple
67
7- from datasets import concatenate_datasets , load_dataset
88from tqdm import tqdm
99
10+ from datasets import concatenate_datasets , config , load_dataset
11+
1012"""
1113This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format:
1214{
@@ -88,7 +90,49 @@ def parse_args():
8890 return parser .parse_args ()
8991
9092
91- def process_ultrachat_row (row : Dict ) -> Tuple [Dict , int ]:
93+ def get_cache_dir (dataset_name ):
94+ cache_dir = None
95+ if dataset_name == "sharegpt4v" :
96+ raise Exception ("Don't Support Download sharegpt4v." )
97+ elif dataset_name == "allava4v" :
98+ cache_dir = os .path .join (
99+ config .HF_DATASETS_CACHE , "FreedomIntelligence" , "ALLaVA"
100+ )
101+ else :
102+ raise Exception (f"Don't support { dataset_name } " )
103+ return cache_dir
104+
105+
106+ def download_vlm_dataset (dataset_name : str ) -> None :
107+ """Download VLM's dataset such as sharegpt4v and allava4v"""
108+ if dataset_name == "sharegpt4v" :
109+ raise Exception ("Don't Support Download sharegpt4v." )
110+ elif dataset_name == "allava4v" :
111+ cache_dir = get_cache_dir (dataset_name )
112+ os .makedirs (cache_dir , exist_ok = True )
113+ script_path = os .path .join (
114+ os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
115+ "datasets" ,
116+ "download_laion.sh" ,
117+ )
118+ os .chmod (script_path , 0o755 )
119+ if not os .path .exists (os .path .join (cache_dir , "allava_laion" )):
120+ result = subprocess .run (
121+ ["bash" , script_path ],
122+ cwd = cache_dir ,
123+ capture_output = True ,
124+ text = True ,
125+ )
126+ if result .returncode != 0 :
127+ raise RuntimeError (f"Download image dataset failed: { result .stderr } " )
128+ print ("##### allava4v dataset Download Complete #####" )
129+ else :
130+ print ("##### allava4v dataset has existed." )
131+ else :
132+ raise Exception (f"Don't support { dataset_name } " )
133+
134+
135+ def process_ultrachat_row (row : Dict , dataset_name : str = None ) -> Tuple [Dict , int ]:
92136 """Process a row from the ultrachat dataset.
93137
94138 The function expects a row with the following schema:
@@ -110,7 +154,7 @@ def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
110154 return row , 0
111155
112156
113- def process_sharegpt_row (row : Dict ) -> Tuple [Dict , int ]:
157+ def process_sharegpt_row (row : Dict , dataset_name : str = None ) -> Tuple [Dict , int ]:
114158 """
115159 sharegpt dataset schema:
116160 {
@@ -138,7 +182,7 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
138182 return row , skipped_count
139183
140184
141- def process_sharegpt4v_row (row ) -> Dict :
185+ def process_sharegpt4v_row (row , dataset_name : str = None ) -> Dict :
142186 """
143187 sharegpt4v dataset schema:
144188 {
@@ -153,8 +197,9 @@ def process_sharegpt4v_row(row) -> Dict:
153197 ]
154198 }
155199 """
200+ cache_dir = get_cache_dir (dataset_name )
156201 conversations = row ["conversations" ]
157- image = f'FreedomIntelligence/ALLaVA-4V/ { row ["image" ]} '
202+ image = os . path . join ( cache_dir , f" { row ["image" ]} " )
158203 if not os .path .exists (image ):
159204 print (f"Image path { image } does not exist, skipping this sample." )
160205 return None , None
@@ -194,7 +239,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
194239 with open (train_output_jsonl_path , "w" ) as f :
195240 for item in tqdm (train_ds , desc = f"Processing { dataset_name } dataset" ):
196241 if proc_fn is not None :
197- row , skipped_count = proc_fn (item )
242+ row , skipped_count = proc_fn (item , dataset_name )
198243 if row is None :
199244 continue
200245 total_skipped_count += skipped_count
@@ -207,7 +252,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
207252 with open (test_output_jsonl_path , "w" ) as f :
208253 for item in tqdm (test_ds , desc = f"Processing { dataset_name } test dataset" ):
209254 if proc_fn is not None :
210- row , skipped_count = proc_fn (item )
255+ row , skipped_count = proc_fn (item , dataset_name )
211256 if row is None :
212257 continue
213258 total_skipped_count += skipped_count
@@ -292,11 +337,14 @@ def main():
292337 proc_fn = process_sharegpt_row
293338 elif args .dataset == "sharegpt4v" :
294339 ds = load_dataset ("Lin-Chen/ShareGPT4V" , "ShareGPT4V" )["train" ]
340+ raise Exception ("Not supported sharegpt4v now" )
341+ download_vlm_dataset (args .dataset )
295342 proc_fn = process_sharegpt4v_row
296343 elif args .dataset == "allava4v" :
297344 ds = load_dataset ("FreedomIntelligence/ALLaVA-4V" , name = "allava_laion" )[
298345 "instruct"
299346 ]
347+ download_vlm_dataset (args .dataset )
300348 proc_fn = process_sharegpt4v_row
301349 elif args .dataset == "opc" :
302350 if args .opc_subset == "all" :
@@ -318,7 +366,6 @@ def main():
318366 raise ValueError (
319367 f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script."
320368 )
321-
322369 # filter and split dataset
323370 if args .sample_size is not None and args .sample_size < len (ds ):
324371 ds = ds .select (range (args .sample_size ))
0 commit comments