Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ Change Logs
0.8.0
+++++

* :pr:`276`: implements onnx_generate which implements method generate for an onnx model,
changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
* :pr:`278`: implements ``onnx_generate_with_genai``
* :pr:`277`: changes the serialization for all caches to reorder the model outputs (key_1, value_1, key_2, ...)
* :pr:`276`: implements ``onnx_generate`` which implements method generate for an onnx model,
* :pr:`275`: fixes function ``patched_vmap``

0.7.16
Expand Down
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def linkcode_resolve(domain, info):
"ONNX": "https://onnx.ai/",
"ONNX Operators": "https://onnx.ai/onnx/operators/",
"onnxruntime": "https://onnxruntime.ai/",
"onnxruntime-genai": "https://github.com/microsoft/onnxruntime-genai",
"onnxruntime-training": "https://onnxruntime.ai/docs/get-started/training-on-device.html",
"onnxruntime kernels": "https://onnxruntime.ai/docs/reference/operators/OperatorKernels.html",
"onnx-array-api": "https://sdpython.github.io/doc/onnx-array-api/dev/",
Expand Down
3 changes: 3 additions & 0 deletions _doc/technical/plot_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
# %%
# Custom method generate
# ======================
#
# Let's implement a simple function replicating when method
# ``generate`` does.


def simple_generate_with_cache(
Expand Down
6 changes: 6 additions & 0 deletions _unittests/ut_helpers/test_model_builder_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import_model_builder,
create_model_builder,
save_model_builder,
find_names_pattern,
)
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.helpers.rt_helper import make_feeds
Expand Down Expand Up @@ -63,6 +64,11 @@ def test_model_builder_id(self):
raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1")
self.assertEqualAny(expected, got)

def test_find_names_pattern(self):
pats = ["past_key_values_key_0", "past_key_values_key_1"]
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats))
self.assertEqual("past_key_values_key_%d", find_names_pattern(pats[:1]))


if __name__ == "__main__":
unittest.main(verbosity=2)
23 changes: 20 additions & 3 deletions _unittests/ut_helpers/test_rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
has_onnxruntime_genai,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
from onnx_diagnostic.helpers.rt_helper import (
onnx_generate,
generate_and_validate,
onnx_generate_with_genai,
)
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.export.api import to_onnx
Expand All @@ -22,6 +27,7 @@ def test_onnx_generate(self):
print("-- test_onnx_generate: get model")
data = get_untrained_model_with_inputs(mid)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
configuration = data["configuration"]
del inputs["position_ids"]
del ds["position_ids"]
input_ids = inputs["input_ids"]
Expand Down Expand Up @@ -53,12 +59,23 @@ def test_onnx_generate(self):
model, input_ids[:1], 2, max_new_tokens=10, session=session
)
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
print("******", res)
print("******", expected)
self.assertEqual(expected.dtype, torch.int64)
self.assertEqual(expected.shape, (1, 13))
self.assertEqualArray(expected, res)

if not has_onnxruntime_genai():
raise unittest.SkipTest("onnxruntime_genai is missing")

res, session = onnx_generate_with_genai(
model_name,
input_ids[:1],
max_new_tokens=10,
return_session=True,
transformers_config=configuration,
)
self.assertNotEmpty(session)
self.assertEqualArray(expected, res)


if __name__ == "__main__":
unittest.main(verbosity=2)
11 changes: 11 additions & 0 deletions onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False):
return True


def has_onnxruntime_genai():
"""Tells if onnxruntime_genai is installed."""
try:
import onnxruntime_genai # noqa: F401

return True
except ImportError:
# onnxruntime not training
return False


def requires_onnxruntime_training(
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
) -> Callable:
Expand Down
134 changes: 132 additions & 2 deletions onnx_diagnostic/helpers/model_builder_helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import copy
import importlib.util
import os
import re
import requests
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from onnx import ModelProto, TensorProto
from onnx import ModelProto, TensorProto, load as load_model

CACHE_SUBDIR = "onnx-diagnostic"

Expand Down Expand Up @@ -337,3 +339,131 @@ def _post(onnx_model):
# onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
# onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
return onnx_model


def find_names_pattern(names: List[str]) -> str:
"""
Finds a repeatable patterns in a list of names.
It tries to locate the figures.

.. runpython::
:showcode:

from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
print(pattern)
"""
patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
unique = set(patterns)
assert (
len(unique) == 1
), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
return patterns[0]


def make_genai_config(
config,
onnx_filename: str,
) -> Dict:
"""
Creates genai config file for a model.

:param config: configuration from transformers
:param onnx_filename: onnx configuration
:return: configuration
"""
onx = load_model(onnx_filename, load_external_data=False)
config = copy.deepcopy(config)
defaults = {
"bos_token_id": None,
"do_sample": False,
"eos_token_id": None,
"pad_token_id": None,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
}
for key, default_val in defaults.items():
if not hasattr(config, key):
setattr(config, key, default_val)

