@@ -62,6 +62,87 @@ class TorchOnnxEvaluator:
6262 The class is not multithreaded. `runtime_info` gets updated
6363 by the the class. The list of available kernels is returned by function
6464 :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
65+ Example:
66+
67+ .. runpython::
68+ :showcode:
69+
70+ import onnx
71+ import onnx.helper as oh
72+ import torch
73+ from onnx_diagnostic.helpers import string_type
74+ from onnx_diagnostic.reference import TorchOnnxEvaluator
75+
76+ TFLOAT = onnx.TensorProto.FLOAT
77+
78+ proto = oh.make_model(
79+ oh.make_graph(
80+ [
81+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
82+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
83+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
84+ ],
85+ "-nd-",
86+ [
87+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
88+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
89+ ],
90+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
91+ ),
92+ opset_imports=[oh.make_opsetid("", 18)],
93+ ir_version=9,
94+ )
95+
96+ sess = TorchOnnxEvaluator(proto)
97+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
98+ result = sess.run(None, feeds)
99+ print(string_type(result, with_shape=True, with_min_max=True))
100+
101+ Adding ``verbose=1`` shows which kernels is executed:
102+
103+ .. runpython::
104+ :showcode:
105+
106+ import onnx
107+ import onnx.helper as oh
108+ import torch
109+ from onnx_diagnostic.helpers import string_type
110+ from onnx_diagnostic.reference import TorchOnnxEvaluator
111+
112+ TFLOAT = onnx.TensorProto.FLOAT
113+
114+ proto = oh.make_model(
115+ oh.make_graph(
116+ [
117+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
118+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
119+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
120+ ],
121+ "-nd-",
122+ [
123+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
124+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
125+ ],
126+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
127+ ),
128+ opset_imports=[oh.make_opsetid("", 18)],
129+ ir_version=9,
130+ )
131+
132+ sess = TorchOnnxEvaluator(proto, verbose=1)
133+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
134+ result = sess.run(None, feeds)
135+ print(string_type(result, with_shape=True, with_min_max=True))
136+
137+ It also shows when a result is not needed anymore. In that case,
138+ it is deleted to free the memory it takes.
139+ The runtime can also execute the kernel the onnx model on CUDA.
140+ It follows the same logic as :class:`onnxruntime.InferenceSession`:
141+ ``providers=["CUDAExecutionProvider"]``.
142+ It is better in that case to move the input on CUDA. The class
143+ tries to move every weight on CUDA but tries to keep any tensor
144+ identified as a shape in CPU. Some bugs may remain as torch
145+ raises an exception when devices are expected to be the same.
65146 """
66147
67148 class IO :
0 commit comments