10
10
T = TypeVar ('T' )
11
11
no_default = object ()
12
12
13
- def read (input_dict : dict [str , Any ], expected_type : type , keys : str | list [str ], default = no_default ) -> T :
13
+ def read (input_dict : dict [str , Any ], expected_type : type | list [type ], keys : str | list [str ], default = no_default ) -> T :
14
+
15
+ expected_types = expected_type if isinstance (expected_type , list ) else [expected_type ]
14
16
15
17
if isinstance (keys , str ): keys = [keys ]
16
18
@@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str],
34
36
if expected_type == int and isinstance (x , float ) and x == int (x ):
35
37
x = int (x )
36
38
37
- if isinstance ( x , expected_type ) :
38
- return cast ( T , x )
39
- else :
40
- raise TypeError (f"Value for { key } is not of expected type { expected_type } " )
39
+ for t in expected_types :
40
+ if isinstance ( x , t ):
41
+ return cast ( T , x )
42
+ raise TypeError (f"Value for { key } is not of expected type { expected_type } " )
41
43
42
44
if default != no_default : return default
43
45
raise ValueError (f"Missing any of the following keys: { keys } " )
@@ -105,7 +107,10 @@ class ExLlamaV2Config:
105
107
attn_logit_softcapping : float | None
106
108
sliding_window : int
107
109
norm_head : int | None
108
-
110
+ l3_rope_factor : float | None
111
+ l3_rope_low_freq_factor : float | None
112
+ l3_rope_high_freq_factor : float | None
113
+ l3_rope_original_max_position_embeddings : int | None
109
114
checkpoint_fused_mlp : bool
110
115
checkpoint_offset_qzeros : bool
111
116
@@ -191,10 +196,13 @@ def prepare(self, no_tensors: bool = False):
191
196
# Vocab params
192
197
193
198
self .bos_token_id = read (read_config , int , "bos_token_id" , None ) # 1
194
- self .eos_token_id = read (read_config , int , "eos_token_id" , None ) # 2
199
+ self .eos_token_id = read (read_config , [ int , list ] , "eos_token_id" , None ) # 2
195
200
self .pad_token_id = read (read_config , int , "pad_token_id" , None ) # 0
196
201
self .vocab_size = read (read_config , int , "vocab_size" )
197
202
203
+ if isinstance (self .eos_token_id , list ):
204
+ self .eos_token_id = self .eos_token_id [0 ] # TODO: Figure out a way to maybe use all the EOS tokens somehow
205
+
198
206
# Standard params
199
207
200
208
self .initializer_range = read (read_config , float , ["initializer_range" ])
@@ -287,6 +295,13 @@ def prepare(self, no_tensors: bool = False):
287
295
self .alt_rope_method = "su"
288
296
# if scaling_type == "yarn":
289
297
# self.scale_alpha_value = factor
298
+ rope_type = rs .get ("rope_type" , None )
299
+ if rope_type == "llama3" :
300
+ self .alt_rope_method = "llama3"
301
+ self .l3_rope_factor = rs ["factor" ]
302
+ self .l3_rope_low_freq_factor = rs ["low_freq_factor" ]
303
+ self .l3_rope_high_freq_factor = rs ["high_freq_factor" ]
304
+ self .l3_rope_original_max_position_embeddings = rs ["original_max_position_embeddings" ]
290
305
291
306
# Checkpoint format (for GPTQ models)
292
307
0 commit comments