|
1 | 1 | import unittest |
| 2 | +from typing import Any, Dict, List |
2 | 3 | import numpy as np |
3 | 4 | import onnx.helper as oh |
4 | 5 | import onnx.numpy_helper as onh |
5 | | -from onnx import TensorProto |
| 6 | +from onnx import TensorProto, FunctionProto, ValueInfoProto |
6 | 7 | from onnx.checker import check_model |
| 8 | +import torch |
7 | 9 | from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout |
8 | 10 | from onnx_diagnostic.helpers.onnx_helper import ( |
9 | 11 | onnx_lighten, |
10 | 12 | onnx_unlighten, |
11 | 13 | onnx_find, |
12 | 14 | _validate_function, |
13 | 15 | check_model_ort, |
| 16 | + iterator_initializer_constant, |
| 17 | + from_array_extended, |
| 18 | + tensor_statistics, |
14 | 19 | ) |
15 | 20 |
|
16 | 21 |
|
17 | 22 | TFLOAT = TensorProto.FLOAT |
18 | 23 |
|
19 | 24 |
|
20 | | -class TestOnnxTools(ExtTestCase): |
| 25 | +class TestOnnxHelper(ExtTestCase): |
21 | 26 |
|
22 | 27 | def _get_model(self): |
23 | 28 | model = oh.make_model( |
@@ -122,6 +127,130 @@ def test_check_model_ort(self): |
122 | 127 | ) |
123 | 128 | check_model_ort(model) |
124 | 129 |
|
| 130 | + def test_iterate_init(self): |
| 131 | + itype = TensorProto.FLOAT |
| 132 | + cst = np.arange(6).astype(np.float32) |
| 133 | + model = oh.make_model( |
| 134 | + oh.make_graph( |
| 135 | + [ |
| 136 | + oh.make_node("IsNaN", ["x"], ["xi"]), |
| 137 | + oh.make_node("IsNaN", ["y"], ["yi"]), |
| 138 | + oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64), |
| 139 | + oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64), |
| 140 | + oh.make_node("Add", ["xii", "yii"], ["gggg"]), |
| 141 | + oh.make_node("Cast", ["gggg"], ["final"], to=itype), |
| 142 | + ], |
| 143 | + "dummy", |
| 144 | + [oh.make_tensor_value_info("x", itype, [None, None])], |
| 145 | + [oh.make_tensor_value_info("final", itype, [None, None])], |
| 146 | + [from_array_extended(cst, name="y")], |
| 147 | + ), |
| 148 | + opset_imports=[oh.make_opsetid("", 20)], |
| 149 | + ir_version=10, |
| 150 | + ) |
| 151 | + li = list(iterator_initializer_constant(model)) |
| 152 | + self.assertEqual(len(li), 1) |
| 153 | + self.assertEqual(li[0][0], "y") |
| 154 | + self.assertEqualArray(li[0][1], cst) |
| 155 | + li = list(iterator_initializer_constant(model, use_numpy=False)) |
| 156 | + self.assertEqual(len(li), 1) |
| 157 | + self.assertEqual(li[0][0], "y") |
| 158 | + self.assertEqualArray(li[0][1], cst) |
| 159 | + self.assertIsInstance(li[0][1], torch.Tensor) |
| 160 | + |
| 161 | + def _get_cdist_implementation( |
| 162 | + self, |
| 163 | + node_inputs: List[str], |
| 164 | + node_outputs: List[str], |
| 165 | + opsets: Dict[str, int], |
| 166 | + **kwargs: Any, |
| 167 | + ) -> FunctionProto: |
| 168 | + """ |
| 169 | + Returns the CDist implementation as a function. |
| 170 | + """ |
| 171 | + assert len(node_inputs) == 2 |
| 172 | + assert len(node_outputs) == 1 |
| 173 | + assert opsets |
| 174 | + assert "" in opsets |
| 175 | + assert set(kwargs) == {"metric"}, f"kwargs={kwargs}" |
| 176 | + metric = kwargs["metric"] |
| 177 | + assert metric in ("euclidean", "sqeuclidean") |
| 178 | + # subgraph |
| 179 | + nodes = [ |
| 180 | + oh.make_node("Sub", ["next", "next_in"], ["diff"]), |
| 181 | + oh.make_node("Constant", [], ["axis"], value_ints=[1]), |
| 182 | + oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0), |
| 183 | + oh.make_node("Identity", ["next_in"], ["next_out"]), |
| 184 | + ] |
| 185 | + |
| 186 | + def make_value(name): |
| 187 | + value = ValueInfoProto() |
| 188 | + value.name = name |
| 189 | + return value |
| 190 | + |
| 191 | + graph = oh.make_graph( |
| 192 | + nodes, |
| 193 | + "loop", |
| 194 | + [make_value("next_in"), make_value("next")], |
| 195 | + [make_value("next_out"), make_value("scan_out")], |
| 196 | + ) |
| 197 | + |
| 198 | + scan = oh.make_node( |
| 199 | + "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph |
| 200 | + ) |
| 201 | + final = ( |
| 202 | + oh.make_node("Sqrt", ["zout"], ["z"]) |
| 203 | + if metric == "euclidean" |
| 204 | + else oh.make_node("Identity", ["zout"], ["z"]) |
| 205 | + ) |
| 206 | + return oh.make_function( |
| 207 | + "npx", |
| 208 | + f"CDist_{metric}", |
| 209 | + ["xa", "xb"], |
| 210 | + ["z"], |
| 211 | + [scan, final], |
| 212 | + [oh.make_opsetid("", opsets[""])], |
| 213 | + ) |
| 214 | + |
| 215 | + def test_iterate_function(self): |
| 216 | + itype = TensorProto.FLOAT |
| 217 | + proto = self._get_cdist_implementation( |
| 218 | + ["X", "Y"], ["Z"], opsets={"": 18}, metric="euclidean" |
| 219 | + ) |
| 220 | + model = oh.make_model( |
| 221 | + oh.make_graph( |
| 222 | + [ |
| 223 | + oh.make_node(proto.name, ["X", "Y"], ["Z"]), |
| 224 | + ], |
| 225 | + "dummy", |
| 226 | + [ |
| 227 | + oh.make_tensor_value_info("X", itype, [None, None]), |
| 228 | + oh.make_tensor_value_info("Y", itype, [None, None]), |
| 229 | + ], |
| 230 | + [oh.make_tensor_value_info("final", itype, [None, None])], |
| 231 | + ), |
| 232 | + opset_imports=[oh.make_opsetid("", 18)], |
| 233 | + ir_version=10, |
| 234 | + ) |
| 235 | + model.functions.append(proto) |
| 236 | + li = list(iterator_initializer_constant(model)) |
| 237 | + self.assertEqual(len(li), 1) |
| 238 | + self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis") |
| 239 | + self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64)) |
| 240 | + li = list(iterator_initializer_constant(model, use_numpy=False)) |
| 241 | + self.assertEqual(len(li), 1) |
| 242 | + self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis") |
| 243 | + self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64)) |
| 244 | + self.assertIsInstance(li[0][1], torch.Tensor) |
| 245 | + |
| 246 | + def test_statistics(self): |
| 247 | + rnd = np.random.rand(40, 50).astype(np.float16) |
| 248 | + stat = tensor_statistics(rnd) |
| 249 | + self.assertEqual(stat["stype"], "FLOAT16") |
| 250 | + rnd = np.random.rand(40, 50).astype(np.float32) |
| 251 | + stat = tensor_statistics(rnd) |
| 252 | + self.assertEqual(stat["stype"], "FLOAT") |
| 253 | + |
125 | 254 |
|
126 | 255 | if __name__ == "__main__": |
127 | 256 | unittest.main(verbosity=2) |
0 commit comments