|
| 1 | +import unittest |
| 2 | +import numpy as np |
| 3 | +import onnx |
| 4 | +import onnx.helper as oh |
| 5 | +import onnx.numpy_helper as onh |
| 6 | +import torch |
| 7 | +from onnx_diagnostic.ext_test_case import ExtTestCase |
| 8 | +from onnx_diagnostic.torch_onnx.runtime_info import ( |
| 9 | + first_used_last_used, |
| 10 | + RuntimeValue, |
| 11 | + RuntimeValueKind, |
| 12 | + RuntimeDevice, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +class TestRuntimeInfo(ExtTestCase): |
| 17 | + def test_runtime_info(self): |
| 18 | + rt = RuntimeValue("e", is_shape=True, value=torch.Tensor([0])) |
| 19 | + r = repr(rt) |
| 20 | + self.assertEqual("RuntimeValue(name=e, is_shape=True, value=T1s1)", r) |
| 21 | + |
| 22 | + def test_runtime_kind(self): |
| 23 | + h = RuntimeValueKind.INPUT |
| 24 | + self.assertEqual(h.to_str(), "INPUT") |
| 25 | + |
| 26 | + def test_runtime_device(self): |
| 27 | + h = RuntimeDevice.CPU |
| 28 | + self.assertEqual(h.to_str(), "CPU") |
| 29 | + |
| 30 | + def test_runtime_values(self): |
| 31 | + def _mkv_(name): |
| 32 | + value_info_proto = onnx.ValueInfoProto() |
| 33 | + value_info_proto.name = name |
| 34 | + return value_info_proto |
| 35 | + |
| 36 | + model = oh.make_model( |
| 37 | + oh.make_graph( |
| 38 | + [ |
| 39 | + oh.make_node("ReduceSum", ["0X"], ["1Xred"]), |
| 40 | + oh.make_node("Add", ["0X", "0two"], ["2X0"]), |
| 41 | + oh.make_node("Add", ["2X0", "0zero"], ["3X00"]), |
| 42 | + oh.make_node("CastLike", ["0one", "1Xred"], ["4one_c"]), |
| 43 | + oh.make_node("Greater", ["1Xred", "4one_c"], ["5cond"]), |
| 44 | + oh.make_node( |
| 45 | + "If", |
| 46 | + ["5cond"], |
| 47 | + ["6Z_c"], |
| 48 | + then_branch=oh.make_graph( |
| 49 | + [ |
| 50 | + oh.make_node("Constant", [], ["0two"], value_floats=[2.1]), |
| 51 | + oh.make_node("Add", ["3X00", "0two"], ["11Y"]), |
| 52 | + ], |
| 53 | + "then", |
| 54 | + [], |
| 55 | + [_mkv_("11Y")], |
| 56 | + ), |
| 57 | + else_branch=oh.make_graph( |
| 58 | + [ |
| 59 | + oh.make_node("Constant", [], ["0two"], value_floats=[2.2]), |
| 60 | + oh.make_node("Sub", ["2X0", "0two"], ["12Y"]), |
| 61 | + ], |
| 62 | + "else", |
| 63 | + [], |
| 64 | + [_mkv_("12Y")], |
| 65 | + ), |
| 66 | + ), |
| 67 | + oh.make_node("CastLike", ["6Z_c", "0X"], ["7Z"]), |
| 68 | + ], |
| 69 | + "test", |
| 70 | + [ |
| 71 | + oh.make_tensor_value_info("0X", onnx.TensorProto.FLOAT, ["N"]), |
| 72 | + oh.make_tensor_value_info("0one", onnx.TensorProto.FLOAT, ["N"]), |
| 73 | + ], |
| 74 | + [oh.make_tensor_value_info("7Z", onnx.TensorProto.UNDEFINED, ["N"])], |
| 75 | + [ |
| 76 | + onh.from_array(np.array([0], dtype=np.float32), name="0zero"), |
| 77 | + onh.from_array(np.array([2], dtype=np.float32), name="0two"), |
| 78 | + ], |
| 79 | + ), |
| 80 | + opset_imports=[oh.make_operatorsetid("", 18)], |
| 81 | + ir_version=10, |
| 82 | + ) |
| 83 | + rt_values = first_used_last_used(model) |
| 84 | + self.assertEqual( |
| 85 | + { |
| 86 | + "2X0", |
| 87 | + "0two", |
| 88 | + "5cond", |
| 89 | + "1Xred", |
| 90 | + "0zero", |
| 91 | + "0X", |
| 92 | + "4one_c", |
| 93 | + "7Z", |
| 94 | + "6Z_c", |
| 95 | + "0one", |
| 96 | + "3X00", |
| 97 | + }, |
| 98 | + set(rt_values), |
| 99 | + ) |
| 100 | + for name, node in rt_values.items(): |
| 101 | + self.assertEqual(name, node.name) |
| 102 | + if name != "7Z": |
| 103 | + self.assertIsInstance(node.first_used, int) |
| 104 | + self.assertIsInstance(node.last_used, int) |
| 105 | + self.assertIsInstance(node.created, int, msg=f"{name!r} missing 'created'") |
| 106 | + self.assertIsInstance(node.kind, int) |
| 107 | + self.assertEqual( |
| 108 | + int(name[0]) - 1, node.created, msg=f"{name!r} created is wrong {node.created}" |
| 109 | + ) |
| 110 | + if name != "7Z": |
| 111 | + self.assertGreater(node.first_used, node.created) |
| 112 | + self.assertGreaterOrEqual(node.last_used, node.first_used) |
| 113 | + |
| 114 | + |
| 115 | +if __name__ == "__main__": |
| 116 | + unittest.main(verbosity=2) |
0 commit comments