diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 897c7a38..6057fc4e 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"), "num_hidden_layers", ("num_key_value_heads", "num_attention_heads", "use_mambapy"), - "intermediate_size", "hidden_size", "vocab_size", ) if config.__class__.__name__ == "FalconMambaConfig": - check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 + check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( num_hidden_layers=min(config.num_hidden_layers, 2), intermediate_size=256 if config is None else min(512, config.intermediate_size), @@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: if hasattr(config, "num_key_value_heads") else config.num_attention_heads ), - intermediate_size=( - min(config.intermediate_size, 24576 // 4) - if config.intermediate_size % 4 == 0 - else config.intermediate_size - ), hidden_size=( min(config.hidden_size, 3072 // 4) if config.hidden_size % 4 == 0 else config.hidden_size ), ) + if config is None or hasattr(config, "intermediate_size"): + kwargs["intermediate_size"] = ( + min(config.intermediate_size, 24576 // 4) + if config.intermediate_size % 4 == 0 + else config.intermediate_size + ) update_config(config, kwargs) return kwargs @@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "vocab_size", ("num_attention_heads", "use_mambapy"), ("num_key_value_heads", "num_attention_heads", "use_mambapy"), - "intermediate_size", "hidden_size", ) if config.__class__.__name__ == "FalconMambaConfig": - check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 + check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( batch_size=2, sequence_length=30, @@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: if config is None else _pick(config, "num_key_value_heads", "num_attention_heads") ), - intermediate_size=1024 if config is None else config.intermediate_size, hidden_size=512 if config is None else config.hidden_size, ) + if config is None or hasattr(config, "intermediate_size"): + kwargs["intermediate_size"] = ( + 1024 if config is None else config.intermediate_size, + ) + return kwargs, get_inputs