@@ -24,9 +24,16 @@ class ModelConfig:
24
24
downloading the model and tokenizer.
25
25
download_dir: Directory to download and load the weights, default to the
26
26
default cache directory of huggingface.
27
- use_np_weights: Save a numpy copy of model weights for faster loading.
28
- This can increase the disk usage by up to 2x.
29
- use_dummy_weights: Use dummy values for model weights (for profiling).
27
+ load_format: The format of the model weights to load:
28
+ "auto" will try to load the weights in the safetensors format and
29
+ fall back to the pytorch bin format if safetensors format is
30
+ not available.
31
+ "pt" will load the weights in the pytorch bin format.
32
+ "safetensors" will load the weights in the safetensors format.
33
+ "npcache" will load the weights in pytorch format and store
34
+ a numpy cache to speed up the loading.
35
+ "dummy" will initialize the weights with random values, which is
36
+ mainly for profiling.
30
37
dtype: Data type for model weights and activations. The "auto" option
31
38
will use FP16 precision for FP32 and FP16 models, and BF16 precision
32
39
for BF16 models.
@@ -40,8 +47,7 @@ def __init__(
40
47
tokenizer_mode : str ,
41
48
trust_remote_code : bool ,
42
49
download_dir : Optional [str ],
43
- use_np_weights : bool ,
44
- use_dummy_weights : bool ,
50
+ load_format : str ,
45
51
dtype : str ,
46
52
seed : int ,
47
53
) -> None :
@@ -50,14 +56,24 @@ def __init__(
50
56
self .tokenizer_mode = tokenizer_mode
51
57
self .trust_remote_code = trust_remote_code
52
58
self .download_dir = download_dir
53
- self .use_np_weights = use_np_weights
54
- self .use_dummy_weights = use_dummy_weights
59
+ self .load_format = load_format
55
60
self .seed = seed
56
61
57
62
self .hf_config = get_config (model , trust_remote_code )
58
63
self .dtype = _get_and_verify_dtype (self .hf_config , dtype )
64
+ self ._verify_load_format ()
59
65
self ._verify_tokenizer_mode ()
60
66
67
+ def _verify_load_format (self ) -> None :
68
+ load_format = self .load_format .lower ()
69
+ if load_format not in [
70
+ "auto" , "pt" , "safetensors" , "npcache" , "dummy"
71
+ ]:
72
+ raise ValueError (
73
+ f"Unknown load format: { self .load_format } . Must be one of "
74
+ "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." )
75
+ self .load_format = load_format
76
+
61
77
def _verify_tokenizer_mode (self ) -> None :
62
78
tokenizer_mode = self .tokenizer_mode .lower ()
63
79
if tokenizer_mode not in ["auto" , "slow" ]:
0 commit comments