Skip to content

Commit 943342a

Browse files
committed
Fix for missing intermediate_size
1 parent 16fddf2 commit 943342a

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

onnx_diagnostic/tasks/text_generation.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
1919
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
2020
"num_hidden_layers",
2121
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
22-
"intermediate_size",
2322
"hidden_size",
2423
"vocab_size",
2524
)
2625
if config.__class__.__name__ == "FalconMambaConfig":
27-
check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
26+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
2827
kwargs = dict(
2928
num_hidden_layers=min(config.num_hidden_layers, 2),
3029
intermediate_size=256 if config is None else min(512, config.intermediate_size),
@@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
4443
if hasattr(config, "num_key_value_heads")
4544
else config.num_attention_heads
4645
),
47-
intermediate_size=(
48-
min(config.intermediate_size, 24576 // 4)
49-
if config.intermediate_size % 4 == 0
50-
else config.intermediate_size
51-
),
5246
hidden_size=(
5347
min(config.hidden_size, 3072 // 4)
5448
if config.hidden_size % 4 == 0
5549
else config.hidden_size
5650
),
5751
)
52+
if config is None or hasattr(config, "intermediate_size"):
53+
kwargs["intermediate_size"] = (
54+
min(config.intermediate_size, 24576 // 4)
55+
if config.intermediate_size % 4 == 0
56+
else config.intermediate_size
57+
)
5858
update_config(config, kwargs)
5959
return kwargs
6060

@@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
228228
"vocab_size",
229229
("num_attention_heads", "use_mambapy"),
230230
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
231-
"intermediate_size",
232231
"hidden_size",
233232
)
234233
if config.__class__.__name__ == "FalconMambaConfig":
235-
check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
234+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
236235
kwargs = dict(
237236
batch_size=2,
238237
sequence_length=30,
@@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
263262
if config is None
264263
else _pick(config, "num_key_value_heads", "num_attention_heads")
265264
),
266-
intermediate_size=1024 if config is None else config.intermediate_size,
267265
hidden_size=512 if config is None else config.hidden_size,
268266
)
267+
if config is None or hasattr(config, "intermediate_size"):
268+
kwargs["intermediate_size"] = (
269+
1024 if config is None else config.intermediate_size,
270+
)
271+
269272
return kwargs, get_inputs

0 commit comments

Comments
 (0)