Skip to content

Commit 37bdbd6

Browse files
committed
onnx_generate_with_genai
1 parent ce4fb66 commit 37bdbd6

File tree

6 files changed

+229
-23
lines changed

6 files changed

+229
-23
lines changed

_doc/technical/plot_generate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
# %%
9595
# Custom method generate
9696
# ======================
97+
#
98+
# Let's implement a simple function replicating when method
99+
# ``generate`` does.
97100

98101

99102
def simple_generate_with_cache(

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
import torch
44
from onnx_diagnostic.ext_test_case import (
55
ExtTestCase,
6+
has_onnxruntime_genai,
67
hide_stdout,
78
requires_transformers,
89
requires_torch,
910
)
1011
from onnx_diagnostic.helpers import max_diff, flatten_object
11-
from onnx_diagnostic.helpers.rt_helper import onnx_generate, make_empty_cache
12+
from onnx_diagnostic.helpers.rt_helper import (
13+
onnx_generate,
14+
onnx_generate_with_genai,
15+
make_empty_cache,
16+
)
1217
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1318
from onnx_diagnostic.helpers.ort_session import InferenceSessionForTorch
1419
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
@@ -101,6 +106,7 @@ def test_onnx_generate(self):
101106
print("-- test_onnx_generate: get model")
102107
data = get_untrained_model_with_inputs(mid)
103108
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
109+
configuration = data["configuration"]
104110
del inputs["position_ids"]
105111
del ds["position_ids"]
106112
input_ids = inputs["input_ids"]
@@ -118,25 +124,38 @@ def test_onnx_generate(self):
118124
exporter="custom",
119125
)
120126

121-
print("-- test_onnx_generate: generate")
122-
res, session = onnx_generate(
123-
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
124-
)
125-
n_inputs = input_ids.shape[1]
126-
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
127-
self.assertEqual(res.dtype, torch.int64)
128-
self.assertEqual(res.shape, (1, 13))
129-
print("-- test_onnx_generate: done")
130-
# expected = model.generate(input_ids[:1], max_new_tokens=10)
131-
expected = self.simple_generate_with_cache(
132-
model, input_ids[:1], 2, max_new_tokens=10, session=session
133-
)
134-
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
135-
print("******", res)
136-
print("******", expected)
137-
self.assertEqual(expected.dtype, torch.int64)
138-
self.assertEqual(expected.shape, (1, 13))
139-
self.assertEqualArray(expected, res)
127+
print("-- test_onnx_generate: generate")
128+
res, session = onnx_generate(
129+
model_name, input_ids[:1], 2, max_new_tokens=10, return_session=True
130+
)
131+
n_inputs = input_ids.shape[1]
132+
self.assertEqualArray(input_ids[:1], res[:, :n_inputs])
133+
self.assertEqual(res.dtype, torch.int64)
134+
self.assertEqual(res.shape, (1, 13))
135+
print("-- test_onnx_generate: done")
136+
# expected = model.generate(input_ids[:1], max_new_tokens=10)
137+
expected = self.simple_generate_with_cache(
138+
model, input_ids[:1], 2, max_new_tokens=10, session=session
139+
)
140+
self.assertEqualArray(input_ids[:1], expected[:, :n_inputs])
141+
print("******", res)
142+
print("******", expected)
143+
self.assertEqual(expected.dtype, torch.int64)
144+
self.assertEqual(expected.shape, (1, 13))
145+
self.assertEqualArray(expected, res)
146+
147+
if not has_onnxruntime_genai():
148+
raise unittest.SkipTest("onnxruntime_genai is missing")
149+
150+
res, session = onnx_generate_with_genai(
151+
model_name,
152+
input_ids[:1],
153+
max_new_tokens=10,
154+
return_session=True,
155+
transformers_config=configuration,
156+
)
157+
self.assertNotEmpty(session)
158+
self.assertEqualArray(expected, res)
140159

141160

142161
if __name__ == "__main__":

onnx_diagnostic/ext_test_case.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False):
630630
return True
631631

632632

633+
def has_onnxruntime_genai():
634+
"""Tells if onnxruntime_genai is installed."""
635+
try:
636+
import onnxruntime_genai # noqa: F401
637+
638+
return True
639+
except ImportError:
640+
# onnxruntime not training
641+
return False
642+
643+
633644
def requires_onnxruntime_training(
634645
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
635646
) -> Callable:

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import copy
12
import importlib.util
23
import os
34
import requests
45
import sys
56
from pathlib import Path
6-
from typing import Any, Optional, Union
7+
from typing import Any, Dict, Optional, Union
78
from urllib.parse import urlparse
8-
from onnx import ModelProto, TensorProto
9+
from onnx import ModelProto, TensorProto, load as load_model
910

1011
CACHE_SUBDIR = "onnx-diagnostic"
1112

