11import argparse
22import json
33import os
4+ import subprocess
45from pathlib import Path
56from typing import Dict , Tuple
67
7- from datasets import concatenate_datasets , load_dataset
88from tqdm import tqdm
99
10+ from datasets import concatenate_datasets , config , load_dataset
11+
1012"""
1113This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format:
1214{
@@ -88,7 +90,53 @@ def parse_args():
8890 return parser .parse_args ()
8991
9092
91- def process_ultrachat_row (row : Dict ) -> Tuple [Dict , int ]:
93+ def get_cache_dir (dataset_name ):
94+ cache_dir = None
95+ if dataset_name == "sharegpt4v" :
96+ raise ValueError ("Downloading 'sharegpt4v' is not supported." )
97+ elif dataset_name == "allava4v" :
98+ cache_dir = os .path .join (
99+ config .HF_DATASETS_CACHE , "FreedomIntelligence" , "ALLaVA"
100+ )
101+ else :
102+ raise ValueError (
103+ f"Dataset '{ dataset_name } ' is not a supported VLM dataset for download."
104+ )
105+ return cache_dir
106+
107+
108+ def download_vlm_dataset (dataset_name : str ) -> None :
109+ """Download VLM's dataset such as sharegpt4v and allava4v"""
110+ if dataset_name == "sharegpt4v" :
111+ raise Exception ("Don't Support Download sharegpt4v." )
112+ elif dataset_name == "allava4v" :
113+ cache_dir = get_cache_dir (dataset_name )
114+ os .makedirs (cache_dir , exist_ok = True )
115+ script_path = os .path .join (
116+ os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
117+ "datasets" ,
118+ "download_laion.sh" ,
119+ )
120+ os .chmod (script_path , 0o755 )
121+ if not os .path .exists (
122+ os .path .join (cache_dir , "allava_laion" , "image_chunks" , "images_0.zip" )
123+ ):
124+ result = subprocess .run (
125+ ["bash" , script_path ],
126+ cwd = cache_dir ,
127+ capture_output = True ,
128+ text = True ,
129+ )
130+ if result .returncode != 0 :
131+ raise RuntimeError (f"Download image dataset failed: { result .stderr } " )
132+ print ("##### allava4v dataset Download Complete #####" )
133+ else :
134+ print ("##### allava4v dataset has existed." )
135+ else :
136+ raise Exception (f"Don't support { dataset_name } " )
137+
138+
139+ def process_ultrachat_row (row : Dict , dataset_name : str = None ) -> Tuple [Dict , int ]:
92140 """Process a row from the ultrachat dataset.
93141
94142 The function expects a row with the following schema:
@@ -110,7 +158,7 @@ def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
110158 return row , 0
111159
112160
113- def process_sharegpt_row (row : Dict ) -> Tuple [Dict , int ]:
161+ def process_sharegpt_row (row : Dict , dataset_name : str = None ) -> Tuple [Dict , int ]:
114162 """
115163 sharegpt dataset schema:
116164 {
@@ -138,7 +186,7 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
138186 return row , skipped_count
139187
140188
141- def process_sharegpt4v_row (row ) -> Dict :
189+ def process_sharegpt4v_row (row , dataset_name : str = None ) -> Dict :
142190 """
143191 sharegpt4v dataset schema:
144192 {
@@ -153,8 +201,9 @@ def process_sharegpt4v_row(row) -> Dict:
153201 ]
154202 }
155203 """
204+ cache_dir = get_cache_dir (dataset_name )
156205 conversations = row ["conversations" ]
157- image = f'FreedomIntelligence/ALLaVA-4V/ { row ["image" ]} '
206+ image = os . path . join ( cache_dir , row ["image" ])
158207 if not os .path .exists (image ):
159208 print (f"Image path { image } does not exist, skipping this sample." )
160209 return None , None
@@ -194,7 +243,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
194243 with open (train_output_jsonl_path , "w" ) as f :
195244 for item in tqdm (train_ds , desc = f"Processing { dataset_name } dataset" ):
196245 if proc_fn is not None :
197- row , skipped_count = proc_fn (item )
246+ row , skipped_count = proc_fn (item , dataset_name )
198247 if row is None :
199248 continue
200249 total_skipped_count += skipped_count
@@ -207,7 +256,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
207256 with open (test_output_jsonl_path , "w" ) as f :
208257 for item in tqdm (test_ds , desc = f"Processing { dataset_name } test dataset" ):
209258 if proc_fn is not None :
210- row , skipped_count = proc_fn (item )
259+ row , skipped_count = proc_fn (item , dataset_name )
211260 if row is None :
212261 continue
213262 total_skipped_count += skipped_count
@@ -292,11 +341,14 @@ def main():
292341 proc_fn = process_sharegpt_row
293342 elif args .dataset == "sharegpt4v" :
294343 ds = load_dataset ("Lin-Chen/ShareGPT4V" , "ShareGPT4V" )["train" ]
344+ raise Exception ("Not supported sharegpt4v now" )
345+ download_vlm_dataset (args .dataset )
295346 proc_fn = process_sharegpt4v_row
296347 elif args .dataset == "allava4v" :
297348 ds = load_dataset ("FreedomIntelligence/ALLaVA-4V" , name = "allava_laion" )[
298349 "instruct"
299350 ]
351+ download_vlm_dataset (args .dataset )
300352 proc_fn = process_sharegpt4v_row
301353 elif args .dataset == "opc" :
302354 if args .opc_subset == "all" :
@@ -318,7 +370,6 @@ def main():
318370 raise ValueError (
319371 f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script."
320372 )
321-
322373 # filter and split dataset
323374 if args .sample_size is not None and args .sample_size < len (ds ):
324375 ds = ds .select (range (args .sample_size ))
0 commit comments