@@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
1919 ("head_dim" , ("hidden_size" , "num_attention_heads" ), "use_mambapy" ),
2020 "num_hidden_layers" ,
2121 ("num_key_value_heads" , "num_attention_heads" , "use_mambapy" ),
22- "intermediate_size" ,
2322 "hidden_size" ,
2423 "vocab_size" ,
2524 )
2625 if config .__class__ .__name__ == "FalconMambaConfig" :
27- check_hasattr (config , "conv_kernel" , "state_size" ) # 4 and 8
26+ check_hasattr (config , "conv_kernel" , "state_size" , "intermediate_size" ) # 4 and 8
2827 kwargs = dict (
2928 num_hidden_layers = min (config .num_hidden_layers , 2 ),
3029 intermediate_size = 256 if config is None else min (512 , config .intermediate_size ),
@@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
4443 if hasattr (config , "num_key_value_heads" )
4544 else config .num_attention_heads
4645 ),
47- intermediate_size = (
48- min (config .intermediate_size , 24576 // 4 )
49- if config .intermediate_size % 4 == 0
50- else config .intermediate_size
51- ),
5246 hidden_size = (
5347 min (config .hidden_size , 3072 // 4 )
5448 if config .hidden_size % 4 == 0
5549 else config .hidden_size
5650 ),
5751 )
52+ if config is None or hasattr (config , "intermediate_size" ):
53+ kwargs ["intermediate_size" ] = (
54+ min (config .intermediate_size , 24576 // 4 )
55+ if config .intermediate_size % 4 == 0
56+ else config .intermediate_size
57+ )
5858 update_config (config , kwargs )
5959 return kwargs
6060
@@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
228228 "vocab_size" ,
229229 ("num_attention_heads" , "use_mambapy" ),
230230 ("num_key_value_heads" , "num_attention_heads" , "use_mambapy" ),
231- "intermediate_size" ,
232231 "hidden_size" ,
233232 )
234233 if config .__class__ .__name__ == "FalconMambaConfig" :
235- check_hasattr (config , "conv_kernel" , "state_size" ) # 4 and 8
234+ check_hasattr (config , "conv_kernel" , "state_size" , "intermediate_size" ) # 4 and 8
236235 kwargs = dict (
237236 batch_size = 2 ,
238237 sequence_length = 30 ,
@@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
263262 if config is None
264263 else _pick (config , "num_key_value_heads" , "num_attention_heads" )
265264 ),
266- intermediate_size = 1024 if config is None else config .intermediate_size ,
267265 hidden_size = 512 if config is None else config .hidden_size ,
268266 )
267+ if config is None or hasattr (config , "intermediate_size" ):
268+ kwargs ["intermediate_size" ] = (
269+ 1024 if config is None else config .intermediate_size ,
270+ )
271+
269272 return kwargs , get_inputs
0 commit comments