@@ -337,3 +338,102 @@ def _post(onnx_model):
337338
# onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
338339
# onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
339340
return onnx_model
341+
342+
343+
def make_genai_config(
344+
config,
345+
onnx_filename: str,
346+
) -> Dict:
347+
"""
348+
Creates genai config file for a model.
349+
350+
:param config: configuration from transformers
351+
:param onnx_filename: onnx configuration
352+
:return: configuration
353+
"""
354+
onx = load_model(onnx_filename, load_external_data=False)
355+
config = copy.deepcopy(config)
356+
defaults = {
357+
"bos_token_id": None,
358+
"do_sample": False,
359+
"eos_token_id": None,
360+
"pad_token_id": None,
361+
"temperature": 1.0,
362+
"top_k": 50,
363+
"top_p": 1.0,
364+
}
365+
for key, default_val in defaults.items():
366+
if not hasattr(config, key):
367+
setattr(config, key, default_val)
368+
369+
bos_token_id = (
370+
config.bos_token_id
371+
if hasattr(config, "bos_token_id") and config.bos_token_id is not None
372+
else 1
373+
)
374+
eos_token_id = config.eos_token_id
375+
pad_token_id = (
376+
config.pad_token_id
377+
if hasattr(config, "pad_token_id") and config.pad_token_id is not None
378+
else (
379+
config.eos_token_id[0]
380+
if isinstance(config.eos_token_id, list)
381+
else config.eos_token_id
382+
)
383+
)
384+
input_names = [i.name for i in onx.graph.input]
385+
output_names = [i.name for i in onx.graph.output]
386+
past_key_values = [s for s in input_names if s.startswith("past_key_value")]
387+
first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
388+
shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
389+
return {
390+
"model": {
391+
"bos_token_id": bos_token_id,
392+
"context_length": config.max_position_embeddings,
393+
"decoder": {
394+
"session_options": {
395+
"log_id": "onnxruntime-genai",
396+
"provider_options": [],
397+
},
398+
"filename": onnx_filename,
399+
"head_size": shape[-1],
400+
"hidden_size": config.hidden_size,
401+
"inputs": input_names,
402+
"outputs": output_names,
403+
"num_attention_heads": config.num_attention_heads,
404+
"num_hidden_layers": len(past_key_values) // 2,
405+
"num_key_value_heads": shape[1],
406+
},
407+
"eos_token_id": eos_token_id,
408+
"pad_token_id": pad_token_id,
409+
# "type": self.model_type[ : self.model_type.find("For")
410+
# if "For" in self.model_type else len(self.model_type)].lower(),
411+
"vocab_size": config.vocab_size,
412+
},
413+
"search": {
414+
"diversity_penalty": (
415+
config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
416+
),
417+
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
418+
"early_stopping": True,
419+
"length_penalty": (
420+
config.length_penalty if hasattr(config, "length_penalty") else 1.0
421+
),
422+
"max_length": config.max_position_embeddings,
423+
"min_length": 0,
424+
"no_repeat_ngram_size": (
425+
config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
426+
),
427+
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
428+
"num_return_sequences": (
429+
config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
430+
),
431+
"past_present_share_buffer": False,
432+
"repetition_penalty": (
433+
config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
434+
),
435+
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
436+
"top_k": config.top_k if hasattr(config, "top_k") else 50,
437+
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
438+
},
439+
}

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Dict, List, Tuple, Union
1+
import json
2+
import os
3+
from typing import Any, Dict, List, Optional, Tuple, Union
24
import numpy as np
35
import onnx
46
import torch
@@ -224,3 +226,73 @@ def onnx_generate(
224226
if return_session:
225227
return input_ids, session
226228
return input_ids
229+
230+
231+
def onnx_generate_with_genai(
232+
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
233+
input_ids: torch.Tensor,
234+
max_new_tokens=100,
235+
return_session: bool = False,
236+
transformers_config: Optional[Any] = None,
237+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
238+
"""
239+
Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
240+
for an ONNX model. The function does not expect any ``position_ids`` as input.
241+
242+
:param model_or_path: model or loaded model
243+
:param input_ids: input tokens
244+
:param eos_token_ids: token representing the end of an answer
245+
:param max_new_tokens: stops after this number of generated tokens
246+
:param return_session: returns the instance of class
247+
:class:`InferenceSessionForTorch
248+
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
249+
created if necessary
250+
:param transformers_config: write configuration
251+
if missing and if this configuration is provided
252+
:return: input tokens concatenated with new tokens
253+
"""
254+
import onnxruntime_genai as og
255+
256+
if not isinstance(model_or_path, og.Model):
257+
from .model_builder_helper import make_genai_config
258+
259+
assert isinstance(
260+
model_or_path, str
261+
), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
262+
folder = os.path.dirname(model_or_path)
263+
assert os.path.exists(folder), f"Folder {folder!r} does not exists."
264+
assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
265+
config_file = os.path.join(folder, "genai_config.json")
266+
if not os.path.exists(config_file):
267+
if not transformers_config:
268+
raise FileNotFoundError(
269+
f"Folder {model_or_path!r} does not contain 'genai_config.json'."
270+
)
271+
config = make_genai_config(transformers_config, model_or_path)
272+
with open(config_file, "w") as f:
273+
json.dump(config, f, indent=4)
274+
275+
config = og.Config(os.path.dirname(config_file))
276+
if input_ids.is_cuda:
277+
config.clear_providers()
278+
config.append_provider("cuda")
279+
session = og.Model(config)
280+
else:
281+
session = model_or_path
282+
283+
params = og.GeneratorParams(session)
284+
params.set_search_options(max_new_tokens=max_new_tokens, batch_size=input_ids.shape[0])
285+
generator = og.Generator(session, params)
286+
287+
# First call: prefill
288+
cats = [input_ids]
289+
generator.append_tokens(input_ids)
290+
while not generator.is_done():
291+
generator.generate_next_token()
292+
new_token = generator.get_next_tokens()[0]
293+
cats.append(new_token)
294+
295+
input_ids = torch.cat(cats, dim=-1)
296+
if return_session:
297+
return input_ids, session
298+
return input_ids

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ huggingface_hub
77
matplotlib
88
onnx-array-api>=0.3.1
99
onnx
10+
onnxruntime-genai
1011
onnxscript
1112
openpyxl
1213
packaging

0 commit comments

Comments
 (0)