Skip to content

Commit 6a336be

Browse files
committed
Merge branch 'main' of https://github.com/sdpython/onnx-diagnostic into empty
2 parents 977c839 + be1f54f commit 6a336be

File tree

8 files changed

+55
-10
lines changed

8 files changed

+55
-10
lines changed

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,16 @@ def test_init_numpy_bfloat16(self):
239239
self.assertIsInstance(got[0], np.ndarray)
240240
self.assertEqualArray(expected[0], got[0])
241241

242+
def test_init_numpy_bfloat16_whole(self):
243+
model, feeds, expected = self._get_model_init(TensorProto.BFLOAT16)
244+
wrap = OnnxruntimeEvaluator(model, providers="cpu", whole=True)
245+
got = wrap.run(
246+
None, {k: v.to(float).numpy().astype(ml_dtypes.bfloat16) for k, v in feeds.items()}
247+
)
248+
self.assertIsInstance(got[0], np.ndarray)
249+
self.assertEqualArray(expected[0], got[0])
250+
self.assertEqual(got[0].dtype, ml_dtypes.bfloat16)
251+
242252
@hide_stdout()
243253
def test_init_torch_afloat32(self):
244254
model, feeds, expected = self._get_model_init(TensorProto.FLOAT)

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,37 @@
77
requires_torch,
88
requires_experimental,
99
requires_transformers,
10+
requires_cuda,
1011
)
1112
from onnx_diagnostic.torch_models.validate import validate_model
1213

1314

1415
class TestValidateModel(ExtTestCase):
16+
@requires_transformers("4.53")
17+
@requires_torch("2.7.99")
18+
@requires_experimental()
19+
@requires_cuda()
20+
@hide_stdout()
21+
def test_validate_tiny_llms_bfloat16(self):
22+
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
23+
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
24+
summary, data = validate_model(
25+
"arnir0/Tiny-LLM",
26+
do_run=True,
27+
verbose=2,
28+
exporter="custom",
29+
do_same=True,
30+
patch=True,
31+
rewrite=True,
32+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
33+
dump_folder="dump_test/validate_tiny_llm",
34+
dtype="bfloat16",
35+
device="cuda",
36+
runtime="orteval",
37+
)
38+
self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2)
39+
self.assertIn("onnx_filename", data)
40+
1541
@requires_transformers("4.53")
1642
@requires_torch("2.7.99")
1743
@requires_experimental()

onnx_diagnostic/helpers/helper.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,10 @@ def string_type(
516516
print(f"[string_type] V2:{type(obj)}")
517517
return "OV(NOTENSOR)"
518518
if with_min_max:
519+
from .torch_helper import to_numpy
520+
519521
try:
520-
t = obj.numpy()
522+
t = to_numpy(obj)
521523
except Exception:
522524
# pass unable to convert into numpy (bfloat16, ...)
523525
if verbose:
@@ -1233,9 +1235,13 @@ def max_diff(
12331235

12341236
if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
12351237
if isinstance(expected, torch.Tensor):
1236-
expected = expected.detach().cpu().numpy()
1238+
from .torch_helper import to_numpy
1239+
1240+
expected = to_numpy(expected)
12371241
if isinstance(got, torch.Tensor):
1238-
got = got.detach().cpu().numpy()
1242+
from .torch_helper import to_numpy
1243+
1244+
got = to_numpy(got)
12391245
if verbose >= 6:
12401246
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
12411247

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def create_model_builder(
203203
"ChatGLMModel": builder.ChatGLMModel,
204204
"Ernie4_5_ForCausalLM": builder.ErnieModel,
205205
"GemmaForCausalLM": builder.Gemma2Model,
206+
"Gemma2ForCausalLM": builder.Gemma2Model,
206207
"Gemma3ForCausalLM": builder.Gemma3Model,
207208
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
208209
"GraniteForCausalLM": builder.GraniteModel,

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import onnx
44
import torch
55
from .helper import string_type, flatten_object
6+
from .torch_helper import to_numpy
67
from .cache_helper import is_cache_dynamic_registered
78

89

@@ -56,7 +57,7 @@ def make_feeds(
5657
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
5758
)
5859
if use_numpy:
59-
flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
60+
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
6061
names = (
6162
[i.name for i in proto.graph.input]
6263
if isinstance(proto, onnx.ModelProto)

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,10 @@ def is_torchdynamo_exporting() -> bool:
478478
return False
479479

480480

481-
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
481+
def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
482482
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
483483
try:
484-
return tensor.numpy()
484+
return tensor.detach().cpu().numpy()
485485
except TypeError:
486486
# We try with ml_dtypes
487487
pass
@@ -490,7 +490,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821
490490

491491
conv = {torch.bfloat16: ml_dtypes.bfloat16}
492492
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
493-
return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
493+
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
494494

495495

496496
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import onnx
55
import torch
6-
from ..helpers.torch_helper import to_tensor
6+
from ..helpers.torch_helper import to_tensor, to_numpy
77
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
88
from .report_results_comparison import ReportResultComparison
99
from . import torch_ops
@@ -578,7 +578,7 @@ def run(
578578
print(f"- clean {o}")
579579

580580
if use_numpy:
581-
return [None if a is None else a.detach().cpu().numpy() for a in fres]
581+
return [None if a is None else to_numpy(a) for a in fres]
582582
return fres
583583

584584
def run_with_values(

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from ..helpers import string_type, string_diff, max_diff
55
from ..helpers.onnx_helper import to_array_extended
6+
from ..helpers.torch_helper import to_numpy
67

78

89
def validate_fx_tensor(
@@ -296,7 +297,7 @@ def post_process(obs):
296297
)
297298

298299
for inp, v in zip(onx.graph.input, args):
299-
onnx_results[inp.name] = v.cpu().numpy()
300+
onnx_results[inp.name] = to_numpy(v)
300301
if verbose:
301302
print(
302303
f"[run_aligned] +onnx-input: {inp.name}: "

0 commit comments

Comments
 (0)