Skip to content

Commit 29b3a20

Browse files
committed
Support for custom kernels
1 parent 90ab957 commit 29b3a20

File tree

2 files changed

+232
-6
lines changed

2 files changed

+232
-6
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
import torch
77
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
88
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
9+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
910
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
11+
from onnx_diagnostic.reference.torch_ops import OpRun, OpRunTensor
1012
from onnx_diagnostic.reference.torch_evaluator import get_kernels
1113

1214

1315
TFLOAT = onnx.TensorProto.FLOAT
16+
TFLOAT16 = onnx.TensorProto.FLOAT16
1417
TINT64 = onnx.TensorProto.INT64
1518

1619

@@ -1369,6 +1372,105 @@ def test_tile(self):
13691372
torch.tensor([2, 2], dtype=torch.int64),
13701373
)
13711374

1375+
def test_custom_kernels(self):
1376+
class LayerNormalizationOrt(OpRun):
1377+
"LayerNormalization"
1378+
1379+
_shared = [0]
1380+
1381+
def __init__(self, node: onnx.NodeProto, version=None):
1382+
super().__init__(node, version)
1383+
self.axis = self.get_attribute_int(node, "axis", -1)
1384+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
1385+
self.stash_type = onnx_dtype_to_torch_dtype(
1386+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
1387+
)
1388+
self.compute_std = len(node.output) > 1
1389+
assert not self.compute_std
1390+
layer_model = oh.make_model(
1391+
oh.make_graph(
1392+
[
1393+
oh.make_node(
1394+
"LayerNormalization",
1395+
["X", "W", "B"],
1396+
["Z"],
1397+
axis=-1,
1398+
epsilon=9.999999974752427e-7,
1399+
)
1400+
],
1401+
"dummy",
1402+
[
1403+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
1404+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
1405+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
1406+
],
1407+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
1408+
),
1409+
ir_version=9,
1410+
opset_imports=[oh.make_opsetid("", 17)],
1411+
)
1412+
import onnxruntime
1413+
1414+
self.ort_sess = onnxruntime.InferenceSession(
1415+
layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
1416+
)
1417+
1418+
def run(self, x, scale, bias=None):
1419+
self._shared[0] += 1
1420+
feeds = dict(X=x, W=scale)
1421+
if bias is not None:
1422+
feeds["B"] = bias
1423+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
1424+
got = self.ort_sess.run(None, feeds)[0]
1425+
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
1426+
1427+
model = oh.make_model(
1428+
oh.make_graph(
1429+
[
1430+
oh.make_node(
1431+
"LayerNormalization",
1432+
["X", "W", "B"],
1433+
["ln"],
1434+
axis=-1,
1435+
epsilon=9.999999974752427e-7,
1436+
),
1437+
oh.make_node(
1438+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
1439+
),
1440+
],
1441+
"dummy",
1442+
[
1443+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
1444+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
1445+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
1446+
],
1447+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
1448+
),
1449+
ir_version=9,
1450+
opset_imports=[oh.make_opsetid("", 17)],
1451+
)
1452+
1453+
torch_sess = TorchOnnxEvaluator(model, verbose=0)
1454+
torch_sess_custom = TorchOnnxEvaluator(
1455+
model,
1456+
verbose=0,
1457+
custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
1458+
)
1459+
feeds = dict(
1460+
zip(
1461+
torch_sess.input_names,
1462+
[
1463+
torch.rand(3, 4, 5, dtype=torch.float16),
1464+
torch.abs(torch.rand(5, dtype=torch.float16)),
1465+
torch.rand(5, dtype=torch.float16),
1466+
],
1467+
)
1468+
)
1469+
expected = torch_sess.run(None, feeds)
1470+
got = torch_sess_custom.run(None, feeds)
1471+
self.assertEqualAny(expected, got)
1472+
self.assertEqual([1], LayerNormalizationOrt._shared)
1473+
13721474

