Skip to content

Commit 5ea455b

Browse files
committed
find names pattern
1 parent 15c41fe commit 5ea455b

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ Change Logs
44
0.8.0
55
+++++
66

7-
* :pr:`276`: implements onnx_generate which implements method generate for an onnx model,
8-
changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
7+
* :pr:`278`: implements ``onnx_generate_with_genai``
8+
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
9+
* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model,
910
* :pr:`275`: fixes function ``patched_vmap``
1011

1112
0.7.16

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import_model_builder,
1212
create_model_builder,
1313
save_model_builder,
14+
find_names_pattern,
1415
)
1516
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1617
from onnx_diagnostic.helpers.rt_helper import make_feeds
@@ -63,6 +64,11 @@ def test_model_builder_id(self):
6364
raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1")
6465
self.assertEqualAny(expected, got)
6566

67+
def test_find_names_pattern(self):
68+
pats = ["past_key_values_key_0", "past_key_values_key_1"]
69+
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats))
70+
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats[:1]))
71+
6672

6773
if __name__ == "__main__":
6874
unittest.main(verbosity=2)

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ def test_onnx_generate(self):
138138
model, input_ids[:1], 2, max_new_tokens=10, session=session
139139
)
140140
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
141-
print("******", res)
142-
print("******", expected)
143141
self.assertEqual(expected.dtype, torch.int64)
144142
self.assertEqual(expected.shape, (1, 13))
145143
self.assertEqualArray(expected, res)

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import copy
22
import importlib.util
33
import os
4+
import re
45
import requests
56
import sys
67
from pathlib import Path
7-
from typing import Any, Dict, Optional, Union
8+
from typing import Any, Dict, List, Optional, Union
89
from urllib.parse import urlparse
910
from onnx import ModelProto, TensorProto, load as load_model
1011

@@ -340,6 +341,26 @@ def _post(onnx_model):
340341
return onnx_model
341342

342343

344+
def find_names_pattern(names: List[str]) -> str:
345+
"""
346+
Finds a repeatable patterns in a list of names.
347+
It tries to locate the figures.
348+
349+
.. runpython::
350+
:showcode:
351+
352+
from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
353+
pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
354+
print(pattern)
355+
"""
356+
patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
357+
unique = set(patterns)
358+
assert (
359+
len(unique) == 1
360+
), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
361+
return patterns[0]
362+
363+
343364
def make_genai_config(
344365
config,
345366
onnx_filename: str,
@@ -398,8 +419,17 @@ def make_genai_config(
398419
"filename": onnx_filename,
399420
"head_size": shape[-1],
400421
"hidden_size": config.hidden_size,
401-
"inputs": input_names,
402-
"outputs": output_names,
422+
"inputs": {
423+
"input_ids": input_names[0],
424+
"attention_mask": input_names[1],
425+
"past_key_names": find_names_pattern(input_names[2::2]),
426+
"past_value_names": find_names_pattern(input_names[3::2]),
427+
},
428+
"outputs": {
429+
"logits": output_names[0],
430+
"present_key_names": find_names_pattern(output_names[1::2]),
431+
"present_value_names": find_names_pattern(output_names[2::2]),
432+
},
403433
"num_attention_heads": config.num_attention_heads,
404434
"num_hidden_layers": len(past_key_values) // 2,
405435
"num_key_value_heads": shape[1],

0 commit comments

Comments
 (0)