From fc28812f7327d735392a2775f81743d2750e2ffd Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Oct 2025 18:55:07 +0200 Subject: [PATCH 1/2] Add a set of inputs with batch size 1 --- _unittests/ut_tasks/test_tasks.py | 23 +++++++++++++++++++++++ onnx_diagnostic/tasks/text_generation.py | 15 +++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 6815e8dc..1be4b442 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -48,6 +48,7 @@ def test_text_generation(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) + @hide_stdout() def test_text_generation_empty_cache(self): mid = "arnir0/Tiny-LLM" data = get_untrained_model_with_inputs(mid, add_second_input=True) @@ -69,6 +70,28 @@ def test_text_generation_empty_cache(self): got = ep.module()(**torch_deepcopy(inputs)) self.assertEqualArrayAny(expected, got) + @hide_stdout() + def test_text_generation_batch1(self): + mid = "arnir0/Tiny-LLM" + data = get_untrained_model_with_inputs(mid, add_second_input=True) + model, inputs = data["model"], data["inputs"] + self.assertIn("inputs_batch1", data) + empty_inputs = torch_deepcopy(data["inputs_batch1"]) + model(**torch_deepcopy(empty_inputs)) + expected = model(**torch_deepcopy(inputs)) + self.assertEqual( + {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) + ) + with torch_export_patches(patch_transformers=True, verbose=1): + ep = torch.export.export( + model, + (), + kwargs=torch_deepcopy(inputs), + dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), + ) + got = ep.module()(**torch_deepcopy(inputs)) + self.assertEqualArrayAny(expected, got) + @hide_stdout() def test_automatic_speech_recognition_float32(self): mid = "openai/whisper-tiny" diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 964e0462..db930b08 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -284,6 +284,21 @@ def get_inputs( add_second_input=0, **kwargs, )["inputs"] + res["inputs_batch1"] = get_inputs( + model=model, + config=config, + dummy_max_token_id=dummy_max_token_id, + num_hidden_layers=num_hidden_layers, + batch_size=1, + sequence_length=sequence_length, + sequence_length2=sequence_length2, + dynamic_rope=dynamic_rope, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + cls_cache=cls_cache, + add_second_input=0, + **kwargs, + )["inputs"] return res From 9918d12768dd3fbd95a0d21182e2669bc0bacd62 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 7 Oct 2025 14:25:23 +0200 Subject: [PATCH 2/2] misspelling --- onnx_diagnostic/torch_models/hghub/model_inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 2c65d1a4..b7f0226c 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -178,7 +178,7 @@ def get_untrained_model_with_inputs( if verbose: print( - f"[get_untrained_model_with_inputs] package_source={package_source.__name__} é" + f"[get_untrained_model_with_inputs] package_source={package_source.__name__} " f"from {package_source.__file__}" ) if use_pretrained: