@@ -24,9 +24,16 @@ class ModelConfig:
2424 downloading the model and tokenizer.
2525 download_dir: Directory to download and load the weights, default to the
2626 default cache directory of huggingface.
27- use_np_weights: Save a numpy copy of model weights for faster loading.
28- This can increase the disk usage by up to 2x.
29- use_dummy_weights: Use dummy values for model weights (for profiling).
27+ load_format: The format of the model weights to load:
28+ "auto" will try to load the weights in the safetensors format and
29+ fall back to the pytorch bin format if safetensors format is
30+ not available.
31+ "pt" will load the weights in the pytorch bin format.
32+ "safetensors" will load the weights in the safetensors format.
33+ "npcache" will load the weights in pytorch format and store
34+ a numpy cache to speed up the loading.
35+ "dummy" will initialize the weights with random values, which is
36+ mainly for profiling.
3037 dtype: Data type for model weights and activations. The "auto" option
3138 will use FP16 precision for FP32 and FP16 models, and BF16 precision
3239 for BF16 models.
@@ -40,8 +47,7 @@ def __init__(
4047 tokenizer_mode : str ,
4148 trust_remote_code : bool ,
4249 download_dir : Optional [str ],
43- use_np_weights : bool ,
44- use_dummy_weights : bool ,
50+ load_format : str ,
4551 dtype : str ,
4652 seed : int ,
4753 ) -> None :
@@ -50,14 +56,24 @@ def __init__(
5056 self .tokenizer_mode = tokenizer_mode
5157 self .trust_remote_code = trust_remote_code
5258 self .download_dir = download_dir
53- self .use_np_weights = use_np_weights
54- self .use_dummy_weights = use_dummy_weights
59+ self .load_format = load_format
5560 self .seed = seed
5661
5762 self .hf_config = get_config (model , trust_remote_code )
5863 self .dtype = _get_and_verify_dtype (self .hf_config , dtype )
64+ self ._verify_load_format ()
5965 self ._verify_tokenizer_mode ()
6066
67+ def _verify_load_format (self ) -> None :
68+ load_format = self .load_format .lower ()
69+ if load_format not in [
70+ "auto" , "pt" , "safetensors" , "npcache" , "dummy"
71+ ]:
72+ raise ValueError (
73+ f"Unknown load format: { self .load_format } . Must be one of "
74+ "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." )
75+ self .load_format = load_format
76+
6177 def _verify_tokenizer_mode (self ) -> None :
6278 tokenizer_mode = self .tokenizer_mode .lower ()
6379 if tokenizer_mode not in ["auto" , "slow" ]:
0 commit comments