|
4 | 4 | import ml_dtypes |
5 | 5 | from onnx import ModelProto, TensorProto |
6 | 6 | from onnx.checker import check_model |
| 7 | +import onnx |
7 | 8 | import onnx.helper as oh |
8 | 9 | import onnx.numpy_helper as onh |
9 | 10 | import torch |
@@ -248,6 +249,71 @@ def test_init_torch_bfloat16(self): |
248 | 249 | self.assertIsInstance(got[0], (torch.Tensor, np.ndarray)) |
249 | 250 | self.assertEqualArray(expected[0], got[0]) |
250 | 251 |
|
| 252 | + @hide_stdout() |
| 253 | + def test_if(self): |
| 254 | + |
| 255 | + def _mkv_(name): |
| 256 | + value_info_proto = onnx.ValueInfoProto() |
| 257 | + value_info_proto.name = name |
| 258 | + return value_info_proto |
| 259 | + |
| 260 | + model = oh.make_model( |
| 261 | + oh.make_graph( |
| 262 | + [ |
| 263 | + oh.make_node("ReduceSum", ["X"], ["Xred"]), |
| 264 | + oh.make_node("Add", ["X", "two"], ["X0"]), |
| 265 | + oh.make_node("Add", ["X0", "zero"], ["X00"]), |
| 266 | + oh.make_node("CastLike", ["one", "Xred"], ["one_c"]), |
| 267 | + oh.make_node("Greater", ["Xred", "one_c"], ["cond"]), |
| 268 | + oh.make_node( |
| 269 | + "If", |
| 270 | + ["cond"], |
| 271 | + ["Z_c"], |
| 272 | + then_branch=oh.make_graph( |
| 273 | + [ |
| 274 | + oh.make_node("Constant", [], ["two"], value_floats=[2.1]), |
| 275 | + oh.make_node("Add", ["X00", "two"], ["Y"]), |
| 276 | + ], |
| 277 | + "then", |
| 278 | + [], |
| 279 | + [_mkv_("Y")], |
| 280 | + ), |
| 281 | + else_branch=oh.make_graph( |
| 282 | + [ |
| 283 | + oh.make_node("Constant", [], ["two"], value_floats=[2.2]), |
| 284 | + oh.make_node("Sub", ["X0", "two"], ["Y"]), |
| 285 | + ], |
| 286 | + "else", |
| 287 | + [], |
| 288 | + [_mkv_("Y")], |
| 289 | + ), |
| 290 | + ), |
| 291 | + oh.make_node("CastLike", ["Z_c", "X"], ["Z"]), |
| 292 | + ], |
| 293 | + "test", |
| 294 | + [ |
| 295 | + oh.make_tensor_value_info("X", TensorProto.FLOAT, ["N"]), |
| 296 | + oh.make_tensor_value_info("one", TensorProto.FLOAT, ["N"]), |
| 297 | + ], |
| 298 | + [oh.make_tensor_value_info("Z", TensorProto.UNDEFINED, ["N"])], |
| 299 | + [ |
| 300 | + onh.from_array(np.array([0], dtype=np.float32), name="zero"), |
| 301 | + onh.from_array(np.array([2], dtype=np.float32), name="two"), |
| 302 | + ], |
| 303 | + ), |
| 304 | + opset_imports=[oh.make_operatorsetid("", 18)], |
| 305 | + ir_version=10, |
| 306 | + ) |
| 307 | + feeds = { |
| 308 | + "X": np.array([1, 2, 3], dtype=np.float32), |
| 309 | + "one": np.array([1], dtype=np.float32), |
| 310 | + } |
| 311 | + ref = ExtendedReferenceEvaluator(model, verbose=10) |
| 312 | + expected = ref.run(None, feeds)[0] |
| 313 | + sess = OnnxruntimeEvaluator(model, verbose=10) |
| 314 | + got = sess.run(None, feeds)[0] |
| 315 | + self.assertEqualArray(expected[0], got[0]) |
| 316 | + |
251 | 317 |
|
252 | 318 | if __name__ == "__main__": |
253 | 319 | unittest.main(verbosity=2) |
0 commit comments