Skip to content

Commit c61e539

Browse files
authored
Implements a function to use onnxruntime-genai (#278)
* Changes Cache serialization * mypy * fix * other fixes * fix other tests * fix modelbuilder * disable two ewemples * fix some issues * fix caches * more tests * fix version * fix issues * mypy * import * fix issues * onnx_generate_with_genai * find names pattern * add genai * doc * fix doc
1 parent 6963db2 commit c61e539

File tree

9 files changed

+272
-10
lines changed

9 files changed

+272
-10
lines changed

CHANGELOGS.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ Change Logs
44
0.8.0
55
+++++
66

7-
* :pr:`276`: implements onnx_generate which implements method generate for an onnx model,
8-
changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
7+
* :pr:`278`: implements ``onnx_generate_with_genai``
8+
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
9+
* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model,
910
* :pr:`275`: fixes function ``patched_vmap``
1011

1112
0.7.16

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def linkcode_resolve(domain, info):
239239
"ONNX": "https://onnx.ai/",
240240
"ONNX Operators": "https://onnx.ai/onnx/operators/",
241241
"onnxruntime": "https://onnxruntime.ai/",
242+
"onnxruntime-genai": "https://github.com/microsoft/onnxruntime-genai",
242243
"onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
243244
"onnxruntime kernels": "https://onnxruntime.ai/docs/reference/operators/OperatorKernels.html",
244245
"onnx-array-api": "https://sdpython.github.io/doc/onnx-array-api/dev/",

_doc/technical/plot_generate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
# %%
9595
# Custom method generate
9696
# ======================
97+
#
98+
# Let's implement a simple function replicating when method
99+
# ``generate`` does.
97100

98101

99102
def simple_generate_with_cache(

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import_model_builder,
1212
create_model_builder,
1313
save_model_builder,
14+
find_names_pattern,
1415
)
1516
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1617
from onnx_diagnostic.helpers.rt_helper import make_feeds
@@ -63,6 +64,11 @@ def test_model_builder_id(self):
6364
raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1")
6465
self.assertEqualAny(expected, got)
6566

67+
def test_find_names_pattern(self):
68+
pats = ["past_key_values_key_0", "past_key_values_key_1"]
69+
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats))
70+
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats[:1]))
71+
6672

6773
if __name__ == "__main__":
6874
unittest.main(verbosity=2)

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import torch
44
from onnx_diagnostic.ext_test_case import (
55
ExtTestCase,
6+
has_onnxruntime_genai,
67
hide_stdout,
78
requires_transformers,
89
requires_torch,
910
)
10-
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
11+
from onnx_diagnostic.helpers.rt_helper import (
12+
onnx_generate,
13+
generate_and_validate,
14+
onnx_generate_with_genai,
15+
)
1116
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1217
from onnx_diagnostic.torch_export_patches import torch_export_patches
1318
from onnx_diagnostic.export.api import to_onnx
@@ -22,6 +27,7 @@ def test_onnx_generate(self):
2227
print("-- test_onnx_generate: get model")
2328
data = get_untrained_model_with_inputs(mid)
2429
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
30+
configuration = data["configuration"]
2531
del inputs["position_ids"]
2632
del ds["position_ids"]
2733
input_ids = inputs["input_ids"]
@@ -53,12 +59,23 @@ def test_onnx_generate(self):
5359
model, input_ids[:1], 2, max_new_tokens=10, session=session
5460
)
5561
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
56-
print("******", res)
57-
print("******", expected)
5862
self.assertEqual(expected.dtype, torch.int64)
5963
self.assertEqual(expected.shape, (1, 13))
6064
self.assertEqualArray(expected, res)
6165

66+
if not has_onnxruntime_genai():
67+
raise unittest.SkipTest("onnxruntime_genai is missing")
68+
69+
res, session = onnx_generate_with_genai(
70+
model_name,
71+
input_ids[:1],
72+
max_new_tokens=10,
73+
return_session=True,
74+
transformers_config=configuration,
75+
)
76+
self.assertNotEmpty(session)
77+
self.assertEqualArray(expected, res)
78+
6279

6380
if __name__ == "__main__":
6481
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False):
630630
return True
631631

632632

