Skip to content

Commit b768139

Browse files
committed
Fix modelbuilder export
1 parent a6caa0a commit b768139

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,12 @@ def create_model_builder(
201201
arch_map = {
202202
"ChatGLMForConditionalGeneration": builder.ChatGLMModel,
203203
"ChatGLMModel": builder.ChatGLMModel,
204+
"Ernie4_5_ForCausalLM": builder.ErnieModel,
204205
"GemmaForCausalLM": builder.Gemma2Model,
205206
"Gemma3ForCausalLM": builder.Gemma3Model,
206207
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
207208
"GraniteForCausalLM": builder.GraniteModel,
209+
"GptOssForCausalLM": builder.GPTOSSModel,
208210
"LlamaForCausalLM": builder.LlamaModel,
209211
"MistralForCausalLM": builder.MistralModel,
210212
"NemotronForCausalLM": builder.NemotronModel,
@@ -235,6 +237,7 @@ def create_model_builder(
235237
"Phi4MMForCausalLM": builder.Phi4MMModel,
236238
"Qwen2ForCausalLM": builder.QwenModel,
237239
"Qwen3ForCausalLM": builder.Qwen3Model,
240+
"SmolLM3ForCausalLM": builder.SmolLM3Model,
238241
}
239242

240243
assert config.architectures[0] in arch_map, (
@@ -276,6 +279,8 @@ def _post(onnx_model):
276279
for key in text_config:
277280
if not hasattr(config, key):
278281
setattr(config, key, getattr(text_config, key))
282+
elif config.architectures[0] == "GptOssForCausalLM":
283+
delattr(config, "quantization_config")
279284
elif (
280285
config.architectures[0] == "PhiMoEForCausalLM"
281286
and config.max_position_embeddings != config.original_max_position_embeddings

onnx_diagnostic/torch_models/validate.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,15 @@ def validate_model(
496496
cpl = CoupleInputsDynamicShapes(
497497
tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
498498
)
499-
data[k] = cpl.change_dynamic_dimensions(
500-
desired_values=dict(batch=1), only_desired=True
501-
)
499+
if patch_kwargs.get("patch", False):
500+
with torch_export_patches(**patch_kwargs):
501+
data[k] = cpl.change_dynamic_dimensions(
502+
desired_values=dict(batch=1), only_desired=True
503+
)
504+
else:
505+
data[k] = cpl.change_dynamic_dimensions(
506+
desired_values=dict(batch=1), only_desired=True
507+
)
502508
if verbose:
503509
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
504510

0 commit comments

Comments
 (0)