Skip to content

Commit 6dbde53

Browse files
committed
add more ops
1 parent 821f9be commit 6dbde53

File tree

7 files changed

+296
-8
lines changed

7 files changed

+296
-8
lines changed

_unittests/ut_reference/test_torch_evaluator.py renamed to _unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,212 @@ def test_op_layer_normalization_big_eps(self):
559559
atol=1e-4,
560560
)
561561

562+
def test_op_range_float(self):
563+
model = oh.make_model(
564+
oh.make_graph(
565+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
566+
"dummy",
567+
[
568+
oh.make_tensor_value_info("start", TFLOAT, []),
569+
oh.make_tensor_value_info("limit", TFLOAT, []),
570+
oh.make_tensor_value_info("delta", TFLOAT, []),
571+
],
572+
[oh.make_tensor_value_info("Z", TFLOAT, ["a"])],
573+
),
574+
ir_version=9,
575+
opset_imports=[oh.make_opsetid("", 18)],
576+
)
577+
self._finalize_test(
578+
model,
579+
torch.tensor(2.1, dtype=torch.float32),
580+
torch.tensor(5.1, dtype=torch.float32),
581+
torch.tensor(1, dtype=torch.float32),
582+
)
583+
584+
def test_op_range_int64(self):
585+
model = oh.make_model(
586+
oh.make_graph(
587+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
588+
"dummy",
589+
[
590+
oh.make_tensor_value_info("start", TINT64, []),
591+
oh.make_tensor_value_info("limit", TINT64, []),
592+
oh.make_tensor_value_info("delta", TINT64, []),
593+
],
594+
[oh.make_tensor_value_info("Z", TINT64, ["a"])],
595+
),
596+
ir_version=9,
597+
opset_imports=[oh.make_opsetid("", 18)],
598+
)
599+
self._finalize_test(
600+
model,
601+
torch.tensor(2, dtype=torch.int64),
602+
torch.tensor(5, dtype=torch.int64),
603+
torch.tensor(1, dtype=torch.int64),
604+
)
605+
606+
def test_op_range_int64_h2(self):
607+
model = oh.make_model(
608+
oh.make_graph(
609+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
610+
"dummy",
611+
[
612+
oh.make_tensor_value_info("start", TINT64, []),
613+
oh.make_tensor_value_info("limit", TINT64, []),
614+
oh.make_tensor_value_info("delta", TINT64, []),
615+
],
616+
[oh.make_tensor_value_info("Z", TINT64, ["a"])],
617+
),
618+
ir_version=9,
619+
opset_imports=[oh.make_opsetid("", 18)],
620+
)
621+
self._finalize_test(
622+
model,
623+
torch.tensor(2, dtype=torch.int64),
624+
torch.tensor(5, dtype=torch.int64),
625+
torch.tensor(2, dtype=torch.int64),
626+
)
627+
628+
def test_op_expand(self):
629+
model = oh.make_model(
630+
oh.make_graph(
631+
[oh.make_node("Expand", ["X", "shape"], ["Y"])],
632+
"dummy",
633+
[
634+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c", "d"]),
635+
oh.make_tensor_value_info("shape", TINT64, ["f"]),
636+
],
637+
[oh.make_tensor_value_info("Y", TFLOAT, ["aa", "ba", "ca", "da"])],
638+
),
639+
ir_version=9,
640+
opset_imports=[oh.make_opsetid("", 18)],
641+
)
642+
self._finalize_test(
643+
model,
644+
torch.rand((1, 5, 6, 7), dtype=torch.float32),
645+
torch.tensor([4, 5, 1, 1], dtype=torch.int64),
646+
)
647+
648+
def test_op_unary_op(self):
649+
model = oh.make_model(
650+
oh.make_graph(
651+
[
652+
oh.make_node("Cos", ["X"], ["nx"]),
653+
oh.make_node("Sin", ["nx"], ["t"]),
654+
oh.make_node("Exp", ["t"], ["u"]),
655+
oh.make_node("Log", ["u"], ["Z"]),
656+
],
657+
"dummy",
658+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
659+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
660+
),
661+
ir_version=9,
662+
opset_imports=[oh.make_opsetid("", 18)],
663+
)
664+
onnx.checker.check_model(model)
665+
self._finalize_test(model, torch.abs(torch.rand(3, 4, dtype=torch.float32)), atol=1e-6)
666+
667+
def test_op_pow_op(self):
668+
model = oh.make_model(
669+
oh.make_graph(
670+
[oh.make_node("Pow", ["X", "Y"], ["Z"])],
671+
"dummy",
672+
[
673+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
674+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b"]),
675+
],
676+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
677+
),
678+
ir_version=9,
679+
opset_imports=[oh.make_opsetid("", 18)],
680+
)
681+
onnx.checker.check_model(model)
682+
self._finalize_test(
683+
model,
684+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
685+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
686+
atol=1e-7,
687+
)
688+
689+
def test_op_pow_op_int(self):
690+
model = oh.make_model(
691+
oh.make_graph(
692+
[oh.make_node("Pow", ["X", "Y"], ["Z"])],
693+
"dummy",
694+
[
695+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
696+
oh.make_tensor_value_info("Y", TINT64, ["a", "b"]),
697+
],
698+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
699+
),
700+
ir_version=9,
701+
opset_imports=[oh.make_opsetid("", 18)],
702+
)
703+
onnx.checker.check_model(model)
704+
self._finalize_test(
705+
model,
706+
torch.rand(3, 4, 5, dtype=torch.float32),
707+
torch.tensor([2], dtype=torch.int64),
708+
atol=1e-7,
709+
)
710+
711+
def test_op_sqrt_op(self):
712+
model = oh.make_model(
713+
oh.make_graph(
714+
[oh.make_node("Sqrt", ["X"], ["Z"])],
715+
"dummy",
716+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
717+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
718+
),
719+
ir_version=9,
720+
opset_imports=[oh.make_opsetid("", 18)],
721+
)
722+
onnx.checker.check_model(model)
723+
self._finalize_test(model, torch.abs(torch.rand(3, 4, dtype=torch.float32)), atol=1e-6)
724+
725+
def test_op_split_op(self):
726+
model = oh.make_model(
727+
oh.make_graph(
728+
[oh.make_node("Split", ["X"], ["Z1", "Z2"], axis=1, num_outputs=2)],
729+
"dummy",
730+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
731+
[
732+
oh.make_tensor_value_info("Z1", TFLOAT, ["a", "b1"]),
733+
oh.make_tensor_value_info("Z2", TFLOAT, ["a", "b2"]),
734+
],
735+
),
736+
ir_version=9,
737+
opset_imports=[oh.make_opsetid("", 18)],
738+
)
739+
onnx.checker.check_model(model)
740+
self._finalize_test(model, torch.rand(3, 5, dtype=torch.float32), use_ort=True)
741+
self._finalize_test(model, torch.rand(3, 6, dtype=torch.float32), use_ort=True)
742+
743+
def test_op_split_op_sizes(self):
744+
model = oh.make_model(
745+
oh.make_graph(
746+
[oh.make_node("Split", ["X", "split"], ["Z1", "Z2"], axis=1)],
747+
"dummy",
748+
[
749+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
750+
oh.make_tensor_value_info("split", TINT64, [2]),
751+
],
752+
[
753+
oh.make_tensor_value_info("Z1", TFLOAT, ["a", "b1"]),
754+
oh.make_tensor_value_info("Z2", TFLOAT, ["a", "b2"]),
755+
],
756+
),
757+
ir_version=9,
758+
opset_imports=[oh.make_opsetid("", 18)],
759+
)
760+
onnx.checker.check_model(model)
761+
self._finalize_test(
762+
model,
763+
torch.rand(3, 5, dtype=torch.float32),
764+
torch.tensor([2, 3], dtype=torch.int64),
765+
use_ort=True,
766+
)
767+
562768

