3434
3535import argparse
3636import gc
37+ import gzip
3738import hashlib
3839import os
3940from concurrent .futures import ThreadPoolExecutor
4647from tqdm import tqdm
4748from transformers import AutoConfig , AutoProcessor , AutoTokenizer
4849
49- from datasets import load_dataset
50+ from datasets import Dataset
5051from specforge .args import SGLangBackendArgs
5152from specforge .data import build_eagle3_dataset , prepare_dp_dataloaders
5253from specforge .distributed import (
5758 is_tp_rank_0 ,
5859)
5960from specforge .modeling .target import Eagle3TargetModel , get_eagle3_target_model
60- from specforge .utils import print_with_rank , rank_0_priority
61+ from specforge .utils import (
62+ print_args_with_dots ,
63+ print_with_rank ,
64+ rank_0_priority ,
65+ safe_conversations_generator ,
66+ )
6167
6268
6369@dataclass
@@ -119,8 +125,8 @@ def parse_args():
119125 others_group .add_argument (
120126 "--num-io-threads" ,
121127 type = int ,
122- default = 4 ,
123- help = "Number of threads for async I/O operations" ,
128+ default = None ,
129+ help = "Number of threads for async I/O operations (default: all of CPU cores). " ,
124130 )
125131 others_group .add_argument (
126132 "--num-workers" , type = int , default = 4 , help = "Number of workers for DataLoader"
@@ -137,6 +143,17 @@ def parse_args():
137143 default = 2000 ,
138144 help = "Number of files per subdirectory." ,
139145 )
146+ others_group .add_argument (
147+ "--compress" ,
148+ action = "store_true" ,
149+ help = "Compress hidden state files on disk (gzip)." ,
150+ )
151+ others_group .add_argument (
152+ "--compression-level" ,
153+ type = int ,
154+ default = 6 ,
155+ help = "Gzip compression level (1-9)." ,
156+ )
140157
141158 sglang_group = parser .add_argument_group ("sglang" )
142159 SGLangBackendArgs .add_args (sglang_group )
@@ -211,6 +228,8 @@ def __init__(
211228 num_io_threads : int = 4 ,
212229 io_queue_size : int = 50 ,
213230 file_group_size : int = 2000 ,
231+ compress : bool = False ,
232+ compression_level : int = 6 ,
214233 ):
215234 """
216235 Args:
@@ -227,6 +246,9 @@ def __init__(
227246 self .num_io_threads = num_io_threads
228247 self .io_queue_size = io_queue_size
229248 self .file_group_size = file_group_size
249+ self .compress = compress
250+ self .compression_level = compression_level
251+ self .file_extension = ".ckpt.gz" if self .compress else ".ckpt"
230252
231253 # progress bar should only shown on TP rank = 0
232254 self .show_progress = dist .get_rank (get_tp_group ()) == 0
@@ -278,7 +300,13 @@ def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None:
278300 )
279301 return
280302
281- torch .save (asdict (data_point ), output_file )
303+ if self .compress :
304+ with gzip .open (
305+ output_file , "wb" , compresslevel = self .compression_level
306+ ) as f :
307+ torch .save (asdict (data_point ), f )
308+ else :
309+ torch .save (asdict (data_point ), output_file )
282310
283311 def _save_tensor_async (self , data_point : DataPoint , output_file : str ) -> None :
284312 """
@@ -361,14 +389,22 @@ def _check_existing_files_batch(
361389 return [False ] * len (global_indices )
362390
363391 def check_single_file (idx ):
364- return os .path .exists (self ._get_file_path (output_path , idx ))
392+ if os .path .exists (self ._get_file_path (output_path , idx )):
393+ return True
394+ legacy_ckpt = self ._get_file_path (output_path , idx , extension = ".ckpt" )
395+ compressed_ckpt = self ._get_file_path (
396+ output_path , idx , extension = ".ckpt.gz"
397+ )
398+ return os .path .exists (legacy_ckpt ) or os .path .exists (compressed_ckpt )
365399
366400 # Parallel file existence check
367401 with ThreadPoolExecutor (max_workers = self .num_io_threads ) as executor :
368402 exists = list (executor .map (check_single_file , global_indices ))
369403 return exists
370404
371- def _get_file_path (self , output_path : str , idx : int ) -> str :
405+ def _get_file_path (
406+ self , output_path : str , idx : int , extension : Optional [str ] = None
407+ ) -> str :
372408 """
373409 A helper function to get the standard file path for the data point with the given index.
374410
@@ -379,9 +415,10 @@ def _get_file_path(self, output_path: str, idx: int) -> str:
379415 Returns:
380416 str: The file path for the data point.
381417 """
418+ ext = self .file_extension if extension is None else extension
382419 group_idx = (idx // self .file_group_size ) * self .file_group_size
383420 grouped_subdir = f"rows_{ group_idx } -{ group_idx + self .file_group_size } "
384- return os .path .join (output_path , grouped_subdir , f"data_{ idx } .ckpt " )
421+ return os .path .join (output_path , grouped_subdir , f"data_{ idx } { ext } " )
385422
386423 @torch .no_grad ()
387424 def generate (
@@ -469,7 +506,6 @@ def generate(
469506 filtered_batch_gpu = {
470507 k : v .cuda (non_blocking = True ) for k , v in filtered_batch .items ()
471508 }
472-
473509 _ , _ , aux_hidden_states_list , last_hidden_states_list = self .model .extend (
474510 ** filtered_batch_gpu ,
475511 return_last_hidden_states = True ,
@@ -550,9 +586,12 @@ def main():
550586 args .aux_hidden_states_layers = [
551587 int (x ) for x in args .aux_hidden_states_layers .split ("," )
552588 ]
553-
589+ if args .num_io_threads is None :
590+ cpu_cores = os .cpu_count () or 1
591+ args .num_io_threads = max (1 , cpu_cores )
554592 # Initialize distributed environment (TP + DP)
555593 init_distributed (timeout = args .dist_timeout , tp_size = args .tp_size )
594+ print_args_with_dots (args )
556595
557596 # Build target model (with TP)
558597 target_model_config = AutoConfig .from_pretrained (
@@ -574,10 +613,17 @@ def main():
574613 assert os .path .exists (
575614 args .data_path
576615 ), f"Dataset path { args .data_path } does not exist"
577- dataset = load_dataset ("json" , data_files = args .data_path )["train" ]
616+ dataset = Dataset .from_generator (
617+ generator = safe_conversations_generator ,
618+ gen_kwargs = {"file_path" : args .data_path },
619+ cache_dir = os .path .join (
620+ os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
621+ "cache" ,
622+ "hf_dataset" ,
623+ ),
624+ )
578625 if args .num_samples is not None :
579626 dataset = dataset .select (range (args .num_samples ))
580-
581627 # Tokenizer and cache key
582628 tokenizer = AutoTokenizer .from_pretrained (
583629 args .target_model_path , trust_remote_code = True
@@ -643,10 +689,12 @@ def main():
643689 # Pass configurable arguments from args if needed
644690 with HiddenStatesGenerator (
645691 target_model ,
646- args .enable_aux_hidden_states ,
692+ enable_aux_hidden_states = args .enable_aux_hidden_states ,
647693 num_io_threads = args .num_io_threads ,
648694 io_queue_size = args .io_queue_size ,
649695 file_group_size = args .file_group_size ,
696+ compress = args .compress ,
697+ compression_level = args .compression_level ,
650698 # Other params like io_queue_size can also be added to argparse
651699 ) as hidden_states_generator :
652700
0 commit comments