Skip to content

Commit 49a41ef

Browse files
authored
backend torch-onnx (#120)
* torch-onnx * add more ops * more ops * mypy * mypy * mypy * add constant of shape * add trilu * add equal
1 parent 10afdd3 commit 49a41ef

File tree

14 files changed

+617
-28
lines changed

14 files changed

+617
-28
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.6.1
55
+++++
66

7-
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`: first steps for TorchOnnxEvaluator
7+
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`:
8+
first steps for TorchOnnxEvaluator
89
* :pr:`114`: extends the list of known rewritings
910
* :pr:`113`: fixes a couple of issues with ModelBuilder
1011

_unittests/ut_reference/test_torch_evaluator.py renamed to _unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 296 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import onnx.helper as oh
55
import onnx.numpy_helper as onh
66
import torch
7-
from onnx_diagnostic.ext_test_case import ExtTestCase
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
8+
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
89
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, TorchOnnxEvaluator
910
from onnx_diagnostic.reference.torch_evaluator import get_kernels
1011

@@ -97,11 +98,13 @@ def test_op_binary_cmp(self):
9798
[
9899
oh.make_node("Neg", ["X"], ["nx"]),
99100
oh.make_node("Reciprocal", ["nx"], ["rnx"]),
101+
oh.make_node("Equal", ["X", "Y"], ["ae"]),
100102
oh.make_node("Greater", ["X", "rnx"], ["a"]),
101103
oh.make_node("GreaterOrEqual", ["X", "Y"], ["b"]),
102104
oh.make_node("Less", ["X", "Y"], ["c"]),
103105
oh.make_node("LessOrEqual", ["X", "Y"], ["d"]),
104-
oh.make_node("And", ["a", "b"], ["ab"]),
106+
oh.make_node("And", ["ae", "a"], ["aa"]),
107+
oh.make_node("And", ["aa", "b"], ["ab"]),
105108
oh.make_node("Or", ["c", "d"], ["cd"]),
106109
oh.make_node("Not", ["cd"], ["ncd"]),
107110
oh.make_node("And", ["ab", "ncd"], ["Z"]),
@@ -559,6 +562,297 @@ def test_op_layer_normalization_big_eps(self):
559562
atol=1e-4,
560563
)
561564

565+
def test_op_range_float(self):
566+
model = oh.make_model(
567+
oh.make_graph(
568+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
569+
"dummy",
570+
[
571+
oh.make_tensor_value_info("start", TFLOAT, []),
572+
oh.make_tensor_value_info("limit", TFLOAT, []),
573+
oh.make_tensor_value_info("delta", TFLOAT, []),
574+
],
575+
[oh.make_tensor_value_info("Z", TFLOAT, ["a"])],
576+
),
577+
ir_version=9,
578+
opset_imports=[oh.make_opsetid("", 18)],
579+
)
580+
self._finalize_test(
581+
model,
582+
torch.tensor(2.1, dtype=torch.float32),
583+
torch.tensor(5.1, dtype=torch.float32),
584+
torch.tensor(1, dtype=torch.float32),
585+
)
586+
587+
def test_op_range_int64(self):
588+
model = oh.make_model(
589+
oh.make_graph(
590+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
591+
"dummy",
592+
[
593+
oh.make_tensor_value_info("start", TINT64, []),
594+
oh.make_tensor_value_info("limit", TINT64, []),
595+
oh.make_tensor_value_info("delta", TINT64, []),
596+
],
597+
[oh.make_tensor_value_info("Z", TINT64, ["a"])],
598+
),
599+
ir_version=9,
600+
opset_imports=[oh.make_opsetid("", 18)],
601+
)
602+
self._finalize_test(
603+
model,
604+
torch.tensor(2, dtype=torch.int64),
605+
torch.tensor(5, dtype=torch.int64),
606+
torch.tensor(1, dtype=torch.int64),
607+
)
608+
609+
def test_op_range_int64_h2(self):
610+
model = oh.make_model(
611+
oh.make_graph(
612+
[oh.make_node("Range", ["start", "limit", "delta"], ["Z"])],
613+
"dummy",
614+
[
615+
oh.make_tensor_value_info("start", TINT64, []),
616+
oh.make_tensor_value_info("limit", TINT64, []),
617+
oh.make_tensor_value_info("delta", TINT64, []),
618+
],
619+
[oh.make_tensor_value_info("Z", TINT64, ["a"])],
620+
),
621+
ir_version=9,
622+
opset_imports=[oh.make_opsetid("", 18)],
623+
)
624+
self._finalize_test(
625+
model,
626+
torch.tensor(2, dtype=torch.int64),
627+
torch.tensor(5, dtype=torch.int64),
628+
torch.tensor(2, dtype=torch.int64),
629+
)
630+
631+
def test_op_expand(self):
632+
model = oh.make_model(
633+
oh.make_graph(
634+
[oh.make_node("Expand", ["X", "shape"], ["Y"])],
635+
"dummy",
636+
[
637+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c", "d"]),
638+
oh.make_tensor_value_info("shape", TINT64, ["f"]),
639+
],
640+
[oh.make_tensor_value_info("Y", TFLOAT, ["aa", "ba", "ca", "da"])],
641+
),
642+
ir_version=9,
643+
opset_imports=[oh.make_opsetid("", 18)],
644+
)
645+
self._finalize_test(
646+
model,
647+
torch.rand((1, 5, 6, 7), dtype=torch.float32),
648+
torch.tensor([4, 5, 1, 1], dtype=torch.int64),
649+
)
650+
651+
def test_op_unary(self):
652+
model = oh.make_model(
653+
oh.make_graph(
654+
[
655+
oh.make_node("Cos", ["X"], ["nx"]),
656+
oh.make_node("Sin", ["nx"], ["t"]),
657+
oh.make_node("Exp", ["t"], ["u"]),
658+
oh.make_node("Log", ["u"], ["Z"]),
659+
],
660+
"dummy",
661+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
662+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
663+
),
664+
ir_version=9,
665+
opset_imports=[oh.make_opsetid("", 18)],
666+
)
667+
onnx.checker.check_model(model)
668+
self._finalize_test(model, torch.abs(torch.rand(3, 4, dtype=torch.float32)), atol=1e-6)
669+
670+
def test_op_pow(self):
671+
model = oh.make_model(
672+
oh.make_graph(
673+
[oh.make_node("Pow", ["X", "Y"], ["Z"])],
674+
"dummy",
675+
[
676+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
677+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b"]),
678+
],
679+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
680+
),
681+
ir_version=9,
682+
opset_imports=[oh.make_opsetid("", 18)],
683+
)
684+
onnx.checker.check_model(model)
685+
self._finalize_test(
686+
model,
687+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
688+
torch.abs(torch.rand(3, 4, 5, dtype=torch.float32)),
689+
atol=1e-7,
690+
)
691+
692+
def test_op_pow_op_int(self):
693+
model = oh.make_model(
694+
oh.make_graph(
695+
[oh.make_node("Pow", ["X", "Y"], ["Z"])],
696+
"dummy",
697+
[
698+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
699+
oh.make_tensor_value_info("Y", TINT64, ["a", "b"]),
700+
],
701+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
702+
),
703+
ir_version=9,
704+
opset_imports=[oh.make_opsetid("", 18)],
705+
)
706+
onnx.checker.check_model(model)
707+
self._finalize_test(
708+
model,
709+
torch.rand(3, 4, 5, dtype=torch.float32),
710+
torch.tensor([2], dtype=torch.int64),
711+
atol=1e-7,
712+
)
713+
714+
def test_op_sqrt(self):
715+
model = oh.make_model(
716+
oh.make_graph(
717+
[oh.make_node("Sqrt", ["X"], ["Z"])],
718+
"dummy",
719+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
720+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
721+
),
722+
ir_version=9,
723+
opset_imports=[oh.make_opsetid("", 18)],
724+
)
725+
onnx.checker.check_model(model)
726+
self._finalize_test(model, torch.abs(torch.rand(3, 4, dtype=torch.float32)), atol=1e-6)
727+
728+
def test_op_sigmoid(self):
729+
model = oh.make_model(
730+
oh.make_graph(
731+
[oh.make_node("Sigmoid", ["X"], ["Z"])],
732+
"dummy",
733+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
734+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
735+
),
736+
ir_version=9,
737+
opset_imports=[oh.make_opsetid("", 18)],
738+
)
739+
onnx.checker.check_model(model)
740+
self._finalize_test(model, torch.abs(torch.rand(3, 4, dtype=torch.float32)), atol=1e-6)
741+
742+
def test_op_split(self):
743+
model = oh.make_model(
744+
oh.make_graph(
745+
[oh.make_node("Split", ["X"], ["Z1", "Z2"], axis=1, num_outputs=2)],
746+
"dummy",
747+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
748+
[
749+
oh.make_tensor_value_info("Z1", TFLOAT, ["a", "b1"]),
750+
oh.make_tensor_value_info("Z2", TFLOAT, ["a", "b2"]),
751+
],
752+
),
753+
ir_version=9,
754+
opset_imports=[oh.make_opsetid("", 18)],
755+
)
756+
onnx.checker.check_model(model)
757+
self._finalize_test(model, torch.rand(3, 5, dtype=torch.float32), use_ort=True)
758+
self._finalize_test(model, torch.rand(3, 6, dtype=torch.float32), use_ort=True)
759+
760+
def test_op_split_op_sizes(self):
761+
model = oh.make_model(
762+
oh.make_graph(
763+
[oh.make_node("Split", ["X", "split"], ["Z1", "Z2"], axis=1)],
764+
"dummy",
765+
[
766+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
767+
oh.make_tensor_value_info("split", TINT64, [2]),
768+
],
769+
[
770+
oh.make_tensor_value_info("Z1", TFLOAT, ["a", "b1"]),
771+
oh.make_tensor_value_info("Z2", TFLOAT, ["a", "b2"]),
772+
],
773+
),
774+
ir_version=9,
775+
opset_imports=[oh.make_opsetid("", 18)],
776+
)
777+
onnx.checker.check_model(model)
778+
self._finalize_test(
779+
model,
780+
torch.rand(3, 5, dtype=torch.float32),
781+
torch.tensor([2, 3], dtype=torch.int64),
782+
use_ort=True,
783+
)
784+
785+
def test_op_constant_of_shape(self):
786+
model = oh.make_model(
787+
oh.make_graph(
788+
[
789+
oh.make_node(
790+
"ConstantOfShape",
791+
["shape"],
792+
["Z"],
793+
value=from_array_extended(np.array([2], dtype=np.float16)),
794+
)
795+
],
796+
"dummy",
797+
[oh.make_tensor_value_info("shape", TINT64, ["a"])],
798+
[oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT16, ["a", "b"])],
799+
),
800+
ir_version=9,
801+
opset_imports=[oh.make_opsetid("", 18)],
802+
)
803+
onnx.checker.check_model(model)
804+
self._finalize_test(model, torch.tensor([4, 5], dtype=torch.int64))
805+
806+
def test_op_trilu(self):
807+
model = oh.make_model(
808+
oh.make_graph(
809+
[oh.make_node("Trilu", ["X"], ["Z"])],
810+
"dummy",
811+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
812+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
813+
),
814+
ir_version=9,
815+
opset_imports=[oh.make_opsetid("", 18)],
816+
)
817+
onnx.checker.check_model(model)
818+
self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32))
819+
820+
def test_op_trilu_1(self):
821+
model = oh.make_model(
822+
oh.make_graph(
823+
[oh.make_node("Trilu", ["X"], ["Z"], upper=0)],
824+
"dummy",
825+
[oh.make_tensor_value_info("X", TFLOAT, ["a", "b"])],
826+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
827+
),
828+
ir_version=9,
829+
opset_imports=[oh.make_opsetid("", 18)],
830+
)
831+
onnx.checker.check_model(model)
832+
self._finalize_test(model, torch.rand((4, 4), dtype=torch.float32))
833+
834+
@ignore_warnings(DeprecationWarning)
835+
def test_op_trilu_k(self):
836+
model = oh.make_model(
837+
oh.make_graph(
838+
[oh.make_node("Trilu", ["X", "k"], ["Z"], upper=1)],
839+
"dummy",
840+
[
841+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
842+
oh.make_tensor_value_info("k", TINT64, []),
843+
],
844+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b"])],
845+
),
846+
ir_version=9,
847+
opset_imports=[oh.make_opsetid("", 18)],
848+
)
849+
onnx.checker.check_model(model)
850+
self._finalize_test(
851+
model,
852+
torch.rand((6, 6), dtype=torch.float32),
853+
torch.tensor([2], dtype=torch.int64),
854+
)
855+
562856