bos_token_id = (
config.bos_token_id
if hasattr(config, "bos_token_id") and config.bos_token_id is not None
else 1
)
eos_token_id = config.eos_token_id
pad_token_id = (
config.pad_token_id
if hasattr(config, "pad_token_id") and config.pad_token_id is not None
else (
config.eos_token_id[0]
if isinstance(config.eos_token_id, list)
else config.eos_token_id
)
)
input_names = [i.name for i in onx.graph.input]
output_names = [i.name for i in onx.graph.output]
past_key_values = [s for s in input_names if s.startswith("past_key_value")]
first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
return {
"model": {
"bos_token_id": bos_token_id,
"context_length": config.max_position_embeddings,
"decoder": {
"session_options": {
"log_id": "onnxruntime-genai",
"provider_options": [],
},
"filename": os.path.split(onnx_filename)[-1],
"head_size": shape[-1],
"hidden_size": config.hidden_size,
"inputs": {
"input_ids": input_names[0],
"attention_mask": input_names[1],
"past_key_names": find_names_pattern(input_names[2::2]),
"past_value_names": find_names_pattern(input_names[3::2]),
},
"outputs": {
"logits": output_names[0],
"present_key_names": find_names_pattern(output_names[1::2]),
"present_value_names": find_names_pattern(output_names[2::2]),
},
"num_attention_heads": config.num_attention_heads,
"num_hidden_layers": len(past_key_values) // 2,
"num_key_value_heads": shape[1],
},
"eos_token_id": eos_token_id,
"pad_token_id": pad_token_id,
"type": config.model_type,
# if "For" in self.model_type else len(self.model_type)].lower(),
"vocab_size": config.vocab_size,
},
"search": {
"diversity_penalty": (
config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
),
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
"early_stopping": True,
"length_penalty": (
config.length_penalty if hasattr(config, "length_penalty") else 1.0
),
"max_length": config.max_position_embeddings,
"min_length": 0,
"no_repeat_ngram_size": (
config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
),
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
"num_return_sequences": (
config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
),
"past_present_share_buffer": False,
"repetition_penalty": (
config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
),
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
"top_k": config.top_k if hasattr(config, "top_k") else 50,
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
},
}
98 changes: 95 additions & 3 deletions onnx_diagnostic/helpers/rt_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
Expand Down Expand Up @@ -283,7 +285,11 @@ def onnx_generate(

import os
from onnx_diagnostic.helpers import string_type, string_diff
from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
from onnx_diagnostic.helpers.rt_helper import (
onnx_generate,
generate_and_validate,
onnx_generate_with_genai,
)
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.export.api import to_onnx
Expand Down Expand Up @@ -313,18 +319,29 @@ def onnx_generate(
exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
)

print("-- onnx_generate")
print("-- generate with onnx")
onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
print("-- onnx output", onnx_outputs)

print("-- generate")
# The example continues with other functions doing the same.
print("-- generate with pytorch")
torch_outputs, diffs = generate_and_validate(
model, input_ids[:1], 2, max_new_tokens=10, session=model_name
)
print("-- torch output", torch_outputs)
print("-- differences at each step:")
for i, d in enumerate(diffs):
print(f"iteration {i}: {string_diff(d)}")

print("-- generate with genai")
genai_outputs, session = onnx_generate_with_genai(
model_name,
input_ids[:1],
max_new_tokens=10,
return_session=True,
transformers_config=data["configuration"],
)
print("-- genai output", genai_outputs)
"""
if not isinstance(model_or_path, InferenceSessionForTorch):
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
Expand Down Expand Up @@ -382,3 +399,78 @@ def onnx_generate(
if return_session:
return input_ids, session
return input_ids


def onnx_generate_with_genai(
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
input_ids: torch.Tensor,
max_new_tokens=100,
return_session: bool = False,
transformers_config: Optional[Any] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
"""
Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
for an ONNX model. The function does not expect any ``position_ids`` as input.

:param model_or_path: model or loaded model
:param input_ids: input tokens
:param eos_token_ids: token representing the end of an answer
:param max_new_tokens: stops after this number of generated tokens
:param return_session: returns the instance of class
:class:`InferenceSessionForTorch
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
created if necessary
:param transformers_config: write configuration
if missing and if this configuration is provided
:return: input tokens concatenated with new tokens

See example given with function :func:`onnx_generate
<onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
"""
import onnxruntime_genai as og

if not isinstance(model_or_path, og.Model):
from .model_builder_helper import make_genai_config

assert isinstance(
model_or_path, str
), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
folder = os.path.dirname(model_or_path)
assert os.path.exists(folder), f"Folder {folder!r} does not exists."
assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
config_file = os.path.join(folder, "genai_config.json")
if not os.path.exists(config_file):
if not transformers_config:
raise FileNotFoundError(
f"Folder {model_or_path!r} does not contain 'genai_config.json'."
)
config = make_genai_config(transformers_config, model_or_path)
with open(config_file, "w") as f:
json.dump(config, f, indent=4)

config = og.Config(os.path.dirname(config_file))
if input_ids.is_cuda:
config.clear_providers()
config.append_provider("cuda")
session = og.Model(config)
else:
session = model_or_path

params = og.GeneratorParams(session)
params.set_search_options(
max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0]
)
generator = og.Generator(session, params)

# First call: prefill
cats = []
generator.append_tokens(input_ids)
while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_next_tokens()[0]
cats.append(int(new_token))

input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1)
if return_session:
return input_ids, session
return input_ids
Loading
Loading