Skip to content

Commit 9568e18

Browse files
committed
add a test
1 parent 6fea147 commit 9568e18

File tree

3 files changed

+83
-4
lines changed

3 files changed

+83
-4
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import (
3+
ExtTestCase,
4+
hide_stdout,
5+
requires_transformers,
6+
requires_torch,
7+
)
8+
from onnx_diagnostic.torch_models.validate import validate_model
9+
10+
11+
class TestTasksMaskGeneration(ExtTestCase):
12+
@hide_stdout()
13+
@requires_transformers("4.53")
14+
@requires_torch("2.7.99")
15+
def test_text_generation(self):
16+
mid = "microsoft/phi-2"
17+
summary, data = validate_model(
18+
mid,
19+
do_run=True,
20+
verbose=10,
21+
exporter="onnx-dynamo",
22+
dump_folder="dump_test/microsoft_phi-2",
23+
inputs2=True,
24+
patch=True,
25+
)
26+
self.assertIsInstance(summary, dict)
27+
# token generation
28+
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
29+
# prompt processing
30+
self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2)
31+
# multi-turn conversation
32+
self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2)
33+
self.assertIsInstance(data, dict)
34+
onnx_filename = data["onnx_filename"]
35+
self.assertExists(onnx_filename)
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main(verbosity=2)

onnx_diagnostic/tasks/text_generation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,24 @@ def get_inputs(
278278
add_second_input=0,
279279
**kwargs,
280280
)["inputs"]
281+
# multi-turn conversation
282+
# prompt-processing -> token-generation(loop output) ->
283+
# prompt-processing from the loop output
284+
res["inputs3"] = get_inputs(
285+
model=model,
286+
config=config,
287+
dummy_max_token_id=dummy_max_token_id,
288+
num_hidden_layers=num_hidden_layers,
289+
batch_size=1,
290+
past_sequence_length=32,
291+
sequence_length=8,
292+
dynamic_rope=dynamic_rope,
293+
num_key_value_heads=num_key_value_heads,
294+
head_dim=head_dim,
295+
cls_cache=cls_cache,
296+
add_second_input=0,
297+
**kwargs,
298+
)["inputs"]
281299
return res
282300

283301

onnx_diagnostic/torch_models/validate.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ def validate_model(
573573
if verbose:
574574
print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
575575
print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}")
576+
# NOTE: The dynamic_shapes is always the same across inputs sets
576577
if inputs2:
577578
assert (
578579
"inputs2" in data
@@ -583,6 +584,14 @@ def validate_model(
583584
model=data["model"],
584585
dynamic_shapes=data["dynamic_shapes"],
585586
)
587+
# NOTE: text-generation tests 3rd inputs for multi-turn conversation
588+
if "inputs3" in data:
589+
data["inputs3"], _ = filter_inputs(
590+
data["inputs3"],
591+
drop_names=drop_inputs,
592+
model=data["model"],
593+
dynamic_shapes=data["dynamic_shapes"],
594+
)
586595

587596
if not empty(dtype):
588597
if isinstance(dtype, str):
@@ -594,6 +603,8 @@ def validate_model(
594603
summary["model_dtype"] = str(dtype)
595604
if "inputs2" in data:
596605
data["inputs2"] = to_any(data["inputs2"], dtype) # type: ignore
606+
if "inputs3" in data:
607+
data["inputs3"] = to_any(data["inputs3"], dtype) # type: ignore
597608

598609
if not empty(device):
599610
if verbose:
@@ -603,6 +614,8 @@ def validate_model(
603614
summary["model_device"] = str(device)
604615
if "inputs2" in data:
605616
data["inputs2"] = to_any(data["inputs2"], device) # type: ignore
617+
if "inputs3" in data:
618+
data["inputs3"] = to_any(data["inputs3"], device) # type: ignore
606619

607620
for k in ["task", "size", "n_weights"]:
608621
summary[f"model_{k.replace('_','')}"] = data[k]
@@ -638,10 +651,12 @@ def validate_model(
638651
_validate_do_run_model(
639652
data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
640653
)
641-
if inputs2:
642-
_validate_do_run_model(
643-
data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet
644-
)
654+
_validate_do_run_model(
655+
data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet
656+
)
657+
_validate_do_run_model(
658+
data, summary, "inputs3", "run3", "run_expected3", verbose, 1, 0, quiet
659+
)
645660

646661
if exporter:
647662
print(
@@ -899,6 +914,10 @@ def _validate_do_run_model(
899914
if verbose:
900915
print(f"[validate_model] -- run the model inputs={key!r}...")
901916
print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}")
917+
if key not in data:
918+
if verbose:
919+
print(f"[validate_model] input; {key!r} not defined, skip.")
920+
return
902921
# We make a copy of the input just in case the model modifies them inplace
903922
hash_inputs = string_type(data[key], with_shape=True)
904923
inputs = torch_deepcopy(data[key])
@@ -1329,6 +1348,9 @@ def _mk(key, flavour=flavour):
13291348
keys = [("inputs", "run_expected", "")]
13301349
if inputs2:
13311350
keys.append(("inputs2", "run_expected2", "2"))
1351+
# text-generation tests multi-turn conversation as 3rd inputs
1352+
if "inputs3" in data:
1353+
keys.append(("inputs3", "run_expected3", "3"))
13321354
for k_input, k_expected, suffix in keys:
13331355
# make_feeds
13341356
if verbose:

0 commit comments

Comments
 (0)