Skip to content

Commit cede26a

Browse files
committed
docuemntation
1 parent 01cac8a commit cede26a

File tree

6 files changed

+121
-7
lines changed

6 files changed

+121
-7
lines changed

_doc/api/reference/torch_ops/control_flow_ops.rst

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.controlflow_ops
3+
===================================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.controlflow_ops
6+
:members:

_doc/api/reference/torch_ops/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ onnx_diagnostic.reference.torch_ops
99

1010
access_ops
1111
binary_ops
12-
control_flow_ops
12+
controlflow_ops
1313
generator_ops
1414
nn_ops
1515
other_ops

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33
from onnx_diagnostic.ext_test_case import (
44
ExtTestCase,
5+
ignore_errors,
56
requires_torch,
67
requires_transformers,
78
hide_stdout,
@@ -22,6 +23,7 @@ class TestModelBuilderHelper(ExtTestCase):
2223
# This is to limit impact on CI.
2324
@requires_transformers("4.52")
2425
@requires_torch("2.7.99")
26+
@ignore_errors(OSError) # connectivity issues
2527
def test_download_model_builder(self):
2628
path = download_model_builder_to_cache()
2729
self.assertExists(path)
@@ -32,6 +34,7 @@ def test_download_model_builder(self):
3234
@requires_transformers("4.52")
3335
@requires_torch("2.7.99")
3436
@hide_stdout()
37+
@ignore_errors(OSError) # connectivity issues
3538
def test_model_builder_id(self):
3639
# clear&&python ~/.cache/onnx-diagnostic/builder.py
3740
# --model arnir0/Tiny-LLM -p fp16 -c dump_cache -e cpu -o dump_model

k.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import onnx
2+
import onnx.helper as oh
3+
import torch
4+
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.reference import TorchOnnxEvaluator
6+
7+
TFLOAT = onnx.TensorProto.FLOAT
8+
9+
proto = oh.make_model(
10+
oh.make_graph(
11+
[
12+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
13+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
14+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
15+
],
16+
"-nd-",
17+
[
18+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
19+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
20+
],
21+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
22+
),
23+
opset_imports=[oh.make_opsetid("", 18)],
24+
ir_version=9,
25+
)
26+
27+
sess = TorchOnnxEvaluator(proto, verbose=1)
28+
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
29+
result = sess.run(None, feeds)
30+
print(string_type(result, with_shape=True, with_min_max=True))

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,87 @@ class TorchOnnxEvaluator:
6262
The class is not multithreaded. `runtime_info` gets updated
6363
by the the class. The list of available kernels is returned by function
6464
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
65+
Example:
66+
67+
.. runpython::
68+
:showcode:
69+
70+
import onnx
71+
import onnx.helper as oh
72+
import torch
73+
from onnx_diagnostic.helpers import string_type
74+
from onnx_diagnostic.reference import TorchOnnxEvaluator
75+
76+
TFLOAT = onnx.TensorProto.FLOAT
77+
78+
proto = oh.make_model(
79+
oh.make_graph(
80+
[
81+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
82+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
83+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
84+
],
85+
"-nd-",
86+
[
87+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
88+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
89+
],
90+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
91+
),
92+
opset_imports=[oh.make_opsetid("", 18)],
93+
ir_version=9,
94+
)
95+
96+
sess = TorchOnnxEvaluator(proto)
97+
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
98+
result = sess.run(None, feeds)
99+
print(string_type(result, with_shape=True, with_min_max=True))
100+
101+
Adding ``verbose=1`` shows which kernels is executed:
102+
103+
.. runpython::
104+
:showcode:
105+
106+
import onnx
107+
import onnx.helper as oh
108+
import torch
109+
from onnx_diagnostic.helpers import string_type
110+
from onnx_diagnostic.reference import TorchOnnxEvaluator
111+
112+
TFLOAT = onnx.TensorProto.FLOAT
113+
114+
proto = oh.make_model(
115+
oh.make_graph(
116+
[
117+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
118+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
119+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
120+
],
121+
"-nd-",
122+
[
123+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
124+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
125+
],
126+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
127+
),
128+
opset_imports=[oh.make_opsetid("", 18)],
129+
ir_version=9,
130+
)
131+
132+
sess = TorchOnnxEvaluator(proto, verbose=1)
133+
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
134+
result = sess.run(None, feeds)
135+
print(string_type(result, with_shape=True, with_min_max=True))
136+
137+
It also shows when a result is not needed anymore. In that case,
138+
it is deleted to free the memory it takes.
139+
The runtime can also execute the kernel the onnx model on CUDA.
140+
It follows the same logic as :class:`onnxruntime.InferenceSession`:
141+
``providers=["CUDAExecutionProvider"]``.
142+
It is better in that case to move the input on CUDA. The class
143+
tries to move every weight on CUDA but tries to keep any tensor
144+
identified as a shape in CPU. Some bugs may remain as torch
145+
raises an exception when devices are expected to be the same.
65146
"""
66147

67148
class IO:

0 commit comments

Comments
 (0)