633+
def has_onnxruntime_genai():
634+
"""Tells if onnxruntime_genai is installed."""
635+
try:
636+
import onnxruntime_genai # noqa: F401
637+
638+
return True
639+
except ImportError:
640+
# onnxruntime not training
641+
return False
642+
643+
633644
def requires_onnxruntime_training(
634645
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
635646
) -> Callable:

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import copy
12
import importlib.util
23
import os
4+
import re
35
import requests
46
import sys
57
from pathlib import Path
6-
from typing import Any, Optional, Union
8+
from typing import Any, Dict, List, Optional, Union
79
from urllib.parse import urlparse
8-
from onnx import ModelProto, TensorProto
10+
from onnx import ModelProto, TensorProto, load as load_model
911

1012
CACHE_SUBDIR = "onnx-diagnostic"
1113

@@ -337,3 +339,131 @@ def _post(onnx_model):
337339
# onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
338340
# onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
339341
return onnx_model
342+
343+
344+
def find_names_pattern(names: List[str]) -> str:
345+
"""
346+
Finds a repeatable patterns in a list of names.
347+
It tries to locate the figures.
348+
349+
.. runpython::
350+
:showcode:
351+
352+
from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
353+
pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
354+
print(pattern)
355+
"""
356+
patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
357+
unique = set(patterns)
358+
assert (
359+
len(unique) == 1
360+
), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
361+
return patterns[0]
362+
363+
364+
def make_genai_config(
365+
config,
366+
onnx_filename: str,
367+
) -> Dict:
368+
"""
369+
Creates genai config file for a model.
370+
371+
:param config: configuration from transformers
372+
:param onnx_filename: onnx configuration
373+
:return: configuration
374+
"""
375+
onx = load_model(onnx_filename, load_external_data=False)
376+
config = copy.deepcopy(config)
377+
defaults = {
378+
"bos_token_id": None,
379+
"do_sample": False,
380+
"eos_token_id": None,
381+
"pad_token_id": None,
382+
"temperature": 1.0,
383+
"top_k": 50,
384+
"top_p": 1.0,
385+
}
386+
for key, default_val in defaults.items():
387+
if not hasattr(config, key):
388+
setattr(config, key, default_val)
389+
390+
bos_token_id = (
391+
config.bos_token_id
392+
if hasattr(config, "bos_token_id") and config.bos_token_id is not None
393+
else 1
394+
)
395+
eos_token_id = config.eos_token_id
396+
pad_token_id = (
397+
config.pad_token_id
398+
if hasattr(config, "pad_token_id") and config.pad_token_id is not None
399+
else (
400+
config.eos_token_id[0]
401+
if isinstance(config.eos_token_id, list)
402+
else config.eos_token_id
403+
)
404+
)
405+
input_names = [i.name for i in onx.graph.input]
406+
output_names = [i.name for i in onx.graph.output]
407+
past_key_values = [s for s in input_names if s.startswith("past_key_value")]
408+
first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
409+
shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
410+
return {
411+
"model": {
412+
"bos_token_id": bos_token_id,
413+
"context_length": config.max_position_embeddings,
414+
"decoder": {
415+
"session_options": {
416+
"log_id": "onnxruntime-genai",
417+
"provider_options": [],
418+
},
419+
"filename": os.path.split(onnx_filename)[-1],
420+
"head_size": shape[-1],
421+
"hidden_size": config.hidden_size,
422+
"inputs": {
423+
"input_ids": input_names[0],
424+
"attention_mask": input_names[1],
425+
"past_key_names": find_names_pattern(input_names[2::2]),
426+
"past_value_names": find_names_pattern(input_names[3::2]),
427+
},
428+
"outputs": {
429+
"logits": output_names[0],
430+
"present_key_names": find_names_pattern(output_names[1::2]),
431+
"present_value_names": find_names_pattern(output_names[2::2]),
432+
},
433+
"num_attention_heads": config.num_attention_heads,
434+
"num_hidden_layers": len(past_key_values) // 2,
435+
"num_key_value_heads": shape[1],
436+
},
437+
"eos_token_id": eos_token_id,
438+
"pad_token_id": pad_token_id,
439+
"type": config.model_type,
440+
# if "For" in self.model_type else len(self.model_type)].lower(),
441+
"vocab_size": config.vocab_size,
442+
},
443+
"search": {
444+
"diversity_penalty": (
445+
config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
446+
),
447+
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
448+
"early_stopping": True,
449+
"length_penalty": (
450+
config.length_penalty if hasattr(config, "length_penalty") else 1.0
451+
),
452+
"max_length": config.max_position_embeddings,
453+
"min_length": 0,
454+
"no_repeat_ngram_size": (
455+
config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
456+
),
457+
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
458+
"num_return_sequences": (
459+
config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
460+
),
461+
"past_present_share_buffer": False,
462+
"repetition_penalty": (
463+
config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
464+
),
465+
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
466+
"top_k": config.top_k if hasattr(config, "top_k") else 50,
467+
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
468+
},
469+
}

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import os
13
from typing import Any, Dict, List, Optional, Tuple, Union
24
import numpy as np
35
import onnx
@@ -283,7 +285,11 @@ def onnx_generate(
283285
284286
import os
285287
from onnx_diagnostic.helpers import string_type, string_diff
286-
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
288+
from onnx_diagnostic.helpers.rt_helper import (
289+
onnx_generate,
290+
generate_and_validate,
291+
onnx_generate_with_genai,
292+
)
287293
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
288294
from onnx_diagnostic.torch_export_patches import torch_export_patches
289295
from onnx_diagnostic.export.api import to_onnx
@@ -313,18 +319,29 @@ def onnx_generate(
313319
exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
314320
)
315321
316-
print("-- onnx_generate")
322+
print("-- generate with onnx")
317323
onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
318324
print("-- onnx output", onnx_outputs)
319325
320-
print("-- generate")
326+
# The example continues with other functions doing the same.
327+
print("-- generate with pytorch")
321328
torch_outputs, diffs = generate_and_validate(
322329
model, input_ids[:1], 2, max_new_tokens=10, session=model_name
323330
)
324331
print("-- torch output", torch_outputs)
325332
print("-- differences at each step:")
326333
for i, d in enumerate(diffs):
327334
print(f"iteration {i}: {string_diff(d)}")
335+
336+
print("-- generate with genai")
337+
genai_outputs, session = onnx_generate_with_genai(
338+
model_name,
339+
input_ids[:1],
340+
max_new_tokens=10,
341+
return_session=True,
342+
transformers_config=data["configuration"],
343+
)
344+
print("-- genai output", genai_outputs)
328345
"""
329346
if not isinstance(model_or_path, InferenceSessionForTorch):
330347
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
@@ -382,3 +399,78 @@ def onnx_generate(
382399
if return_session:
383400
return input_ids, session
384401
return input_ids
402+
403+
404+
def onnx_generate_with_genai(
405+
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
406+
input_ids: torch.Tensor,
407+
max_new_tokens=100,
408+
return_session: bool = False,
409+
transformers_config: Optional[Any] = None,
410+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
411+
"""
412+
Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
413+
for an ONNX model. The function does not expect any ``position_ids`` as input.
414+
415+
:param model_or_path: model or loaded model
416+
:param input_ids: input tokens
417+
:param eos_token_ids: token representing the end of an answer
418+
:param max_new_tokens: stops after this number of generated tokens
419+
:param return_session: returns the instance of class
420+
:class:`InferenceSessionForTorch
421+
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
422+
created if necessary
423+
:param transformers_config: write configuration
424+
if missing and if this configuration is provided
425+
:return: input tokens concatenated with new tokens
426+
427+
See example given with function :func:`onnx_generate
428+
<onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
429+
"""
430+
import onnxruntime_genai as og
431+
432+
if not isinstance(model_or_path, og.Model):
433+
from .model_builder_helper import make_genai_config
434+
435+
assert isinstance(
436+
model_or_path, str
437+
), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
438+
folder = os.path.dirname(model_or_path)
439+
assert os.path.exists(folder), f"Folder {folder!r} does not exists."
440+
assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
441+
config_file = os.path.join(folder, "genai_config.json")
442+
if not os.path.exists(config_file):
443+
if not transformers_config:
444+
raise FileNotFoundError(
445+
f"Folder {model_or_path!r} does not contain 'genai_config.json'."
446+
)
447+
config = make_genai_config(transformers_config, model_or_path)
448+
with open(config_file, "w") as f:
449+
json.dump(config, f, indent=4)
450+
451+
config = og.Config(os.path.dirname(config_file))
452+
if input_ids.is_cuda:
453+
config.clear_providers()
454+
config.append_provider("cuda")
455+
session = og.Model(config)
456+
else:
457+
session = model_or_path
458+
459+
params = og.GeneratorParams(session)
460+
params.set_search_options(
461+
max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0]
462+
)
463+
generator = og.Generator(session, params)
464+
465+
# First call: prefill
466+
cats = []
467+
generator.append_tokens(input_ids)
468+
while not generator.is_done():
469+
generator.generate_next_token()
470+
new_token = generator.get_next_tokens()[0]
471+
cats.append(int(new_token))
472+
473+
input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1)
474+
if return_session:
475+
return input_ids, session
476+
return input_ids

0 commit comments

Comments
 (0)