|
| 1 | +""" |
| 2 | +.. _l-plot-failing-onnxruntime-evaluator: |
| 3 | +
|
| 4 | +Running OnnxruntimeEvaluator on a failing model |
| 5 | +=============================================== |
| 6 | +
|
| 7 | +Example :ref:`l-plot-failing-reference-evaluator` demonstrated |
| 8 | +how to run a python runtime on a model but it may very slow sometimes |
| 9 | +and it could show some discrepancies if the only provider is not CPU. |
| 10 | +Let's use :class:`OnnxruntimeEvaluator <onnx_diagnostic.reference.OnnxruntimeEvaluator>`. |
| 11 | +It splits the model into node and runs them independantly until it succeeds |
| 12 | +or fails. This class converts every node into model based on the types |
| 13 | +discovered during the execution. It relies on :class:`InferenceSessionForTorch |
| 14 | +<onnx_diagnostic.ort_session.InferenceSessionForTorch>` or |
| 15 | +:class:`InferenceSessionForNumpy |
| 16 | +<onnx_diagnostic.ort_session.InferenceSessionForNumpy>` |
| 17 | +for the execution. This example uses torch tensor and |
| 18 | +bfloat16. |
| 19 | +
|
| 20 | +A failing model |
| 21 | ++++++++++++++++ |
| 22 | +
|
| 23 | +The issue here is a an operator ``Cast`` trying to convert a result |
| 24 | +into a non-existing type. |
| 25 | +""" |
| 26 | + |
| 27 | +import onnx |
| 28 | +import onnx.helper as oh |
| 29 | +import torch |
| 30 | +import onnxruntime |
| 31 | +from onnx_diagnostic.ext_test_case import has_cuda |
| 32 | +from onnx_diagnostic.helpers import from_array_extended |
| 33 | +from onnx_diagnostic.reference import OnnxruntimeEvaluator |
| 34 | + |
| 35 | +TBFLOAT16 = onnx.TensorProto.BFLOAT16 |
| 36 | + |
| 37 | +model = oh.make_model( |
| 38 | + oh.make_graph( |
| 39 | + [ |
| 40 | + oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"), |
| 41 | + oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"), |
| 42 | + oh.make_node("Add", ["sy", "one"], ["C"], name="n2"), |
| 43 | + oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"), |
| 44 | + oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"), |
| 45 | + ], |
| 46 | + "nd", |
| 47 | + [ |
| 48 | + oh.make_tensor_value_info("X", TBFLOAT16, ["a", "b", "c"]), |
| 49 | + oh.make_tensor_value_info("Y", TBFLOAT16, ["a", "b", "c"]), |
| 50 | + ], |
| 51 | + [oh.make_tensor_value_info("Z", TBFLOAT16, ["a", "b", "c"])], |
| 52 | + [from_array_extended(torch.tensor([1], dtype=torch.bfloat16), name="one")], |
| 53 | + ), |
| 54 | + opset_imports=[oh.make_opsetid("", 18)], |
| 55 | + ir_version=9, |
| 56 | +) |
| 57 | + |
| 58 | +# %% |
| 59 | +# We check it is failing. |
| 60 | + |
| 61 | +try: |
| 62 | + onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) |
| 63 | +except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: |
| 64 | + print(e) |
| 65 | + |
| 66 | + |
| 67 | +# %% |
| 68 | +# OnnxruntimeEvaluator |
| 69 | +# ++++++++++++++++++++++++++ |
| 70 | +# |
| 71 | +# This class extends :class:`onnx.reference.ReferenceEvaluator` |
| 72 | +# with operators outside the standard but defined by :epkg:`onnxruntime`. |
| 73 | +# `verbose=10` tells the class to print as much as possible, |
| 74 | +# `verbose=0` prints nothing. Intermediate values for more or less verbosity. |
| 75 | + |
| 76 | +ref = OnnxruntimeEvaluator(model, verbose=10) |
| 77 | +feeds = dict( |
| 78 | + X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16) |
| 79 | +) |
| 80 | +try: |
| 81 | + ref.run(None, feeds) |
| 82 | +except Exception as e: |
| 83 | + print("ERROR", type(e), e) |
| 84 | + |
| 85 | + |
| 86 | +# %% |
| 87 | +# :epkg:`onnxruntime` may not support bfloat16 on CPU. |
| 88 | +# See :epkg:`onnxruntime kernels`. |
| 89 | + |
| 90 | +if has_cuda(): |
| 91 | + ref = OnnxruntimeEvaluator(model, providers="cuda", verbose=10) |
| 92 | + feeds = dict( |
| 93 | + X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16) |
| 94 | + ) |
| 95 | + try: |
| 96 | + ref.run(None, feeds) |
| 97 | + except Exception as e: |
| 98 | + print("ERROR", type(e), e) |
| 99 | + |
| 100 | +# %% |
| 101 | +# We can see it run until it reaches `Cast` and stops. |
| 102 | +# The error message is not always obvious to interpret. |
| 103 | +# It gets improved everytime from time to time. |
| 104 | +# This runtime is useful when it fails for a numerical reason. |
| 105 | +# It is possible to insert prints in the python code to print |
| 106 | +# more information or debug if needed. |
0 commit comments