1- from typing import Any , Dict , List , Tuple , Union
1+ from typing import Any , Dict , List , Optional , Tuple , Union
22import numpy as np
33import onnx
44import torch
5- from .helper import string_type , flatten_object
5+ from .helper import string_type , flatten_object , max_diff
6+ from .torch_helper import torch_deepcopy
67from .ort_session import InferenceSessionForTorch
78
89
@@ -147,6 +148,115 @@ def make_empty_cache(
147148 return feeds
148149
149150
151+ def generate_and_validate (
152+ model ,
153+ input_ids : torch .Tensor ,
154+ eos_token_id : int ,
155+ max_new_tokens : int = 100 ,
156+ session : Optional [Union [InferenceSessionForTorch , onnx .ModelProto , str ]] = None ,
157+ atol : float = 0.1 ,
158+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , List [Dict ]]]:
159+ """
160+ Implements a simple method ``generate`` for a torch model.
161+ The function does not expect any ``position_ids`` as input.
162+ The function also checks the outputs coming from an onnx model
163+ are close to the output the torch model produces.
164+
165+ :param model_or_path: model or loaded model
166+ :param input_ids: input tokens
167+ :param eos_token_ids: token representing the end of an answer
168+ :param max_new_tokens: stops after this number of generated tokens
169+ :param session: the onnx model
170+ :return: input tokens concatenated with new tokens,
171+ if session is not null, it also returns the maximum differences
172+ at every iterations
173+
174+ See example given with function :func:`onnx_generate
175+ <onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
176+ """
177+ if session is not None :
178+ if not isinstance (session , InferenceSessionForTorch ):
179+ providers = ["CUDAExecutionProvider" ] if input_ids .is_cuda else []
180+ providers .append ("CPUExecutionProvider" )
181+ session = InferenceSessionForTorch (session , providers = providers )
182+
183+ # First call: prefill
184+ attention_mask = torch .ones (
185+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
186+ )
187+ if session :
188+ feeds = {
189+ ** dict (zip (session .input_names [:2 ], [input_ids , attention_mask ])),
190+ ** make_empty_cache (
191+ input_ids .shape [0 ],
192+ session .input_names [2 :],
193+ session .input_shapes [2 :],
194+ session .input_types [2 :],
195+ ),
196+ }
197+ onnx_results = session .run (None , feeds )
198+
199+ outputs = model (input_ids , use_cache = True , attention_mask = attention_mask )
200+
201+ if session :
202+ diff = max_diff (outputs , onnx_results )
203+ assert isinstance (diff ["abs" ], float ) and diff ["abs" ] <= atol , (
204+ f"Unexpected issue with { type (model )} \n diff={ diff } "
205+ f"\n input_ids.shape={ input_ids .shape } "
206+ f"\n expected={ string_type (outputs , with_shape = True , with_min_max = True )} "
207+ f"\n got=\n "
208+ f"{ string_type (onnx_results , with_shape = True , with_min_max = True )} \n "
209+ f"feeds={ string_type (feeds , with_shape = True , with_min_max = True )} "
210+ )
211+ diffs = [diff ]
212+
213+ # Next calls: decode
214+ for iteration in range (max_new_tokens ):
215+ next_token_logits = outputs .logits [:, - 1 , :]
216+ next_token_id = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
217+ if next_token_id .item () == eos_token_id :
218+ break
219+ input_ids = torch .cat ([input_ids , next_token_id ], dim = - 1 )
220+ attention_mask = torch .ones (
221+ input_ids .shape , dtype = input_ids .dtype , device = input_ids .device
222+ )
223+ if session :
224+ feeds = dict (
225+ zip (
226+ session .input_names ,
227+ [
228+ t .detach ()
229+ for t in torch_deepcopy (
230+ flatten_object (
231+ [next_token_id , attention_mask , outputs .past_key_values ]
232+ )
233+ )
234+ ],
235+ )
236+ )
237+ onnx_results = session .run (None , feeds )
238+ outputs = model (
239+ next_token_id ,
240+ use_cache = True ,
241+ past_key_values = outputs .past_key_values ,
242+ attention_mask = attention_mask ,
243+ )
244+ if session :
245+ diff = max_diff (outputs , onnx_results )
246+ assert isinstance (diff ["abs" ], float ) and diff ["abs" ] <= atol , (
247+ f"Unexpected issue with { type (model )} , iteration={ iteration } "
248+ f"\n diff={ diff } \n input_ids.shape={ input_ids .shape } "
249+ f"\n expected={ string_type (outputs , with_shape = True , with_min_max = True )} "
250+ f"\n got=\n "
251+ f"{ string_type (onnx_results , with_shape = True , with_min_max = True )} \n "
252+ f"feeds={ string_type (feeds , with_shape = True , with_min_max = True )} "
253+ )
254+ diffs .append (diff )
255+ if session :
256+ return input_ids , diffs
257+ return input_ids
258+
259+
150260def onnx_generate (
151261 model_or_path : Union [onnx .ModelProto , str , InferenceSessionForTorch ],
152262 input_ids : torch .Tensor ,
@@ -167,6 +277,54 @@ def onnx_generate(
167277 <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
168278 created if necessary
169279 :return: input tokens concatenated with new tokens
280+
281+ .. runpython::
282+ :showcode:
283+
284+ import os
285+ from onnx_diagnostic.helpers import string_type, string_diff
286+ from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate
287+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
288+ from onnx_diagnostic.torch_export_patches import torch_export_patches
289+ from onnx_diagnostic.export.api import to_onnx
290+
291+ mid = "arnir0/Tiny-LLM"
292+ print(f"-- get model for {mid!r}")
293+ data = get_untrained_model_with_inputs(mid)
294+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
295+ del inputs["position_ids"]
296+ del ds["position_ids"]
297+ input_ids = inputs["input_ids"]
298+
299+ print(f"-- input_ids={input_ids.shape}")
300+ print(f"-- inputs: {string_type(inputs, with_shape=True)}")
301+ print(f"-- dynamic_shapes: {string_type(ds)}")
302+ folder = "dump_test"
303+ os.makedirs(folder, exist_ok=True)
304+ model_name = os.path.join(folder, "model.onnx")
305+ print("-- test_onnx_generate: export model")
306+ with torch_export_patches(patch_transformers=True, patch_torch=False):
307+ to_onnx(
308+ model,
309+ (),
310+ kwargs=inputs,
311+ dynamic_shapes=ds,
312+ filename=model_name,
313+ exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
314+ )
315+
316+ print("-- onnx_generate")
317+ onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
318+ print("-- onnx output", onnx_outputs)
319+
320+ print("-- generate")
321+ torch_outputs, diffs = generate_and_validate(
322+ model, input_ids[:1], 2, max_new_tokens=10, session=model_name
323+ )
324+ print("-- torch output", torch_outputs)
325+ print("-- differences at each step:")
326+ for i, d in enumerate(diffs):
327+ print(f"iteration {i}: {string_diff(d)}")
170328 """
171329 if not isinstance (model_or_path , InferenceSessionForTorch ):
172330 providers = ["CUDAExecutionProvider" ] if input_ids .is_cuda else []
0 commit comments