@@ -45,6 +45,8 @@ class TorchOnnxEvaluator:
4545 :param opsets: needed if proto is a graph
4646 :param functions: known local functions
4747 :param verbose: verbosity level
48+ :param custom_kernels: dictionary of kernels the user can defined to overwrite
49+ a specific implementation: ``("", "LayerNormalization"): CustomKernel``
4850
4951 The class holds the following attributes:
5052
@@ -98,7 +100,10 @@ class TorchOnnxEvaluator:
98100 result = sess.run(None, feeds)
99101 print(string_type(result, with_shape=True, with_min_max=True))
100102
101- Adding ``verbose=1`` shows which kernels is executed:
103+ With ``verbose=1``, the class prints out every kernel run and
104+ and every result deleted along the run.
105+ It shows when a result is not needed anymore. In that case,
106+ it is deleted to free the memory it takes.
102107
103108 .. runpython::
104109 :showcode:
@@ -134,8 +139,6 @@ class TorchOnnxEvaluator:
134139 result = sess.run(None, feeds)
135140 print(string_type(result, with_shape=True, with_min_max=True))
136141
137- It also shows when a result is not needed anymore. In that case,
138- it is deleted to free the memory it takes.
139142 The runtime can also execute the kernel the onnx model on CUDA.
140143 It follows the same logic as :class:`onnxruntime.InferenceSession`:
141144 ``providers=["CUDAExecutionProvider"]``.
@@ -144,6 +147,115 @@ class TorchOnnxEvaluator:
144147 identified as a shape in CPU. Some bugs may remain as torch
145148 raises an exception when devices are expected to be the same.
146149 The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
150+ Next example shows how to replace a kernel with a different
151+ one based on :epkg:`onnxruntime`.
152+
153+ .. runpython::
154+ :showcode:
155+
156+ import numpy as np
157+ import onnx
158+ import onnx.helper as oh
159+ import onnxruntime
160+ import torch
161+ from onnx_diagnostic.helpers import string_type
162+ from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
163+ from onnx_diagnostic.reference import TorchOnnxEvaluator
164+ from onnx_diagnostic.reference.torch_ops import OpRun, OpRunTensor
165+
166+ TFLOAT16 = onnx.TensorProto.FLOAT16
167+
168+ class LayerNormalizationOrt(OpRun):
169+ "LayerNormalization based on onnxruntime"
170+
171+ def __init__(self, node: onnx.NodeProto, version=None):
172+ super().__init__(node, version)
173+ self.axis = self.get_attribute_int(node, "axis", -1)
174+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
175+ self.stash_type = onnx_dtype_to_torch_dtype(
176+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
177+ )
178+ self.compute_std = len(node.output) > 1
179+ assert not self.compute_std, "The keren only computes the first output."
180+ layer_model = oh.make_model(
181+ oh.make_graph(
182+ [
183+ oh.make_node(
184+ "LayerNormalization",
185+ ["X", "W", "B"],
186+ ["Z"],
187+ axis=-1,
188+ epsilon=9.999999974752427e-7,
189+ )
190+ ],
191+ "dummy",
192+ [
193+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
194+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
195+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
196+ ],
197+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
198+ ),
199+ ir_version=9,
200+ opset_imports=[oh.make_opsetid("", 17)],
201+ )
202+ self.ort_sess = onnxruntime.InferenceSession(
203+ layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
204+ )
205+
206+ def run(self, x, scale, bias=None):
207+ print(f"-- running {self.__class__.__name__}")
208+ feeds = dict(X=x, W=scale)
209+ if bias is not None:
210+ feeds["B"] = bias
211+ feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
212+ got = self.ort_sess.run(None, feeds)[0]
213+ return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
214+
215+ # This kernel is tested on this model.
216+ model = oh.make_model(
217+ oh.make_graph(
218+ [
219+ oh.make_node(
220+ "LayerNormalization",
221+ ["X", "W", "B"],
222+ ["ln"],
223+ axis=-1,
224+ epsilon=9.999999974752427e-7,
225+ ),
226+ oh.make_node(
227+ "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
228+ ),
229+ ],
230+ "dummy",
231+ [
232+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
233+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
234+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
235+ ],
236+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
237+ ),
238+ ir_version=9,
239+ opset_imports=[oh.make_opsetid("", 17)],
240+ )
241+
242+ torch_sess = TorchOnnxEvaluator(
243+ model,
244+ custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
245+ verbose=1,
246+ )
247+ feeds = dict(
248+ zip(
249+ torch_sess.input_names,
250+ [
251+ torch.rand(3, 4, 5, dtype=torch.float16),
252+ torch.abs(torch.rand(5, dtype=torch.float16)),
253+ torch.rand(5, dtype=torch.float16),
254+ ],
255+ )
256+ )
257+ res = torch_sess.run(None, feeds)
258+ print(string_type(res, with_shape=True, with_min_max=True))
147259 """
148260
149261 class IO :
@@ -172,13 +284,15 @@ def __init__(
172284 opsets : Optional [Dict [str , int ]] = None ,
173285 local_functions : Optional [Dict [Tuple [str , str ], "TorchOnnxEvaluator" ]] = None ,
174286 verbose : int = 0 ,
287+ custom_kernels : Optional [Dict [Tuple [str , str ], type [torch_ops .OpRun ]]] = None ,
175288 ):
176289 self .providers = providers
177290 self .constants : Dict [str , torch .Tensor ] = {}
178291 self .kernels : List [Optional [torch_ops .OpRun ]] = []
179292 self .functions = local_functions .copy () if local_functions else {}
180293 self .CPU = torch .tensor ([0 ]).to ("cpu" ).device
181294 self .verbose = verbose
295+ self .custom_kernels = custom_kernels or {}
182296 dev = self ._on_cuda (providers )
183297 if dev < 0 :
184298 self .default_device = self .CPU
@@ -296,6 +410,16 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
296410 kernels = get_kernels ()
297411 self .kernels .clear ()
298412 for node in nodes :
413+ opset = self .opsets [node .domain ]
414+ key = node .domain , node .op_type , opset
415+ if key [:2 ] in self .custom_kernels :
416+ cls = self .custom_kernels [key [:2 ]]
417+ ags = [self .default_device ] if cls .device_dependent () else []
418+ kws = dict (parent = self ) if cls .has_subgraphs () else {}
419+ kernel2 = cls (node , opset , * ags , ** kws )
420+ self .kernels .append (kernel2 )
421+ continue
422+
299423 if (node .domain , node .op_type ) in self .functions :
300424 kernel = torch_ops .OpRunFunction (
301425 self .functions [node .domain , node .op_type ], node , self .opsets [node .domain ]
@@ -308,8 +432,6 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
308432 self .kernels .append (None )
309433 continue
310434
311- opset = self .opsets [node .domain ]
312- key = node .domain , node .op_type , opset
313435 while key not in kernels and opset > 0 :
314436 opset -= 1
315437 key = node .domain , node .op_type , opset
@@ -438,7 +560,9 @@ def run_with_values(
438560 context : Optional [Dict [str , RuntimeValue ]] = None ,
439561 ) -> Union [torch_ops .OpRunValue , Tuple [torch_ops .OpRunValue , ...]]:
440562 """
441- Runs the ONNX model.
563+ Runs the ONNX model. The signature is different.
564+ This method is called by every kernel hokding a subgraph.
565+ The local variables are stored in `context`.
442566
443567 :param args: inputs
444568 :param context: local context for the execution of subgraphs
0 commit comments