13731475
if __name__ == "__main__":
13741476
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class TorchOnnxEvaluator:
4545
:param opsets: needed if proto is a graph
4646
:param functions: known local functions
4747
:param verbose: verbosity level
48+
:param custom_kernels: dictionary of kernels the user can defined to overwrite
49+
a specific implementation: ``("", "LayerNormalization"): CustomKernel``
4850
4951
The class holds the following attributes:
5052
@@ -98,7 +100,10 @@ class TorchOnnxEvaluator:
98100
result = sess.run(None, feeds)
99101
print(string_type(result, with_shape=True, with_min_max=True))
100102
101-
Adding ``verbose=1`` shows which kernels is executed:
103+
With ``verbose=1``, the class prints out every kernel run and
104+
and every result deleted along the run.
105+
It shows when a result is not needed anymore. In that case,
106+
it is deleted to free the memory it takes.
102107
103108
.. runpython::
104109
:showcode:
@@ -134,8 +139,6 @@ class TorchOnnxEvaluator:
134139
result = sess.run(None, feeds)
135140
print(string_type(result, with_shape=True, with_min_max=True))
136141
137-
It also shows when a result is not needed anymore. In that case,
138-
it is deleted to free the memory it takes.
139142
The runtime can also execute the kernel the onnx model on CUDA.
140143
It follows the same logic as :class:`onnxruntime.InferenceSession`:
141144
``providers=["CUDAExecutionProvider"]``.
@@ -144,6 +147,115 @@ class TorchOnnxEvaluator:
144147
identified as a shape in CPU. Some bugs may remain as torch
145148
raises an exception when devices are expected to be the same.
146149
The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
150+
Next example shows how to replace a kernel with a different
151+
one based on :epkg:`onnxruntime`.
152+
153+
.. runpython::
154+
:showcode:
155+
156+
import numpy as np
157+
import onnx
158+
import onnx.helper as oh
159+
import onnxruntime
160+
import torch
161+
from onnx_diagnostic.helpers import string_type
162+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
163+
from onnx_diagnostic.reference import TorchOnnxEvaluator
164+
from onnx_diagnostic.reference.torch_ops import OpRun, OpRunTensor
165+
166+
TFLOAT16 = onnx.TensorProto.FLOAT16
167+
168+
class LayerNormalizationOrt(OpRun):
169+
"LayerNormalization based on onnxruntime"
170+
171+
def __init__(self, node: onnx.NodeProto, version=None):
172+
super().__init__(node, version)
173+
self.axis = self.get_attribute_int(node, "axis", -1)
174+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
175+
self.stash_type = onnx_dtype_to_torch_dtype(
176+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
177+
)
178+
self.compute_std = len(node.output) > 1
179+
assert not self.compute_std, "The keren only computes the first output."
180+
layer_model = oh.make_model(
181+
oh.make_graph(
182+
[
183+
oh.make_node(
184+
"LayerNormalization",
185+
["X", "W", "B"],
186+
["Z"],
187+
axis=-1,
188+
epsilon=9.999999974752427e-7,
189+
)
190+
],
191+
"dummy",
192+
[
193+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
194+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
195+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
196+
],
197+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
198+
),
199+
ir_version=9,
200+
opset_imports=[oh.make_opsetid("", 17)],
201+
)
202+
self.ort_sess = onnxruntime.InferenceSession(
203+
layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
204+
)
205+
206+
def run(self, x, scale, bias=None):
207+
print(f"-- running {self.__class__.__name__}")
208+
feeds = dict(X=x, W=scale)
209+
if bias is not None:
210+
feeds["B"] = bias
211+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
212+
got = self.ort_sess.run(None, feeds)[0]
213+
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
214+
215+
# This kernel is tested on this model.
216+
model = oh.make_model(
217+
oh.make_graph(
218+
[
219+
oh.make_node(
220+
"LayerNormalization",
221+
["X", "W", "B"],
222+
["ln"],
223+
axis=-1,
224+
epsilon=9.999999974752427e-7,
225+
),
226+
oh.make_node(
227+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
228+
),
229+
],
230+
"dummy",
231+
[
232+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
233+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
234+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
235+
],
236+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
237+
),
238+
ir_version=9,
239+
opset_imports=[oh.make_opsetid("", 17)],
240+
)
241+
242+
torch_sess = TorchOnnxEvaluator(
243+
model,
244+
custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
245+
verbose=1,
246+
)
247+
feeds = dict(
248+
zip(
249+
torch_sess.input_names,
250+
[
251+
torch.rand(3, 4, 5, dtype=torch.float16),
252+
torch.abs(torch.rand(5, dtype=torch.float16)),
253+
torch.rand(5, dtype=torch.float16),
254+
],
255+
)
256+
)
257+
res = torch_sess.run(None, feeds)
258+
print(string_type(res, with_shape=True, with_min_max=True))
147259
"""
148260

149261
class IO:
@@ -172,13 +284,15 @@ def __init__(
172284
opsets: Optional[Dict[str, int]] = None,
173285
local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
174286
verbose: int = 0,
287+
custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRun]]] = None,
175288
):
176289
self.providers = providers
177290
self.constants: Dict[str, torch.Tensor] = {}
178291
self.kernels: List[Optional[torch_ops.OpRun]] = []
179292
self.functions = local_functions.copy() if local_functions else {}
180293
self.CPU = torch.tensor([0]).to("cpu").device
181294
self.verbose = verbose
295+
self.custom_kernels = custom_kernels or {}
182296
dev = self._on_cuda(providers)
183297
if dev < 0:
184298
self.default_device = self.CPU
@@ -296,6 +410,16 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
296410
kernels = get_kernels()
297411
self.kernels.clear()
298412
for node in nodes:
413+
opset = self.opsets[node.domain]
414+
key = node.domain, node.op_type, opset
415+
if key[:2] in self.custom_kernels:
416+
cls = self.custom_kernels[key[:2]]
417+
ags = [self.default_device] if cls.device_dependent() else []
418+
kws = dict(parent=self) if cls.has_subgraphs() else {}
419+
kernel2 = cls(node, opset, *ags, **kws)
420+
self.kernels.append(kernel2)
421+
continue
422+
299423
if (node.domain, node.op_type) in self.functions:
300424
kernel = torch_ops.OpRunFunction(
301425
self.functions[node.domain, node.op_type], node, self.opsets[node.domain]
@@ -308,8 +432,6 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
308432
self.kernels.append(None)
309433
continue
310434

311-
opset = self.opsets[node.domain]
312-
key = node.domain, node.op_type, opset
313435
while key not in kernels and opset > 0:
314436
opset -= 1
315437
key = node.domain, node.op_type, opset
@@ -438,7 +560,9 @@ def run_with_values(
438560
context: Optional[Dict[str, RuntimeValue]] = None,
439561
) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
440562
"""
441-
Runs the ONNX model.
563+
Runs the ONNX model. The signature is different.
564+
This method is called by every kernel hokding a subgraph.
565+
The local variables are stored in `context`.
442566
443567
:param args: inputs
444568
:param context: local context for the execution of subgraphs

0 commit comments

Comments
 (0)