Skip to content

Commit 81acb67

Browse files
committed
f
1 parent 73cbba2 commit 81acb67

File tree

5 files changed

+110
-6
lines changed

5 files changed

+110
-6
lines changed

_unittests/ut_reference/test_torch_evaluator.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ def test_kernels(self):
2222
kernel = ker[key]
2323
self.assertEqual("Add_1", kernel.__name__)
2424

25-
def _finalize_test(self, model, *args, atol: float = 0):
25+
def _finalize_test(self, model, *args, atol: float = 0, use_ort: bool = False):
2626
onnx.checker.check_model(model)
2727
feeds = dict(zip([i.name for i in model.graph.input], args))
28+
feeds_numpy = {k: v.numpy() for k, v in feeds.items()}
2829

29-
expected = ExtendedReferenceEvaluator(model).run(
30-
None, {k: v.numpy() for k, v in feeds.items()}
31-
)
30+
if use_ort:
31+
import onnxruntime
32+
33+
sess = onnxruntime.InferenceSession(
34+
model.SerializeToString(), providers=["CPUExecutionProvider"]
35+
)
36+
expected = sess.run(None, feeds_numpy)
37+
else:
38+
expected = ExtendedReferenceEvaluator(model).run(None, feeds_numpy)
3239
rt = TorchEvaluator(model)
3340
got = rt.run(None, feeds)
3441
self.assertEqualAny(expected, [g.detach().numpy() for g in got], atol=atol)
@@ -453,6 +460,52 @@ def test_op_reduce_sum(self):
453460
atol=1e-5,
454461
)
455462

463+
def test_op_where(self):
464+
model = oh.make_model(
465+
oh.make_graph(
466+
[
467+
oh.make_node("Greater", ["X", "Y"], ["cond"]),
468+
oh.make_node("Where", ["cond", "X", "Y"], ["Z"]),
469+
],
470+
"dummy",
471+
[
472+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
473+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
474+
],
475+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
476+
),
477+
ir_version=9,
478+
opset_imports=[oh.make_opsetid("", 18)],
479+
)
480+
self._finalize_test(
481+
model,
482+
torch.rand(3, 4, 5, dtype=torch.float32),
483+
torch.rand(3, 4, 5, dtype=torch.float32),
484+
)
485+
486+
def test_op_layer_normalization(self):
487+
model = oh.make_model(
488+
oh.make_graph(
489+
[oh.make_node("LayerNormalization", ["X", "W", "B"], ["Z"], axis=-1)],
490+
"dummy",
491+
[
492+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
493+
oh.make_tensor_value_info("W", TFLOAT, []),
494+
oh.make_tensor_value_info("B", TFLOAT, []),
495+
],
496+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
497+
),
498+
ir_version=9,
499+
opset_imports=[oh.make_opsetid("", 18)],
500+
)
501+
self._finalize_test(
502+
model,
503+
torch.rand(3, 4, 5, dtype=torch.float32),
504+
torch.abs(torch.rand(5, dtype=torch.float32)),
505+
torch.rand(5, dtype=torch.float32),
506+
use_ort=True,
507+
)
508+
456509

457510
if __name__ == "__main__":
458511
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
Or_1,
1414
Sub_1,
1515
)
16-
from .other_ops import Cast_6, Concat_1, Transpose_1
17-
from .nn_ops import Softmax_13, Tanh_6
16+
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
17+
from .other_ops import Cast_6, Concat_1, Transpose_1, Where_9
1818
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
1919
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13
2020
from .unary_ops import Neg_1, Not_1, Reciprocal_1

onnx_diagnostic/reference/torch_ops/_op_run.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ def _find_attribute(self, node: onnx.NodeProto, name: str):
8080
return att
8181
return None
8282

83+
def get_attribute_float(
84+
self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
85+
) -> Optional[float]:
86+
"""
87+
Returns an attribute as an int.
88+
89+
:param node: NodeProto
90+
:param name: name
91+
:param default_value: default_value
92+
:return: value
93+
"""
94+
att = self._find_attribute(node, name)
95+
return default_value if att is None else float(att.f)
96+
8397
def get_attribute_int(
8498
self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
8599
) -> Optional[int]:

onnx_diagnostic/reference/torch_ops/nn_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,36 @@
55
from . import OpRun, OpRunValue
66

77

8+
class LayerNormalization_17(OpRun):
9+
"LayerNormalization"
10+
11+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12+
super().__init__(node, version)
13+
self.axis = self.get_attribute_int(node, "axis", -1)
14+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
15+
self.stash_type = onnx_dtype_to_torch_dtype(
16+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
17+
)
18+
self.compute_std = len(node.output) > 1
19+
20+
def run(self, x, scale, bias=None):
21+
original_dtype = x.dtype
22+
xt = x.tensor.to(self.stash_type)
23+
res = torch.nn.functional.layer_norm(
24+
xt,
25+
xt.shape[: self.axis],
26+
weight=scale.tensor,
27+
bias=None if bias is None else bias.tensor,
28+
eps=self.epsilon,
29+
)
30+
if not self.compute_std:
31+
return res.to(original_dtype)
32+
axes = tuple(range(len(xt.shape)))[self.axis :]
33+
mean, var = torch.var(xt, dim=axes, keepdim=False)
34+
x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
35+
return res.to(original_dtype), mean, x_inv_std_dev
36+
37+
838
class Softmax_13(OpRun):
939
"Softmax"
1040

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
4040

4141
def run(self, data: OpRunValue) -> OpRunValue:
4242
return OpRunValue(torch.permute(data.tensor, self.perm))
43+
44+
45+
class Where_9(OpRun):
46+
"Where"
47+
48+
def run(self, cond: OpRunValue, x: OpRunValue, y: OpRunValue) -> OpRunValue:
49+
return OpRunValue(torch.where(cond.tensor, x.tensor, y.tensor))

0 commit comments

Comments
 (0)