|
| 1 | +import copy |
1 | 2 | import importlib.util |
2 | 3 | import os |
3 | 4 | import requests |
4 | 5 | import sys |
5 | 6 | from pathlib import Path |
6 | | -from typing import Any, Optional, Union |
| 7 | +from typing import Any, Dict, Optional, Union |
7 | 8 | from urllib.parse import urlparse |
8 | | -from onnx import ModelProto, TensorProto |
| 9 | +from onnx import ModelProto, TensorProto, load as load_model |
9 | 10 |
|
10 | 11 | CACHE_SUBDIR = "onnx-diagnostic" |
11 | 12 |
|
@@ -337,3 +338,102 @@ def _post(onnx_model): |
337 | 338 | # onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir) |
338 | 339 | # onnx_model.save_processing(hf_name, extra_kwargs, output_dir) |
339 | 340 | return onnx_model |
| 341 | + |
| 342 | + |
| 343 | +def make_genai_config( |
| 344 | + config, |
| 345 | + onnx_filename: str, |
| 346 | +) -> Dict: |
| 347 | + """ |
| 348 | + Creates genai config file for a model. |
| 349 | +
|
| 350 | + :param config: configuration from transformers |
| 351 | + :param onnx_filename: onnx configuration |
| 352 | + :return: configuration |
| 353 | + """ |
| 354 | + onx = load_model(onnx_filename, load_external_data=False) |
| 355 | + config = copy.deepcopy(config) |
| 356 | + defaults = { |
| 357 | + "bos_token_id": None, |
| 358 | + "do_sample": False, |
| 359 | + "eos_token_id": None, |
| 360 | + "pad_token_id": None, |
| 361 | + "temperature": 1.0, |
| 362 | + "top_k": 50, |
| 363 | + "top_p": 1.0, |
| 364 | + } |
| 365 | + for key, default_val in defaults.items(): |
| 366 | + if not hasattr(config, key): |
| 367 | + setattr(config, key, default_val) |
| 368 | + |
| 369 | + bos_token_id = ( |
| 370 | + config.bos_token_id |
| 371 | + if hasattr(config, "bos_token_id") and config.bos_token_id is not None |
| 372 | + else 1 |
| 373 | + ) |
| 374 | + eos_token_id = config.eos_token_id |
| 375 | + pad_token_id = ( |
| 376 | + config.pad_token_id |
| 377 | + if hasattr(config, "pad_token_id") and config.pad_token_id is not None |
| 378 | + else ( |
| 379 | + config.eos_token_id[0] |
| 380 | + if isinstance(config.eos_token_id, list) |
| 381 | + else config.eos_token_id |
| 382 | + ) |
| 383 | + ) |
| 384 | + input_names = [i.name for i in onx.graph.input] |
| 385 | + output_names = [i.name for i in onx.graph.output] |
| 386 | + past_key_values = [s for s in input_names if s.startswith("past_key_value")] |
| 387 | + first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015 |
| 388 | + shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim) |
| 389 | + return { |
| 390 | + "model": { |
| 391 | + "bos_token_id": bos_token_id, |
| 392 | + "context_length": config.max_position_embeddings, |
| 393 | + "decoder": { |
| 394 | + "session_options": { |
| 395 | + "log_id": "onnxruntime-genai", |
| 396 | + "provider_options": [], |
| 397 | + }, |
| 398 | + "filename": onnx_filename, |
| 399 | + "head_size": shape[-1], |
| 400 | + "hidden_size": config.hidden_size, |
| 401 | + "inputs": input_names, |
| 402 | + "outputs": output_names, |
| 403 | + "num_attention_heads": config.num_attention_heads, |
| 404 | + "num_hidden_layers": len(past_key_values) // 2, |
| 405 | + "num_key_value_heads": shape[1], |
| 406 | + }, |
| 407 | + "eos_token_id": eos_token_id, |
| 408 | + "pad_token_id": pad_token_id, |
| 409 | + # "type": self.model_type[ : self.model_type.find("For") |
| 410 | + # if "For" in self.model_type else len(self.model_type)].lower(), |
| 411 | + "vocab_size": config.vocab_size, |
| 412 | + }, |
| 413 | + "search": { |
| 414 | + "diversity_penalty": ( |
| 415 | + config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0 |
| 416 | + ), |
| 417 | + "do_sample": config.do_sample if hasattr(config, "do_sample") else False, |
| 418 | + "early_stopping": True, |
| 419 | + "length_penalty": ( |
| 420 | + config.length_penalty if hasattr(config, "length_penalty") else 1.0 |
| 421 | + ), |
| 422 | + "max_length": config.max_position_embeddings, |
| 423 | + "min_length": 0, |
| 424 | + "no_repeat_ngram_size": ( |
| 425 | + config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0 |
| 426 | + ), |
| 427 | + "num_beams": config.num_beams if hasattr(config, "num_beams") else 1, |
| 428 | + "num_return_sequences": ( |
| 429 | + config.num_return_sequences if hasattr(config, "num_return_sequences") else 1 |
| 430 | + ), |
| 431 | + "past_present_share_buffer": False, |
| 432 | + "repetition_penalty": ( |
| 433 | + config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0 |
| 434 | + ), |
| 435 | + "temperature": config.temperature if hasattr(config, "temperature") else 1.0, |
| 436 | + "top_k": config.top_k if hasattr(config, "top_k") else 50, |
| 437 | + "top_p": config.top_p if hasattr(config, "top_p") else 1.0, |
| 438 | + }, |
| 439 | + } |
0 commit comments