Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions _unittests/ut_reference/test_ort_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ def test_init_numpy_bfloat16(self):
self.assertIsInstance(got[0], np.ndarray)
self.assertEqualArray(expected[0], got[0])

def test_init_numpy_bfloat16_whole(self):
model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16)
wrap = OnnxruntimeEvaluator(model, providers="cpu", whole=True)
got = wrap.run(
None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()}
)
self.assertIsInstance(got[0], np.ndarray)
self.assertEqualArray(expected[0], got[0])
self.assertEqual(got[0].dtype, ml_dtypes.bfloat16)

@hide_stdout()
def test_init_torch_afloat32(self):
model, feeds, expected = self._get_model_init(TensorProto.FLOAT)
Expand Down
26 changes: 26 additions & 0 deletions _unittests/ut_torch_models/test_validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,37 @@
requires_torch,
requires_experimental,
requires_transformers,
requires_cuda,
)
from onnx_diagnostic.torch_models.validate import validate_model


class TestValidateModel(ExtTestCase):
@requires_transformers("4.53")
@requires_torch("2.7.99")
@requires_experimental()
@requires_cuda()
@hide_stdout()
def test_validate_tiny_llms_bfloat16(self):
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
summary, data = validate_model(
"arnir0/Tiny-LLM",
do_run=True,
verbose=2,
exporter="custom",
do_same=True,
patch=True,
rewrite=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
dump_folder="dump_test/validate_tiny_llm",
dtype="bfloat16",
device="cuda",
runtime="orteval",
)
self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2)
self.assertIn("onnx_filename", data)

@requires_transformers("4.53")
@requires_torch("2.7.99")
@requires_experimental()
Expand Down
12 changes: 9 additions & 3 deletions onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,10 @@ def string_type(
print(f"[string_type] V2:{type(obj)}")
return "OV(NOTENSOR)"
if with_min_max:
from .torch_helper import to_numpy

try:
t = obj.numpy()
t = to_numpy(obj)
except Exception:
# pass unable to convert into numpy (bfloat16, ...)
if verbose:
Expand Down Expand Up @@ -1233,9 +1235,13 @@ def max_diff(

if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
if isinstance(expected, torch.Tensor):
expected = expected.detach().cpu().numpy()
from .torch_helper import to_numpy

expected = to_numpy(expected)
if isinstance(got, torch.Tensor):
got = got.detach().cpu().numpy()
from .torch_helper import to_numpy

got = to_numpy(got)
if verbose >= 6:
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")

Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/helpers/model_builder_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def create_model_builder(
"ChatGLMModel": builder.ChatGLMModel,
"Ernie4_5_ForCausalLM": builder.ErnieModel,
"GemmaForCausalLM": builder.Gemma2Model,
"Gemma2ForCausalLM": builder.Gemma2Model,
"Gemma3ForCausalLM": builder.Gemma3Model,
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
"GraniteForCausalLM": builder.GraniteModel,
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/helpers/rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import onnx
import torch
from .helper import string_type, flatten_object
from .torch_helper import to_numpy
from .cache_helper import is_cache_dynamic_registered


Expand Down Expand Up @@ -56,7 +57,7 @@ def make_feeds(
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
)
if use_numpy:
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
names = (
[i.name for i in proto.graph.input]
if isinstance(proto, onnx.ModelProto)
Expand Down
6 changes: 3 additions & 3 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,10 @@ def is_torchdynamo_exporting() -> bool:
return False


def to_numpy(tensor: "torch.Tensor"): # noqa: F821
def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
try:
return tensor.numpy()
return tensor.detach().cpu().numpy()
except TypeError:
# We try with ml_dtypes
pass
Expand All @@ -490,7 +490,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821

conv = {torch.bfloat16: ml_dtypes.bfloat16}
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])


def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions onnx_diagnostic/reference/torch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import onnx
import torch
from ..helpers.torch_helper import to_tensor
from ..helpers.torch_helper import to_tensor, to_numpy
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
from .report_results_comparison import ReportResultComparison
from . import torch_ops
Expand Down Expand Up @@ -578,7 +578,7 @@ def run(
print(f"- clean {o}")

if use_numpy:
return [None if a is None else a.detach().cpu().numpy() for a in fres]
return [None if a is None else to_numpy(a) for a in fres]
return fres

def run_with_values(
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/torch_onnx/sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from ..helpers import string_type, string_diff, max_diff
from ..helpers.onnx_helper import to_array_extended
from ..helpers.torch_helper import to_numpy


def validate_fx_tensor(
Expand Down Expand Up @@ -296,7 +297,7 @@ def post_process(obs):
)

for inp, v in zip(onx.graph.input, args):
onnx_results[inp.name] = v.cpu().numpy()
onnx_results[inp.name] = to_numpy(v)
if verbose:
print(
f"[run_aligned] +onnx-input: {inp.name}: "
Expand Down
Loading