Skip to content

Commit fc28812

Browse files
committed
Add a set of inputs with batch size 1
1 parent 1006017 commit fc28812

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

_unittests/ut_tasks/test_tasks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_text_generation(self):
4848
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4949
)
5050

51+
@hide_stdout()
5152
def test_text_generation_empty_cache(self):
5253
mid = "arnir0/Tiny-LLM"
5354
data = get_untrained_model_with_inputs(mid, add_second_input=True)
@@ -69,6 +70,28 @@ def test_text_generation_empty_cache(self):
6970
got = ep.module()(**torch_deepcopy(inputs))
7071
self.assertEqualArrayAny(expected, got)
7172

73+
@hide_stdout()
74+
def test_text_generation_batch1(self):
75+
mid = "arnir0/Tiny-LLM"
76+
data = get_untrained_model_with_inputs(mid, add_second_input=True)
77+
model, inputs = data["model"], data["inputs"]
78+
self.assertIn("inputs_batch1", data)
79+
empty_inputs = torch_deepcopy(data["inputs_batch1"])
80+
model(**torch_deepcopy(empty_inputs))
81+
expected = model(**torch_deepcopy(inputs))
82+
self.assertEqual(
83+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
84+
)
85+
with torch_export_patches(patch_transformers=True, verbose=1):
86+
ep = torch.export.export(
87+
model,
88+
(),
89+
kwargs=torch_deepcopy(inputs),
90+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
91+
)
92+
got = ep.module()(**torch_deepcopy(inputs))
93+
self.assertEqualArrayAny(expected, got)
94+
7295
@hide_stdout()
7396
def test_automatic_speech_recognition_float32(self):
7497
mid = "openai/whisper-tiny"

onnx_diagnostic/tasks/text_generation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,21 @@ def get_inputs(
284284
add_second_input=0,
285285
**kwargs,
286286
)["inputs"]
287+
res["inputs_batch1"] = get_inputs(
288+
model=model,
289+
config=config,
290+
dummy_max_token_id=dummy_max_token_id,
291+
num_hidden_layers=num_hidden_layers,
292+
batch_size=1,
293+
sequence_length=sequence_length,
294+
sequence_length2=sequence_length2,
295+
dynamic_rope=dynamic_rope,
296+
num_key_value_heads=num_key_value_heads,
297+
head_dim=head_dim,
298+
cls_cache=cls_cache,
299+
add_second_input=0,
300+
**kwargs,
301+
)["inputs"]
287302
return res
288303

289304

0 commit comments

Comments
 (0)