Skip to content

Commit 1dea3cf

Browse files
authored
Merge branch 'llama3-mlp-for-draft-decoder' into llama3-basic-multi-layer-decoder
2 parents fd82bf2 + 2a3969c commit 1dea3cf

File tree

14 files changed

+821
-209
lines changed

14 files changed

+821
-209
lines changed

benchmarks/bench_eagle3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def main():
201201
assert len(benchmark_list) != 0, "the number of benchmark list is 0"
202202

203203
base_url = f"http://localhost:{args.port}"
204+
204205
results = {}
206+
results["model"] = server_args.speculative_draft_model_path
205207

206208
def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int):
207209
for benchmark_name, num_prompts, subset in benchmark_list:

scripts/prepare_hidden_states.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import argparse
3636
import gc
37+
import gzip
3738
import hashlib
3839
import os
3940
from concurrent.futures import ThreadPoolExecutor
@@ -46,7 +47,7 @@
4647
from tqdm import tqdm
4748
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
4849

49-
from datasets import load_dataset
50+
from datasets import Dataset
5051
from specforge.args import SGLangBackendArgs
5152
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
5253
from specforge.distributed import (
@@ -57,7 +58,12 @@
5758
is_tp_rank_0,
5859
)
5960
from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model
60-
from specforge.utils import print_with_rank, rank_0_priority
61+
from specforge.utils import (
62+
print_args_with_dots,
63+
print_with_rank,
64+
rank_0_priority,
65+
safe_conversations_generator,
66+
)
6167

6268

6369
@dataclass
@@ -119,8 +125,8 @@ def parse_args():
119125
others_group.add_argument(
120126
"--num-io-threads",
121127
type=int,
122-
default=4,
123-
help="Number of threads for async I/O operations",
128+
default=None,
129+
help="Number of threads for async I/O operations (default: all of CPU cores).",
124130
)
125131
others_group.add_argument(
126132
"--num-workers", type=int, default=4, help="Number of workers for DataLoader"
@@ -137,6 +143,17 @@ def parse_args():
137143
default=2000,
138144
help="Number of files per subdirectory.",
139145
)
146+
others_group.add_argument(
147+
"--compress",
148+
action="store_true",
149+
help="Compress hidden state files on disk (gzip).",
150+
)
151+
others_group.add_argument(
152+
"--compression-level",
153+
type=int,
154+
default=6,
155+
help="Gzip compression level (1-9).",
156+
)
140157

141158
sglang_group = parser.add_argument_group("sglang")
142159
SGLangBackendArgs.add_args(sglang_group)
@@ -211,6 +228,8 @@ def __init__(
211228
num_io_threads: int = 4,
212229
io_queue_size: int = 50,
213230
file_group_size: int = 2000,
231+
compress: bool = False,
232+
compression_level: int = 6,
214233
):
215234
"""
216235
Args:
@@ -227,6 +246,9 @@ def __init__(
227246
self.num_io_threads = num_io_threads
228247
self.io_queue_size = io_queue_size
229248
self.file_group_size = file_group_size
249+
self.compress = compress
250+
self.compression_level = compression_level
251+
self.file_extension = ".ckpt.gz" if self.compress else ".ckpt"
230252

231253
# progress bar should only shown on TP rank = 0
232254
self.show_progress = dist.get_rank(get_tp_group()) == 0
@@ -278,7 +300,13 @@ def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None:
278300
)
279301
return
280302

281-
torch.save(asdict(data_point), output_file)
303+
if self.compress:
304+
with gzip.open(
305+
output_file, "wb", compresslevel=self.compression_level
306+
) as f:
307+
torch.save(asdict(data_point), f)
308+
else:
309+
torch.save(asdict(data_point), output_file)
282310

283311
def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None:
284312
"""
@@ -361,14 +389,22 @@ def _check_existing_files_batch(
361389
return [False] * len(global_indices)
362390

363391
def check_single_file(idx):
364-
return os.path.exists(self._get_file_path(output_path, idx))
392+
if os.path.exists(self._get_file_path(output_path, idx)):
393+
return True
394+
legacy_ckpt = self._get_file_path(output_path, idx, extension=".ckpt")
395+
compressed_ckpt = self._get_file_path(
396+
output_path, idx, extension=".ckpt.gz"
397+
)
398+
return os.path.exists(legacy_ckpt) or os.path.exists(compressed_ckpt)
365399

366400
# Parallel file existence check
367401
with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor:
368402
exists = list(executor.map(check_single_file, global_indices))
369403
return exists
370404

371-
def _get_file_path(self, output_path: str, idx: int) -> str:
405+
def _get_file_path(
406+
self, output_path: str, idx: int, extension: Optional[str] = None
407+
) -> str:
372408
"""
373409
A helper function to get the standard file path for the data point with the given index.
374410
@@ -379,9 +415,10 @@ def _get_file_path(self, output_path: str, idx: int) -> str:
379415
Returns:
380416
str: The file path for the data point.
381417
"""
418+
ext = self.file_extension if extension is None else extension
382419
group_idx = (idx // self.file_group_size) * self.file_group_size
383420
grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}"
384-
return os.path.join(output_path, grouped_subdir, f"data_{idx}.ckpt")
421+
return os.path.join(output_path, grouped_subdir, f"data_{idx}{ext}")
385422

386423
@torch.no_grad()
387424
def generate(
@@ -469,7 +506,6 @@ def generate(
469506
filtered_batch_gpu = {
470507
k: v.cuda(non_blocking=True) for k, v in filtered_batch.items()
471508
}
472-
473509
_, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend(
474510
**filtered_batch_gpu,
475511
return_last_hidden_states=True,
@@ -550,9 +586,12 @@ def main():
550586
args.aux_hidden_states_layers = [
551587
int(x) for x in args.aux_hidden_states_layers.split(",")
552588
]
553-
589+
if args.num_io_threads is None:
590+
cpu_cores = os.cpu_count() or 1
591+
args.num_io_threads = max(1, cpu_cores)
554592
# Initialize distributed environment (TP + DP)
555593
init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
594+
print_args_with_dots(args)
556595

557596
# Build target model (with TP)
558597
target_model_config = AutoConfig.from_pretrained(
@@ -574,10 +613,17 @@ def main():
574613
assert os.path.exists(
575614
args.data_path
576615
), f"Dataset path {args.data_path} does not exist"
577-
dataset = load_dataset("json", data_files=args.data_path)["train"]
616+
dataset = Dataset.from_generator(
617+
generator=safe_conversations_generator,
618+
gen_kwargs={"file_path": args.data_path},
619+
cache_dir=os.path.join(
620+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
621+
"cache",
622+
"hf_dataset",
623+
),
624+
)
578625
if args.num_samples is not None:
579626
dataset = dataset.select(range(args.num_samples))
580-
581627
# Tokenizer and cache key
582628
tokenizer = AutoTokenizer.from_pretrained(
583629
args.target_model_path, trust_remote_code=True
@@ -643,10 +689,12 @@ def main():
643689
# Pass configurable arguments from args if needed
644690
with HiddenStatesGenerator(
645691
target_model,
646-
args.enable_aux_hidden_states,
692+
enable_aux_hidden_states=args.enable_aux_hidden_states,
647693
num_io_threads=args.num_io_threads,
648694
io_queue_size=args.io_queue_size,
649695
file_group_size=args.file_group_size,
696+
compress=args.compress,
697+
compression_level=args.compression_level,
650698
# Other params like io_queue_size can also be added to argparse
651699
) as hidden_states_generator:
652700

scripts/regenerate_train_data.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
This script will re-generate the dataset from target model,
33
which better aligns the draft model with the target model’s output distribution.
44
@@ -29,6 +29,7 @@
2929

3030
import argparse
3131
import json
32+
import os
3233
import random
3334
from concurrent.futures import ThreadPoolExecutor
3435
from typing import Any, Dict, List
@@ -113,6 +114,11 @@ def parse_arguments():
113114
default=None,
114115
help="The number of samples to regenerate, if not provided, all samples will be regenerated",
115116
)
117+
data_group.add_argument(
118+
"--resume",
119+
action="store_true",
120+
help="Resume from existing output file, skip already processed samples",
121+
)
116122

117123
# sglang server
118124
server_group = parser.add_argument_group("sglang server")
@@ -252,9 +258,29 @@ def main():
252258
print(f" API URL: {args.server_address}")
253259
print(f" Input file: {args.input_file_path}")
254260
print(f" Output file: {args.output_file_path}")
261+
print(f" Resume mode: {args.resume}")
255262
print("-" * 50)
256263
total_lines = sum(1 for _ in open(args.input_file_path))
257264

265+
skip_lines = 0
266+
error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl")
267+
268+
if args.resume and os.path.exists(args.output_file_path):
269+
existing_success = sum(1 for _ in open(args.output_file_path))
270+
existing_error = 0
271+
if os.path.exists(error_file_path):
272+
existing_error = sum(1 for _ in open(error_file_path))
273+
skip_lines = existing_success + existing_error
274+
print(f"Resume mode enabled:")
275+
print(f" Found {existing_success} successful samples in output file")
276+
print(f" Found {existing_error} error samples in error file")
277+
print(f" Skipping first {skip_lines} input samples")
278+
print("-" * 50)
279+
280+
if skip_lines >= total_lines:
281+
print(f"All {total_lines} samples already processed. Nothing to do.")
282+
return
283+
258284
# test all server addresses
259285
valid_server_addresses = []
260286
for server_address in args.server_address:
@@ -279,11 +305,14 @@ def main():
279305
)
280306
print("-" * 50)
281307

282-
# create error file path if not exists
283-
error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl")
308+
# Determine file open mode based on resume flag
309+
file_mode = "a" if (args.resume and skip_lines > 0) else "w"
284310
print(
285311
f"Regenerating dataset and saving the output to {args.output_file_path} and error log to {error_file_path}"
286312
)
313+
print(
314+
f"File open mode: {file_mode} ({'append' if file_mode == 'a' else 'overwrite'})"
315+
)
287316
print("-" * 50)
288317
context_token_sum = 0
289318
context_token_min = None
@@ -294,18 +323,24 @@ def main():
294323
# Create progress bar
295324
with (
296325
open(args.input_file_path, "r") as input_file,
297-
open(args.output_file_path, "w") as output_file_handle,
298-
open(error_file_path, "w") as error_file_handle,
326+
open(args.output_file_path, file_mode) as output_file_handle,
327+
open(error_file_path, file_mode) as error_file_handle,
299328
):
300329
executor = ThreadPoolExecutor(
301330
max_workers=args.concurrency * len(valid_server_addresses)
302331
)
303332
waiting_queue = {
304333
server_address: [] for server_address in valid_server_addresses
305334
}
306-
pbar = tqdm(total=total_lines, desc="Processing")
335+
pbar = tqdm(total=total_lines, desc="Processing", initial=skip_lines)
307336
start_server_index = 0
308337

338+
if skip_lines > 0:
339+
print(f"Skipping {skip_lines} already processed samples...")
340+
for _ in range(skip_lines):
341+
next(input_file, None)
342+
print(f"Resuming from sample {skip_lines + 1}")
343+
309344
for line in input_file:
310345
if (
311346
args.num_samples is not None
@@ -398,9 +433,18 @@ def main():
398433
else:
399434
print("No successful examples to compute context length statistics.")
400435

401-
print(
402-
f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed."
403-
)
436+
total_processed = success_samples + error_samples
437+
if skip_lines > 0:
438+
print(f"\nResume processing completed!")
439+
print(f" Previously processed: {skip_lines}")
440+
print(
441+
f" Newly processed: {total_processed} ({success_samples} success, {error_samples} failed)"
442+
)
443+
print(f" Total: {skip_lines + total_processed}")
444+
else:
445+
print(
446+
f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed."
447+
)
404448

405449

406450
if __name__ == "__main__":

0 commit comments

Comments
 (0)