From 8c1b68d5d8ba06655a4c49285d7b1d02cbe9848a Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 3 Oct 2025 16:18:16 +0200 Subject: [PATCH] fix for bfloat16 --- _unittests/ut_reference/test_ort_evaluator.py | 10 +++++++ .../ut_torch_models/test_validate_models.py | 26 +++++++++++++++++++ onnx_diagnostic/helpers/helper.py | 12 ++++++--- .../helpers/model_builder_helper.py | 1 + onnx_diagnostic/helpers/rt_helper.py | 3 ++- onnx_diagnostic/helpers/torch_helper.py | 6 ++--- onnx_diagnostic/reference/torch_evaluator.py | 4 +-- onnx_diagnostic/torch_onnx/sbs.py | 3 ++- 8 files changed, 55 insertions(+), 10 deletions(-) diff --git a/_unittests/ut_reference/test_ort_evaluator.py b/_unittests/ut_reference/test_ort_evaluator.py index 51c1cfca..4f7fa270 100644 --- a/_unittests/ut_reference/test_ort_evaluator.py +++ b/_unittests/ut_reference/test_ort_evaluator.py @@ -239,6 +239,16 @@ def test_init_numpy_bfloat16(self): self.assertIsInstance(got[0], np.ndarray) self.assertEqualArray(expected[0], got[0]) + def test_init_numpy_bfloat16_whole(self): + model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16) + wrap = OnnxruntimeEvaluator(model, providers="cpu", whole=True) + got = wrap.run( + None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()} + ) + self.assertIsInstance(got[0], np.ndarray) + self.assertEqualArray(expected[0], got[0]) + self.assertEqual(got[0].dtype, ml_dtypes.bfloat16) + @hide_stdout() def test_init_torch_afloat32(self): model, feeds, expected = self._get_model_init(TensorProto.FLOAT) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index f1fe4f92..6bbf60ee 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -7,11 +7,37 @@ requires_torch, requires_experimental, requires_transformers, + requires_cuda, ) from onnx_diagnostic.torch_models.validate import validate_model class TestValidateModel(ExtTestCase): + @requires_transformers("4.53") + @requires_torch("2.7.99") + @requires_experimental() + @requires_cuda() + @hide_stdout() + def test_validate_tiny_llms_bfloat16(self): + # python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning + # --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch + summary, data = validate_model( + "arnir0/Tiny-LLM", + do_run=True, + verbose=2, + exporter="custom", + do_same=True, + patch=True, + rewrite=True, + stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + dump_folder="dump_test/validate_tiny_llm", + dtype="bfloat16", + device="cuda", + runtime="orteval", + ) + self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2) + self.assertIn("onnx_filename", data) + @requires_transformers("4.53") @requires_torch("2.7.99") @requires_experimental() diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 7879a1fd..00c7f8a3 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -516,8 +516,10 @@ def string_type( print(f"[string_type] V2:{type(obj)}") return "OV(NOTENSOR)" if with_min_max: + from .torch_helper import to_numpy + try: - t = obj.numpy() + t = to_numpy(obj) except Exception: # pass unable to convert into numpy (bfloat16, ...) if verbose: @@ -1233,9 +1235,13 @@ def max_diff( if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray): if isinstance(expected, torch.Tensor): - expected = expected.detach().cpu().numpy() + from .torch_helper import to_numpy + + expected = to_numpy(expected) if isinstance(got, torch.Tensor): - got = got.detach().cpu().numpy() + from .torch_helper import to_numpy + + got = to_numpy(got) if verbose >= 6: print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}") diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 6df97cad..8ee33abe 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -203,6 +203,7 @@ def create_model_builder( "ChatGLMModel": builder.ChatGLMModel, "Ernie4_5_ForCausalLM": builder.ErnieModel, "GemmaForCausalLM": builder.Gemma2Model, + "Gemma2ForCausalLM": builder.Gemma2Model, "Gemma3ForCausalLM": builder.Gemma3Model, "Gemma3ForConditionalGeneration": builder.Gemma3Model, "GraniteForCausalLM": builder.GraniteModel, diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index eb65063f..3a896d03 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -3,6 +3,7 @@ import onnx import torch from .helper import string_type, flatten_object +from .torch_helper import to_numpy from .cache_helper import is_cache_dynamic_registered @@ -56,7 +57,7 @@ def make_feeds( f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}" ) if use_numpy: - flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] + flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat] names = ( [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 734aba76..efd6cc74 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -464,10 +464,10 @@ def is_torchdynamo_exporting() -> bool: return False -def to_numpy(tensor: "torch.Tensor"): # noqa: F821 +def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821 """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`.""" try: - return tensor.numpy() + return tensor.detach().cpu().numpy() except TypeError: # We try with ml_dtypes pass @@ -476,7 +476,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821 conv = {torch.bfloat16: ml_dtypes.bfloat16} assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}" - return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype]) + return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype]) def replace_string_by_dynamic(dynamic_shapes: Any) -> Any: diff --git a/onnx_diagnostic/reference/torch_evaluator.py b/onnx_diagnostic/reference/torch_evaluator.py index fa964658..58e7e89c 100644 --- a/onnx_diagnostic/reference/torch_evaluator.py +++ b/onnx_diagnostic/reference/torch_evaluator.py @@ -3,7 +3,7 @@ import numpy as np import onnx import torch -from ..helpers.torch_helper import to_tensor +from ..helpers.torch_helper import to_tensor, to_numpy from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue from .report_results_comparison import ReportResultComparison from . import torch_ops @@ -578,7 +578,7 @@ def run( print(f"- clean {o}") if use_numpy: - return [None if a is None else a.detach().cpu().numpy() for a in fres] + return [None if a is None else to_numpy(a) for a in fres] return fres def run_with_values( diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 1c2451fc..01e36080 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -3,6 +3,7 @@ import torch from ..helpers import string_type, string_diff, max_diff from ..helpers.onnx_helper import to_array_extended +from ..helpers.torch_helper import to_numpy def validate_fx_tensor( @@ -296,7 +297,7 @@ def post_process(obs): ) for inp, v in zip(onx.graph.input, args): - onnx_results[inp.name] = v.cpu().numpy() + onnx_results[inp.name] = to_numpy(v) if verbose: print( f"[run_aligned] +onnx-input: {inp.name}: "