|
6 | 6 | from torch._C import _from_dlpack |
7 | 7 | import onnxruntime |
8 | 8 | from onnxruntime.capi import _pybind_state as ORTC |
9 | | -from .cache_helper import is_cache_dynamic_registered |
10 | | -from .helper import size_type, string_type, flatten_object |
| 9 | +from .helper import size_type |
11 | 10 | from .onnx_helper import ( |
12 | 11 | torch_dtype_to_onnx_dtype, |
13 | 12 | onnx_dtype_to_np_dtype, |
|
18 | 17 | DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} |
19 | 18 |
|
20 | 19 |
|
21 | | -def make_feeds( |
22 | | - proto: Union[onnx.ModelProto, List[str]], |
23 | | - inputs: Any, |
24 | | - use_numpy: bool = False, |
25 | | - copy: bool = False, |
26 | | -) -> Dict[str, Union[torch.Tensor, np.ndarray]]: |
27 | | - """ |
28 | | - Serializes the inputs to produce feeds expected |
29 | | - by :class:`onnxruntime.InferenceSession`. |
30 | | -
|
31 | | - :param proto: onnx model or list of names |
32 | | - :param inputs: any kind of inputs |
33 | | - :param use_numpy: if True, converts torch tensors into numpy arrays |
34 | | - :param copy: a copy is made, this should be the case if the inputs is ingested |
35 | | - by ``OrtValue`` |
36 | | - :return: feeds dictionary |
37 | | - """ |
38 | | - flat = flatten_object(inputs, drop_keys=True) |
39 | | - assert ( |
40 | | - not all(isinstance(obj, torch.Tensor) for obj in flat) |
41 | | - or not is_cache_dynamic_registered(fast=True) |
42 | | - or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0]) |
43 | | - ), ( |
44 | | - f"Unexpected number of flattened objects, " |
45 | | - f"{string_type(flat, with_shape=True, limit=20)} != " |
46 | | - f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}" |
47 | | - ) |
48 | | - if use_numpy: |
49 | | - flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] |
50 | | - names = ( |
51 | | - [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto |
52 | | - ) |
53 | | - if copy: |
54 | | - flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] |
55 | | - return dict(zip(names, flat)) |
56 | | - |
57 | | - |
58 | 20 | class _InferenceSession: |
59 | 21 |
|
60 | 22 | @classmethod |
|
0 commit comments