|
| 1 | +import json |
| 2 | +import os |
1 | 3 | from typing import Any, Dict, List, Optional, Tuple, Union |
2 | 4 | import numpy as np |
3 | 5 | import onnx |
@@ -283,7 +285,11 @@ def onnx_generate( |
283 | 285 |
|
284 | 286 | import os |
285 | 287 | from onnx_diagnostic.helpers import string_type, string_diff |
286 | | - from onnx_diagnostic.helpers.rt_helper import onnx_generate, generate_and_validate |
| 288 | + from onnx_diagnostic.helpers.rt_helper import ( |
| 289 | + onnx_generate, |
| 290 | + generate_and_validate, |
| 291 | + onnx_generate_with_genai, |
| 292 | + ) |
287 | 293 | from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs |
288 | 294 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
289 | 295 | from onnx_diagnostic.export.api import to_onnx |
@@ -313,18 +319,29 @@ def onnx_generate( |
313 | 319 | exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder |
314 | 320 | ) |
315 | 321 |
|
316 | | - print("-- onnx_generate") |
| 322 | + print("-- generate with onnx") |
317 | 323 | onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10) |
318 | 324 | print("-- onnx output", onnx_outputs) |
319 | 325 |
|
320 | | - print("-- generate") |
| 326 | + # The example continues with other functions doing the same. |
| 327 | + print("-- generate with pytorch") |
321 | 328 | torch_outputs, diffs = generate_and_validate( |
322 | 329 | model, input_ids[:1], 2, max_new_tokens=10, session=model_name |
323 | 330 | ) |
324 | 331 | print("-- torch output", torch_outputs) |
325 | 332 | print("-- differences at each step:") |
326 | 333 | for i, d in enumerate(diffs): |
327 | 334 | print(f"iteration {i}: {string_diff(d)}") |
| 335 | +
|
| 336 | + print("-- generate with genai") |
| 337 | + genai_outputs, session = onnx_generate_with_genai( |
| 338 | + model_name, |
| 339 | + input_ids[:1], |
| 340 | + max_new_tokens=10, |
| 341 | + return_session=True, |
| 342 | + transformers_config=data["configuration"], |
| 343 | + ) |
| 344 | + print("-- genai output", genai_outputs) |
328 | 345 | """ |
329 | 346 | if not isinstance(model_or_path, InferenceSessionForTorch): |
330 | 347 | providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else [] |
@@ -382,3 +399,78 @@ def onnx_generate( |
382 | 399 | if return_session: |
383 | 400 | return input_ids, session |
384 | 401 | return input_ids |
| 402 | + |
| 403 | + |
| 404 | +def onnx_generate_with_genai( |
| 405 | + model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch], |
| 406 | + input_ids: torch.Tensor, |
| 407 | + max_new_tokens=100, |
| 408 | + return_session: bool = False, |
| 409 | + transformers_config: Optional[Any] = None, |
| 410 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]: |
| 411 | + """ |
| 412 | + Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate`` |
| 413 | + for an ONNX model. The function does not expect any ``position_ids`` as input. |
| 414 | +
|
| 415 | + :param model_or_path: model or loaded model |
| 416 | + :param input_ids: input tokens |
| 417 | + :param eos_token_ids: token representing the end of an answer |
| 418 | + :param max_new_tokens: stops after this number of generated tokens |
| 419 | + :param return_session: returns the instance of class |
| 420 | + :class:`InferenceSessionForTorch |
| 421 | + <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>` |
| 422 | + created if necessary |
| 423 | + :param transformers_config: write configuration |
| 424 | + if missing and if this configuration is provided |
| 425 | + :return: input tokens concatenated with new tokens |
| 426 | +
|
| 427 | + See example given with function :func:`onnx_generate |
| 428 | + <onnx_diagnostic.helpers.rt_helper.onnx_generate>`. |
| 429 | + """ |
| 430 | + import onnxruntime_genai as og |
| 431 | + |
| 432 | + if not isinstance(model_or_path, og.Model): |
| 433 | + from .model_builder_helper import make_genai_config |
| 434 | + |
| 435 | + assert isinstance( |
| 436 | + model_or_path, str |
| 437 | + ), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}" |
| 438 | + folder = os.path.dirname(model_or_path) |
| 439 | + assert os.path.exists(folder), f"Folder {folder!r} does not exists." |
| 440 | + assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists." |
| 441 | + config_file = os.path.join(folder, "genai_config.json") |
| 442 | + if not os.path.exists(config_file): |
| 443 | + if not transformers_config: |
| 444 | + raise FileNotFoundError( |
| 445 | + f"Folder {model_or_path!r} does not contain 'genai_config.json'." |
| 446 | + ) |
| 447 | + config = make_genai_config(transformers_config, model_or_path) |
| 448 | + with open(config_file, "w") as f: |
| 449 | + json.dump(config, f, indent=4) |
| 450 | + |
| 451 | + config = og.Config(os.path.dirname(config_file)) |
| 452 | + if input_ids.is_cuda: |
| 453 | + config.clear_providers() |
| 454 | + config.append_provider("cuda") |
| 455 | + session = og.Model(config) |
| 456 | + else: |
| 457 | + session = model_or_path |
| 458 | + |
| 459 | + params = og.GeneratorParams(session) |
| 460 | + params.set_search_options( |
| 461 | + max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0] |
| 462 | + ) |
| 463 | + generator = og.Generator(session, params) |
| 464 | + |
| 465 | + # First call: prefill |
| 466 | + cats = [] |
| 467 | + generator.append_tokens(input_ids) |
| 468 | + while not generator.is_done(): |
| 469 | + generator.generate_next_token() |
| 470 | + new_token = generator.get_next_tokens()[0] |
| 471 | + cats.append(int(new_token)) |
| 472 | + |
| 473 | + input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1) |
| 474 | + if return_session: |
| 475 | + return input_ids, session |
| 476 | + return input_ids |
0 commit comments