Skip to content

Commit c8e48ba

Browse files
authored
Merge branch 'sgl-project:main' into tmp/sp
2 parents 2e9bdff + b828c32 commit c8e48ba

33 files changed

+1911
-121
lines changed

configs/qwen3-8b-dflash.json

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"architectures": [
3+
"DFlashDraftModel"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"auto_map": {
8+
"AutoModel": "modeling_dflash.DFlashDraftModel"
9+
},
10+
"block_size": 16,
11+
"bos_token_id": 151643,
12+
"dtype": "bfloat16",
13+
"eos_token_id": 151645,
14+
"head_dim": 128,
15+
"hidden_act": "silu",
16+
"hidden_size": 4096,
17+
"initializer_range": 0.02,
18+
"intermediate_size": 12288,
19+
"layer_types": [
20+
"full_attention",
21+
"full_attention",
22+
"full_attention",
23+
"full_attention",
24+
"full_attention"
25+
],
26+
"max_position_embeddings": 40960,
27+
"max_window_layers": 5,
28+
"model_type": "qwen3",
29+
"num_attention_heads": 32,
30+
"num_hidden_layers": 5,
31+
"num_key_value_heads": 8,
32+
"num_target_layers": 36,
33+
"rms_norm_eps": 1e-06,
34+
"rope_scaling": null,
35+
"rope_theta": 1000000,
36+
"sliding_window": null,
37+
"tie_word_embeddings": false,
38+
"use_cache": true,
39+
"use_sliding_window": false,
40+
"vocab_size": 151936
41+
}

