Skip to content

Commit 0f33450

Browse files
committed
support download allava4v dataset
1 parent 886ab9c commit 0f33450

File tree

7 files changed

+104
-12
lines changed

7 files changed

+104
-12
lines changed

datasets/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Store Comprehensive Datasets Download Scripts
2+
3+
| DatasetName | Github | Huggingface | command |
4+
| -------- | -------- | -------- | -------- |
5+
| ALLaVA-4V | [link](https://github.com/FreedomIntelligence/ALLaVA) | [link](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V) | download_laion.sh |

datasets/download_laion.sh

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
3+
laion_root="allava_laion"
4+
5+
mkdir $laion_root
6+
cd $laion_root
7+
8+
9+
# 1. download annotation files
10+
## 1.1 caption
11+
wget -c -O ALLaVA-Caption-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Caption-LAION-4V.json?download=true
12+
13+
## 1.2 instruction
14+
wget -c -O ALLaVA-Instruct-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Instruct-LAION-4V.json?download=true
15+
16+
17+
# 2. download and upzip images
18+
mkdir image_chunks
19+
20+
## 2.1 download
21+
for ((i=0; i<10; i++))
22+
do
23+
wget -c -O image_chunks/images_$i.zip https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/image_chunks/images_$i.zip?download=true &
24+
done
25+
26+
mkdir -p images/
27+
wait
28+
29+
## 2.2 unzip
30+
for ((i=0; i<10; i++))
31+
do
32+
unzip -j -o image_chunks/images_$i.zip -d images/ & # wait patiently, it takes a while...
33+
done
34+
35+
wait
36+
echo "All done!"

scripts/prepare_data.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import argparse
22
import json
33
import os
4+
import subprocess
45
from pathlib import Path
56
from typing import Dict, Tuple
67

7-
from datasets import concatenate_datasets, load_dataset
88
from tqdm import tqdm
99

10+
from datasets import concatenate_datasets, config, load_dataset
11+
1012
"""
1113
This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format:
1214
{
@@ -88,7 +90,49 @@ def parse_args():
8890
return parser.parse_args()
8991

9092

91-
def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
93+
def get_cache_dir(dataset_name):
94+
cache_dir = None
95+
if dataset_name == "sharegpt4v":
96+
raise Exception("Don't Support Download sharegpt4v.")
97+
elif dataset_name == "allava4v":
98+
cache_dir = os.path.join(
99+
config.HF_DATASETS_CACHE, "FreedomIntelligence", "ALLaVA"
100+
)
101+
else:
102+
raise Exception(f"Don't support {dataset_name}")
103+
return cache_dir
104+
105+
106+
def download_vlm_dataset(dataset_name: str) -> None:
107+
"""Download VLM's dataset such as sharegpt4v and allava4v"""
108+
if dataset_name == "sharegpt4v":
109+
raise Exception("Don't Support Download sharegpt4v.")
110+
elif dataset_name == "allava4v":
111+
cache_dir = get_cache_dir(dataset_name)
112+
os.makedirs(cache_dir, exist_ok=True)
113+
script_path = os.path.join(
114+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
115+
"datasets",
116+
"download_laion.sh",
117+
)
118+
os.chmod(script_path, 0o755)
119+
if not os.path.exists(os.path.join(cache_dir, "allava_laion")):
120+
result = subprocess.run(
121+
["bash", script_path],
122+
cwd=cache_dir,
123+
capture_output=True,
124+
text=True,
125+
)
126+
if result.returncode != 0:
127+
raise RuntimeError(f"Download image dataset failed: {result.stderr}")
128+
print("##### allava4v dataset Download Complete #####")
129+
else:
130+
print("##### allava4v dataset has existed.")
131+
else:
132+
raise Exception(f"Don't support {dataset_name}")
133+
134+
135+
def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
92136
"""Process a row from the ultrachat dataset.
93137
94138
The function expects a row with the following schema:
@@ -110,7 +154,7 @@ def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
110154
return row, 0
111155

112156

113-
def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
157+
def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
114158
"""
115159
sharegpt dataset schema:
116160
{
@@ -138,7 +182,7 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
138182
return row, skipped_count
139183

140184

141-
def process_sharegpt4v_row(row) -> Dict:
185+
def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict:
142186
"""
143187
sharegpt4v dataset schema:
144188
{
@@ -153,8 +197,9 @@ def process_sharegpt4v_row(row) -> Dict:
153197
]
154198
}
155199
"""
200+
cache_dir = get_cache_dir(dataset_name)
156201
conversations = row["conversations"]
157-
image = f'FreedomIntelligence/ALLaVA-4V/{row["image"]}'
202+
image = os.path.join(cache_dir, f"{row["image"]}")
158203
if not os.path.exists(image):
159204
print(f"Image path {image} does not exist, skipping this sample.")
160205
return None, None
@@ -194,7 +239,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
194239
with open(train_output_jsonl_path, "w") as f:
195240
for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"):
196241
if proc_fn is not None:
197-
row, skipped_count = proc_fn(item)
242+
row, skipped_count = proc_fn(item, dataset_name)
198243
if row is None:
199244
continue
200245
total_skipped_count += skipped_count
@@ -207,7 +252,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
207252
with open(test_output_jsonl_path, "w") as f:
208253
for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"):
209254
if proc_fn is not None:
210-
row, skipped_count = proc_fn(item)
255+
row, skipped_count = proc_fn(item, dataset_name)
211256
if row is None:
212257
continue
213258
total_skipped_count += skipped_count
@@ -292,11 +337,14 @@ def main():
292337
proc_fn = process_sharegpt_row
293338
elif args.dataset == "sharegpt4v":
294339
ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"]
340+
raise Exception("Not supported sharegpt4v now")
341+
download_vlm_dataset(args.dataset)
295342
proc_fn = process_sharegpt4v_row
296343
elif args.dataset == "allava4v":
297344
ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[
298345
"instruct"
299346
]
347+
download_vlm_dataset(args.dataset)
300348
proc_fn = process_sharegpt4v_row
301349
elif args.dataset == "opc":
302350
if args.opc_subset == "all":
@@ -318,7 +366,6 @@ def main():
318366
raise ValueError(
319367
f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script."
320368
)
321-
322369
# filter and split dataset
323370
if args.sample_size is not None and args.sample_size < len(ds):
324371
ds = ds.select(range(args.sample_size))

scripts/prepare_hidden_states.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343

4444
import torch
4545
import torch.distributed as dist
46-
from datasets import load_dataset
4746
from tqdm import tqdm
4847
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
4948

49+
from datasets import load_dataset
5050
from specforge.args import SGLangBackendArgs
5151
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
5252
from specforge.distributed import (

scripts/train_eagle3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
import torch.distributed as dist
1111
import torch.nn as nn
1212
from accelerate.utils import set_seed
13-
from datasets import load_dataset
1413
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1514
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
1615
from torch.optim import Optimizer
1716
from torch.utils.data import DataLoader
1817
from tqdm import tqdm
1918
from transformers import AutoProcessor, AutoTokenizer
2019

20+
from datasets import load_dataset
2121
from specforge import (
2222
AutoDraftModelConfig,
2323
AutoEagle3DraftModel,

specforge/data/preprocessing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from typing import Dict, List, Optional, Tuple, Union
2828

2929
import torch
30-
from datasets import Dataset as HFDataset
3130
from tqdm import tqdm
3231
from transformers import ImageProcessingMixin, PreTrainedTokenizer
3332

33+
from datasets import Dataset as HFDataset
34+
3435
try:
3536
from qwen_vl_utils import process_vision_info
3637

specforge/data/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323

2424
import torch
2525
import torch.distributed as dist
26-
from datasets import Dataset
2726
from torch.utils.data import DataLoader, DistributedSampler
2827

28+
<<<<<<< HEAD
2929
from specforge.distributed import get_draft_sp_group
30+
=======
31+
from datasets import Dataset
32+
>>>>>>> 9837a17 (support download allava4v dataset)
3033

3134

3235
class DataCollatorWithPadding:

0 commit comments

Comments
 (0)