563857
if __name__ == "__main__":
564858
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,29 @@ def test_validate_model_custom(self):
164164
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
165165
)
166166

167+
@requires_torch("2.7")
168+
@hide_stdout()
169+
@ignore_warnings(FutureWarning)
170+
@requires_experimental()
171+
def test_validate_model_custom_torch(self):
172+
mid = "arnir0/Tiny-LLM"
173+
summary, data = validate_model(
174+
mid,
175+
do_run=True,
176+
verbose=10,
177+
exporter="custom-inline",
178+
dump_folder="dump_test_validate_model_custom_torch",
179+
patch=True,
180+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
181+
optimization="default",
182+
quiet=False,
183+
runtime="torch",
184+
)
185+
self.assertIsInstance(summary, dict)
186+
self.assertIsInstance(data, dict)
187+
self.assertIn("disc_onnx_ort_run_abs", summary)
188+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
189+
167190
def test_filter_inputs(self):
168191
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}
169192
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"])

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ def get_parser_validate() -> ArgumentParser:
352352
action=BooleanOptionalAction,
353353
help="validate the trained model (requires downloading)",
354354
)
355+
parser.add_argument(
356+
"--runtime",
357+
choices=["onnxruntime", "torch", "ref"],
358+
default="onnxruntime",
359+
help="onnx runtime to use, ",
360+
)
355361
parser.add_argument(
356362
"-o",
357363
"--dump-folder",
@@ -453,6 +459,7 @@ def _cmd_validate(argv: List[Any]):
453459
model_options=args.mop,
454460
subfolder=args.subfolder,
455461
opset=args.opset,
462+
runtime=args.runtime,
456463
)
457464
print("")
458465
print("-- summary --")

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ def make_feeds(
5555
names = (
5656
[i.name for i in proto.graph.input]
5757
if isinstance(proto, onnx.ModelProto)
58-
else ([i.name for i in proto.get_inputs()] if hasattr(proto, "get_inputs") else proto)
58+
else (
59+
[i.name for i in proto.get_inputs()]
60+
if hasattr(proto, "get_inputs")
61+
else (proto.input_names if hasattr(proto, "input_names") else proto)
62+
)
5963
)
6064
assert (
6165
isinstance(names, list)

0 commit comments

Comments
 (0)