datasets/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Store Comprehensive Datasets Download Scripts
2+
3+
| DatasetName | Github | Huggingface | command |
4+
| -------- | -------- | -------- | -------- |
5+
| ALLaVA-4V | [link](https://github.com/FreedomIntelligence/ALLaVA) | [link](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V) | download_laion.sh |

datasets/download_laion.sh

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
3+
laion_root="allava_laion"
4+
5+
mkdir $laion_root
6+
cd $laion_root
7+
8+
9+
# 1. download annotation files
10+
## 1.1 caption
11+
wget -c -O ALLaVA-Caption-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Caption-LAION-4V.json?download=true
12+
13+
## 1.2 instruction
14+
wget -c -O ALLaVA-Instruct-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Instruct-LAION-4V.json?download=true
15+
16+
17+
# 2. download and upzip images
18+
mkdir image_chunks
19+
20+
## 2.1 download
21+
for ((i=0; i<10; i++))
22+
do
23+
wget -c -O image_chunks/images_$i.zip https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/image_chunks/images_$i.zip?download=true &
24+
done
25+
26+
mkdir -p images/
27+
wait
28+
29+
## 2.2 unzip
30+
for ((i=0; i<10; i++))
31+
do
32+
unzip -j -o image_chunks/images_$i.zip -d images/ & # wait patiently, it takes a while...
33+
done
34+
35+
wait
36+
echo "All done!"

docs/basic_usage/training.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ python scripts/prepare_data.py --dataset sharegpt
3434

3535
```bash
3636
# train llama3-8B-instruct
37-
bash ./examples/run_llama3_eagle3.1_8b_online.sh
37+
bash ./examples/run_llama3.1_8b_eagle3_online.sh
3838
```
3939

4040
## 💨 Offline Training
@@ -49,10 +49,10 @@ Same as above
4949

5050
```bash
5151
# train llama3-8B-instruct in an offline manner
52-
bash ./examples/run_llama3_eagle3.1_8b_offline.sh
52+
bash ./examples/run_llama3.1_8b_eagle3_offline.sh
5353
```
5454

55-
It is important to note that the `run_llama3_eagle3_offline.sh` script consists of two steps:
55+
It is important to note that the `run_llama3.1_8b_eagle3_offline.sh` script consists of two steps:
5656

5757
1. Generate the hidden states using the `prepare_hidden_states.py` script. This script will generate the hidden states for the test and train datasets and save them to the disk.
5858
2. Train the model: suppling the `--train-hidden-states-path` argument to the script so that the script will load the hidden states from the disk during training.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4+
ROOT_DIR=$(dirname $SCRIPT_DIR)
5+
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
6+
export SPECFORGE_DATA_NUM_PROC=32
7+
NUM_GPUS=${1:-1}
8+
9+
torchrun \
10+
--standalone \
11+
--nproc_per_node $NUM_GPUS \
12+
$ROOT_DIR/scripts/train_dflash.py \
13+
--target-model-path Qwen/Qwen3-8B \
14+
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \
15+
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
16+
--output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-sharegpt \
17+
--num-epochs 20 \
18+
--batch-size 4 \
19+
--learning-rate 1e-4 \
20+
--max-length 2048 \
21+
--chat-template qwen \
22+
--log-interval 50 \
23+
--save-interval 1000 \
24+
--report-to wandb \
25+
--wandb-project specforge-qwen3-8b-dflash \
26+
--wandb-name qwen3-8b-dflash-sharegpt

scripts/prepare_data.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import argparse
22
import json
33
import os
4+
import subprocess
45
from pathlib import Path
56
from typing import Dict, Tuple
67

7-
from datasets import concatenate_datasets, load_dataset
88
from tqdm import tqdm
99

10+
from datasets import concatenate_datasets, config, load_dataset
11+
1012
"""
1113
This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format:
1214
{
@@ -88,7 +90,53 @@ def parse_args():
8890
return parser.parse_args()
8991

9092

91-
def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
93+
def get_cache_dir(dataset_name):
94+
cache_dir = None
95+
if dataset_name == "sharegpt4v":
96+
raise ValueError("Downloading 'sharegpt4v' is not supported.")
97+
elif dataset_name == "allava4v":
98+
cache_dir = os.path.join(
99+
config.HF_DATASETS_CACHE, "FreedomIntelligence", "ALLaVA"
100+
)
101+
else:
102+
raise ValueError(
103+
f"Dataset '{dataset_name}' is not a supported VLM dataset for download."
104+
)
105+
return cache_dir
106+
107+
108+
def download_vlm_dataset(dataset_name: str) -> None:
109+
"""Download VLM's dataset such as sharegpt4v and allava4v"""
110+
if dataset_name == "sharegpt4v":
111+
raise Exception("Don't Support Download sharegpt4v.")
112+
elif dataset_name == "allava4v":
113+
cache_dir = get_cache_dir(dataset_name)
114+
os.makedirs(cache_dir, exist_ok=True)
115+
script_path = os.path.join(
116+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
117+
"datasets",
118+
"download_laion.sh",
119+
)
120+
os.chmod(script_path, 0o755)
121+
if not os.path.exists(
122+
os.path.join(cache_dir, "allava_laion", "image_chunks", "images_0.zip")
123+
):
124+
result = subprocess.run(
125+
["bash", script_path],
126+
cwd=cache_dir,
127+
capture_output=True,
128+
text=True,
129+
)
130+
if result.returncode != 0:
131+
raise RuntimeError(f"Download image dataset failed: {result.stderr}")
132+
print("##### allava4v dataset Download Complete #####")
133+
else:
134+
print("##### allava4v dataset has existed.")
135+
else:
136+
raise Exception(f"Don't support {dataset_name}")
137+
138+
139+
def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
92140
"""Process a row from the ultrachat dataset.
93141
94142
The function expects a row with the following schema:
@@ -110,7 +158,7 @@ def process_ultrachat_row(row: Dict) -> Tuple[Dict, int]:
110158
return row, 0
111159

112160

113-
def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
161+
def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]:
114162
"""
115163
sharegpt dataset schema:
116164
{
@@ -138,7 +186,7 @@ def process_sharegpt_row(row: Dict) -> Tuple[Dict, int]:
138186
return row, skipped_count
139187

140188

141-
def process_sharegpt4v_row(row) -> Dict:
189+
def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict:
142190
"""
143191
sharegpt4v dataset schema:
144192
{
@@ -153,8 +201,9 @@ def process_sharegpt4v_row(row) -> Dict:
153201
]
154202
}
155203
"""
204+
cache_dir = get_cache_dir(dataset_name)
156205
conversations = row["conversations"]
157-
image = f'FreedomIntelligence/ALLaVA-4V/{row["image"]}'
206+
image = os.path.join(cache_dir, row["image"])
158207
if not os.path.exists(image):
159208
print(f"Image path {image} does not exist, skipping this sample.")
160209
return None, None
@@ -194,7 +243,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
194243
with open(train_output_jsonl_path, "w") as f:
195244
for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"):
196245
if proc_fn is not None:
197-
row, skipped_count = proc_fn(item)
246+
row, skipped_count = proc_fn(item, dataset_name)
198247
if row is None:
199248
continue
200249
total_skipped_count += skipped_count
@@ -207,7 +256,7 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name):
207256
with open(test_output_jsonl_path, "w") as f:
208257
for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"):
209258
if proc_fn is not None:
210-
row, skipped_count = proc_fn(item)
259+
row, skipped_count = proc_fn(item, dataset_name)
211260
if row is None:
212261
continue
213262
total_skipped_count += skipped_count
@@ -292,11 +341,14 @@ def main():
292341
proc_fn = process_sharegpt_row
293342
elif args.dataset == "sharegpt4v":
294343
ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"]
344+
raise Exception("Not supported sharegpt4v now")
345+
download_vlm_dataset(args.dataset)
295346
proc_fn = process_sharegpt4v_row
296347
elif args.dataset == "allava4v":
297348
ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[
298349
"instruct"
299350
]
351+
download_vlm_dataset(args.dataset)
300352
proc_fn = process_sharegpt4v_row
301353
elif args.dataset == "opc":
302354
if args.opc_subset == "all":
@@ -318,7 +370,6 @@ def main():
318370
raise ValueError(
319371
f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script."
320372
)
321-
322373
# filter and split dataset
323374
if args.sample_size is not None and args.sample_size < len(ds):
324375
ds = ds.select(range(args.sample_size))

scripts/prepare_hidden_states.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343

4444
import torch
4545
import torch.distributed as dist
46-
from datasets import load_dataset
4746
from tqdm import tqdm
4847
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
4948

49+
from datasets import load_dataset
5050
from specforge.args import SGLangBackendArgs
5151
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
5252
from specforge.distributed import (

0 commit comments

Comments
 (0)