563769
if __name__ == "__main__":
564770
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,6 @@ def test_validate_model_custom_torch(self):
186186
self.assertIsInstance(data, dict)
187187
self.assertIn("disc_onnx_ort_run_abs", summary)
188188
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
189-
onnx_filename = data["onnx_filename"]
190-
output_path = f"{onnx_filename}.ortopt.onnx"
191-
run_ort_fusion(
192-
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
193-
)
194189

195190
def test_filter_inputs(self):
196191
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
MatMul_1,
1212
Mul_1,
1313
Or_1,
14+
Pow_12,
1415
Sub_1,
1516
)
17+
from .generator_ops import Range_11
1618
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
1719
from .other_ops import Cast_6, Concat_1, Transpose_1, Where_9
1820
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
19-
from .shape_ops import Reshape_14, Shape_15, Squeeze_13, Unsqueeze_13
20-
from .unary_ops import Neg_1, Not_1, Reciprocal_1
21+
from .shape_ops import Expand_8, Reshape_14, Shape_15, Squeeze_13, Split_18, Unsqueeze_13
22+
from .unary_ops import Cos_1, Exp_1, Log_1, Neg_1, Not_1, Reciprocal_1, Sin_1, Sqrt_1

onnx_diagnostic/reference/torch_ops/binary_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from . import OpRun, OpRunValue
23

