Skip to content

Commit 022a766

Browse files
committed
fix ser
1 parent 77517df commit 022a766

File tree

9 files changed

+67
-50
lines changed

9 files changed

+67
-50
lines changed

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ onnx_diagnostic.helpers
1313
memory_peak
1414
onnx_helper
1515
ort_session
16+
rt_helper
1617
torch_test_helper
1718

1819
.. autofunction:: onnx_diagnostic.helpers.max_diff

_doc/api/helpers/rt_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.rt_helper
3+
=================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.rt_helper
6+
:members:
7+
:no-undoc-members:

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from onnx_diagnostic import doc
2626
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
2727
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
28-
from onnx_diagnostic.helpers.ort_session import make_feeds
28+
from onnx_diagnostic.helpers.rt_helper import make_feeds
2929
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
3030
from onnx_diagnostic.torch_models.hghub import (
3131
get_untrained_model_with_inputs,

_unittests/ut_helpers/test_ort_session_tinyllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from onnxruntime.capi import _pybind_state as ORTC
88
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
99
from onnx_diagnostic.helpers import max_diff
10+
from onnx_diagnostic.rt_helper import make_feeds
1011
from onnx_diagnostic.helpers.ort_session import (
1112
InferenceSessionForNumpy,
1213
InferenceSessionForTorch,
13-
make_feeds,
1414
)
1515
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1616
from onnx_diagnostic.torch_models.llms import get_tiny_llm

onnx_diagnostic/ext_test_case.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,8 @@ def assert_onnx_disc(
10811081
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
10821082
"""
10831083
from .helpers import string_type, string_diff, max_diff
1084-
from .helpers.ort_session import InferenceSessionForTorch, make_feeds
1084+
from .rt_helper import make_feeds
1085+
from .helpers.ort_session import InferenceSessionForTorch
10851086

10861087
kws = dict(with_shape=True, with_min_max=verbose > 1)
10871088
if verbose:

onnx_diagnostic/helpers/ort_session.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from torch._C import _from_dlpack
77
import onnxruntime
88
from onnxruntime.capi import _pybind_state as ORTC
9-
from .cache_helper import is_cache_dynamic_registered
10-
from .helper import size_type, string_type, flatten_object
9+
from .helper import size_type
1110
from .onnx_helper import (
1211
torch_dtype_to_onnx_dtype,
1312
onnx_dtype_to_np_dtype,
@@ -18,43 +17,6 @@
1817
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
1918

2019

21-
def make_feeds(
22-
proto: Union[onnx.ModelProto, List[str]],
23-
inputs: Any,
24-
use_numpy: bool = False,
25-
copy: bool = False,
26-
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
27-
"""
28-
Serializes the inputs to produce feeds expected
29-
by :class:`onnxruntime.InferenceSession`.
30-
31-
:param proto: onnx model or list of names
32-
:param inputs: any kind of inputs
33-
:param use_numpy: if True, converts torch tensors into numpy arrays
34-
:param copy: a copy is made, this should be the case if the inputs is ingested
35-
by ``OrtValue``
36-
:return: feeds dictionary
37-
"""
38-
flat = flatten_object(inputs, drop_keys=True)
39-
assert (
40-
not all(isinstance(obj, torch.Tensor) for obj in flat)
41-
or not is_cache_dynamic_registered(fast=True)
42-
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
43-
), (
44-
f"Unexpected number of flattened objects, "
45-
f"{string_type(flat, with_shape=True, limit=20)} != "
46-
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}"
47-
)
48-
if use_numpy:
49-
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
50-
names = (
51-
[i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto
52-
)
53-
if copy:
54-
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
55-
return dict(zip(names, flat))
56-
57-
5820
class _InferenceSession:
5921

6022
@classmethod
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any, Dict, List, Union
2+
import numpy as np
3+
import onnx
4+
import torch
5+
from .helper import string_type, flatten_object
6+
from .cache_helper import is_cache_dynamic_registered
7+
8+
9+
def make_feeds(
10+
proto: Union[onnx.ModelProto, List[str]],
11+
inputs: Any,
12+
use_numpy: bool = False,
13+
copy: bool = False,
14+
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
15+
"""
16+
Serializes the inputs to produce feeds expected
17+
by :class:`onnxruntime.InferenceSession`.
18+
19+
:param proto: onnx model or list of names
20+
:param inputs: any kind of inputs
21+
:param use_numpy: if True, converts torch tensors into numpy arrays
22+
:param copy: a copy is made, this should be the case if the inputs is ingested
23+
by ``OrtValue``
24+
:return: feeds dictionary
25+
"""
26+
flat = flatten_object(inputs, drop_keys=True)
27+
assert (
28+
not all(isinstance(obj, torch.Tensor) for obj in flat)
29+
or not is_cache_dynamic_registered(fast=True)
30+
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
31+
), (
32+
f"Unexpected number of flattened objects, "
33+
f"{string_type(flat, with_shape=True, limit=20)} != "
34+
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}"
35+
)
36+
if use_numpy:
37+
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
38+
names = (
39+
[i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto
40+
)
41+
if copy:
42+
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
43+
return dict(zip(names, flat))

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,17 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
298298
kwargs = dict(
299299
batch_size=2,
300300
sequence_length=30,
301-
dummy_max_token_id=config.vocab_size,
302-
max_source_positions=config.max_source_positions,
303-
d_model=config.d_model,
301+
dummy_max_token_id=31000 if config is None else config.vocab_size,
302+
max_source_positions=1500 if config is None else config.max_source_positions,
303+
d_model=384 if config is None else config.d_model,
304304
num_hidden_layers=4 if config is None else config.num_hidden_layers,
305-
encoder_attention_heads=config.encoder_attention_heads,
306-
encoder_layers=config.encoder_layers,
307-
decoder_layers=config.decoder_layers,
308-
head_dim=config.d_model // config.encoder_attention_heads,
305+
encoder_attention_heads=6 if config is None else config.encoder_attention_heads,
306+
encoder_layers=4 if config is None else config.encoder_layers,
307+
decoder_attention_heads=6 if config is None else config.decoder_attention_heads,
308+
decoder_layers=4 if config is None else config.decoder_layers,
309+
head_dim=(
310+
64 if config is None else (config.d_model // config.encoder_attention_heads)
311+
),
309312
)
310313
fct = get_inputs_for_speech_automatic_recognition # type: ignore
311314
else:

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from ..helpers import max_diff, string_type, string_diff
88
from ..helpers.helper import flatten_object
9-
from ..helpers.ort_session import make_feeds
9+
from ..helpers.rt_helper import make_feeds
1010
from ..helpers.torch_test_helper import to_any, torch_deepcopy
1111
from ..torch_export_patches import bypass_export_some_errors
1212
from .hghub import get_untrained_model_with_inputs

0 commit comments

Comments
 (0)