Skip to content

Commit 026d928

Browse files
committed
ci
1 parent 115a688 commit 026d928

File tree

2 files changed

+100
-11
lines changed

2 files changed

+100
-11
lines changed

_unittests/ut_helpers/test_doc_helper.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import onnx.helper as oh
44
import torch
55
from onnx_diagnostic.ext_test_case import ExtTestCase
6-
from onnx_diagnostic.helpers.doc_helper import LayerNormalizationOrt
6+
from onnx_diagnostic.helpers.doc_helper import LayerNormalizationOrt, MatMulOrt
77
from onnx_diagnostic.reference import TorchOnnxEvaluator
88

99
TFLOAT = onnx.TensorProto.FLOAT
1010
TFLOAT16 = onnx.TensorProto.FLOAT16
1111

1212

1313
class TestDocHelper(ExtTestCase):
14-
def test_custom_doc_kernels(self):
14+
def test_custom_doc_kernels_layer_normalization(self):
1515
model = oh.make_model(
1616
oh.make_graph(
1717
[
@@ -35,7 +35,7 @@ def test_custom_doc_kernels(self):
3535
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
3636
),
3737
ir_version=9,
38-
opset_imports=[oh.make_opsetid("", 17)],
38+
opset_imports=[oh.make_opsetid("", 18)],
3939
)
4040

4141
torch_sess = TorchOnnxEvaluator(model, verbose=0)
@@ -58,6 +58,40 @@ def test_custom_doc_kernels(self):
5858
got = torch_sess_custom.run(None, feeds)
5959
self.assertEqualAny(expected, got)
6060

61+
def test_custom_doc_kernels_matmul(self):
62+
model = oh.make_model(
63+
oh.make_graph(
64+
[oh.make_node("MatMul", ["X", "Y"], ["Z"])],
65+
"dummy",
66+
[
67+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
68+
oh.make_tensor_value_info("Y", TFLOAT16, ["b", "d", "e"]),
69+
],
70+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "e"])],
71+
),
72+
ir_version=9,
73+
opset_imports=[oh.make_opsetid("", 18)],
74+
)
75+
76+
torch_sess = TorchOnnxEvaluator(model, verbose=0)
77+
torch_sess_custom = TorchOnnxEvaluator(
78+
model,
79+
verbose=0,
80+
custom_kernels={("", "MatMul"): MatMulOrt},
81+
)
82+
feeds = dict(
83+
zip(
84+
torch_sess.input_names,
85+
[
86+
torch.rand(3, 4, 5, dtype=torch.float16),
87+
torch.rand(3, 5, 7, dtype=torch.float16),
88+
],
89+
)
90+
)
91+
expected = torch_sess.run(None, feeds)
92+
got = torch_sess_custom.run(None, feeds)
93+
self.assertEqualAny(expected, got)
94+
6195

6296
if __name__ == "__main__":
6397
unittest.main(verbosity=2)

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class LayerNormalizationOrt(OpRunKernel):
10-
"LayerNormalization"
10+
"LayerNormalization with onnxruntime"
1111

1212
@classmethod
1313
def device_dependent(cls) -> bool:
@@ -25,7 +25,7 @@ def __init__(
2525
self.axis = self.get_attribute_int(node, "axis", -1)
2626
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
2727
self.device = device
28-
self.stash_type = onnx_dtype_to_torch_dtype(
28+
self.stash_type = onnx_dtype_to_torch_dtype( # type: ignore[arg-type]
2929
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
3030
)
3131
self.compute_std = len(node.output) > 1
@@ -36,7 +36,7 @@ def __init__(
3636
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
3737
self.is_cpu = torch.device("cpu") == self.device
3838

39-
def _make_model(self, dtype: int, rank: int) -> onnx.ModelProto:
39+
def _make_model(self, itype: int, rank: int) -> onnx.ModelProto:
4040
shape = [*["d{i}" for i in range(rank - 1)], "last"]
4141
layer_model = oh.make_model(
4242
oh.make_graph(
@@ -51,14 +51,14 @@ def _make_model(self, dtype: int, rank: int) -> onnx.ModelProto:
5151
],
5252
"dummy",
5353
[
54-
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT16, shape),
55-
oh.make_tensor_value_info("W", onnx.TensorProto.FLOAT16, ["last"]),
56-
oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT16, ["last"]),
54+
oh.make_tensor_value_info("X", itype, shape),
55+
oh.make_tensor_value_info("W", itype, ["last"]),
56+
oh.make_tensor_value_info("B", itype, ["last"]),
5757
],
58-
[oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT16, shape)],
58+
[oh.make_tensor_value_info("Z", itype, shape)],
5959
),
6060
ir_version=9,
61-
opset_imports=[oh.make_opsetid("", 17)],
61+
opset_imports=[oh.make_opsetid("", 18)],
6262
)
6363
import onnxruntime
6464

@@ -80,3 +80,58 @@ def run(self, x, scale, bias=None):
8080
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
8181
got = sess.run(None, feeds)[0]
8282
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
83+
84+
85+
class MatMulOrt(OpRunKernel):
86+
"MatMul with onnxruntime"
87+
88+
@classmethod
89+
def device_dependent(cls) -> bool:
90+
"Needs device."
91+
return False
92+
93+
def __init__(
94+
self,
95+
node: onnx.NodeProto,
96+
version=None,
97+
device: Optional[torch.device] = None,
98+
verbose=0,
99+
):
100+
super().__init__(node, version, verbose=verbose)
101+
self.device = device
102+
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
103+
self.is_cpu = torch.device("cpu") == self.device
104+
105+
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
106+
shapea = ["a{i}" for i in range(ranka)]
107+
shapeb = ["b{i}" for i in range(rankb)]
108+
shapec = ["c{i}" for i in range(max(ranka, rankb))]
109+
model = oh.make_model(
110+
oh.make_graph(
111+
[oh.make_node("MatMul", ["A", "B"], ["C"])],
112+
"dummy",
113+
[
114+
oh.make_tensor_value_info("A", itype, shapea),
115+
oh.make_tensor_value_info("B", itype, shapeb),
116+
],
117+
[oh.make_tensor_value_info("C", itype, shapec)],
118+
),
119+
ir_version=9,
120+
opset_imports=[oh.make_opsetid("", 17)],
121+
)
122+
import onnxruntime
123+
124+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
125+
return onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider])
126+
127+
def run(self, a, b):
128+
itype = torch_dtype_to_onnx_dtype(a.dtype)
129+
ranka, rankb = len(a.shape), len(b.shape)
130+
key = itype, ranka, rankb
131+
if key not in self._cache:
132+
self._cache[key] = self._make_model(itype, ranka, rankb)
133+
sess = self._cache[key]
134+
feeds = dict(A=a, B=b)
135+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
136+
got = sess.run(None, feeds)[0]
137+
return OpRunTensor(torch.from_numpy(got).to(a.dtype).to(a.device))

0 commit comments

Comments
 (0)