34

@@ -78,6 +79,13 @@ def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
7879
return OpRunValue(x.tensor | y.tensor)
7980

8081

82+
class Pow_12(OpRun):
83+
"""Pow"""
84+
85+
def run(self, x: OpRunValue, y: OpRunValue) -> OpRunValue:
86+
return OpRunValue(torch.pow(x.tensor, y.tensor))
87+
88+
8189
class Sub_1(OpRun):
8290
"""Sub"""
8391

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from . import OpRun, OpRunValue
3+
4+
5+
class Range_11(OpRun):
6+
"""Range"""
7+
8+
def run(self, starts: OpRunValue, limit: OpRunValue, delta: OpRunValue) -> OpRunValue:
9+
return OpRunValue(
10+
torch.arange(starts.tensor, limit.tensor, delta.tensor, dtype=starts.dtype)
11+
)

onnx_diagnostic/reference/torch_ops/shape_ops.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
from . import OpRun, OpRunValue
55

66

7+
class Expand_8(OpRun):
8+
"Expand"
9+
10+
def run(self, data: OpRunValue, shape: OpRunValue) -> OpRunValue:
11+
ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int)
12+
return OpRunValue(data.tensor.expand(ishape))
13+
14+
715
class Reshape_14(OpRun):
816
"Reshape"
917

@@ -19,7 +27,7 @@ def run(self, data: OpRunValue, shape: OpRunValue) -> OpRunValue:
1927
for i, s in enumerate(ishape):
2028
new_shape.append(xshape[i] if s == 0 else s)
2129
return OpRunValue(data.tensor.reshape(new_shape))
22-
return OpRunValue(data.tensor.reshape(shape.as_tuple_int))
30+
return OpRunValue(data.tensor.reshape(ishape))
2331

2432

2533
class Shape_15(OpRun):
@@ -34,6 +42,29 @@ def run(self, data: OpRunValue) -> OpRunValue:
3442
return OpRunValue(torch.tensor(sh, dtype=torch.int64), is_constant=True)
3543

3644

45+
class Split_18(OpRun):
46+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
47+
super().__init__(node, version)
48+
self.axis = self.get_attribute_int(node, "axis", 0)
49+
self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
50+
51+
def run(self, data: OpRunValue, split: Optional[OpRunValue] = None) -> OpRunValue:
52+
if split is None:
53+
assert isinstance(
54+
self.num_outputs, int
55+
), f"Incompatibilies: split is None and num_outputs={self.num_outputs}"
56+
size = data.tensor.shape[self.axis]
57+
split_size = (
58+
size // self.num_outputs
59+
if size % self.num_outputs == 0
60+
else size // self.num_outputs + 1
61+
)
62+
spl = torch.split(data.tensor, split_size, dim=self.axis)
63+
else:
64+
spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis)
65+
return tuple(OpRunValue(t) for t in spl)
66+
67+
3768
class Squeeze_13(OpRun):
3869
"Squeeze"
3970

onnx_diagnostic/reference/torch_ops/unary_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
11
from . import OpRun, OpRunValue
22

33

4+
class Cos_1(OpRun):
5+
"""Cos"""
6+
7+
def run(self, x: OpRunValue) -> OpRunValue:
8+
return OpRunValue(x.tensor.cos())
9+
10+
11+
class Exp_1(OpRun):
12+
"""Exp"""
13+
14+
def run(self, x: OpRunValue) -> OpRunValue:
15+
return OpRunValue(x.tensor.exp())
16+
17+
18+
class Log_1(OpRun):
19+
"""Log"""
20+
21+
def run(self, x: OpRunValue) -> OpRunValue:
22+
return OpRunValue(x.tensor.log())
23+
24+
425
class Neg_1(OpRun):
526
"""Neg"""
627

@@ -20,3 +41,17 @@ class Reciprocal_1(OpRun):
2041

2142
def run(self, x: OpRunValue) -> OpRunValue:
2243
return OpRunValue(1 / x.tensor)
44+
45+
46+
class Sin_1(OpRun):
47+
"""Sin"""
48+
49+
def run(self, x: OpRunValue) -> OpRunValue:
50+
return OpRunValue(x.tensor.sin())
51+
52+
53+
class Sqrt_1(OpRun):
54+
"""Sqrt"""
55+
56+
def run(self, x: OpRunValue) -> OpRunValue:
57+
return OpRunValue(x.tensor.sqrt())

0 commit comments

Comments
 (0)