File tree Expand file tree Collapse file tree 8 files changed +55
-10
lines changed Expand file tree Collapse file tree 8 files changed +55
-10
lines changed Original file line number Diff line number Diff line change @@ -239,6 +239,16 @@ def test_init_numpy_bfloat16(self):
239239 self .assertIsInstance (got [0 ], np .ndarray )
240240 self .assertEqualArray (expected [0 ], got [0 ])
241241
242+ def test_init_numpy_bfloat16_whole (self ):
243+ model , feeds , expected = self ._get_model_init (TensorProto .BFLOAT16 )
244+ wrap = OnnxruntimeEvaluator (model , providers = "cpu" , whole = True )
245+ got = wrap .run (
246+ None , {k : v .to (float ).numpy ().astype (ml_dtypes .bfloat16 ) for k , v in feeds .items ()}
247+ )
248+ self .assertIsInstance (got [0 ], np .ndarray )
249+ self .assertEqualArray (expected [0 ], got [0 ])
250+ self .assertEqual (got [0 ].dtype , ml_dtypes .bfloat16 )
251+
242252 @hide_stdout ()
243253 def test_init_torch_afloat32 (self ):
244254 model , feeds , expected = self ._get_model_init (TensorProto .FLOAT )
Original file line number Diff line number Diff line change 77 requires_torch ,
88 requires_experimental ,
99 requires_transformers ,
10+ requires_cuda ,
1011)
1112from onnx_diagnostic .torch_models .validate import validate_model
1213
1314
1415class TestValidateModel (ExtTestCase ):
16+ @requires_transformers ("4.53" )
17+ @requires_torch ("2.7.99" )
18+ @requires_experimental ()
19+ @requires_cuda ()
20+ @hide_stdout ()
21+ def test_validate_tiny_llms_bfloat16 (self ):
22+ # python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
23+ # --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
24+ summary , data = validate_model (
25+ "arnir0/Tiny-LLM" ,
26+ do_run = True ,
27+ verbose = 2 ,
28+ exporter = "custom" ,
29+ do_same = True ,
30+ patch = True ,
31+ rewrite = True ,
32+ stop_if_static = 2 if pv .Version (torch .__version__ ) > pv .Version ("2.6.1" ) else 0 ,
33+ dump_folder = "dump_test/validate_tiny_llm" ,
34+ dtype = "bfloat16" ,
35+ device = "cuda" ,
36+ runtime = "orteval" ,
37+ )
38+ self .assertLess (summary ["disc_onnx_ort_run_abs" ], 2e-2 )
39+ self .assertIn ("onnx_filename" , data )
40+
1541 @requires_transformers ("4.53" )
1642 @requires_torch ("2.7.99" )
1743 @requires_experimental ()
Original file line number Diff line number Diff line change @@ -516,8 +516,10 @@ def string_type(
516516 print (f"[string_type] V2:{ type (obj )} " )
517517 return "OV(NOTENSOR)"
518518 if with_min_max :
519+ from .torch_helper import to_numpy
520+
519521 try :
520- t = obj . numpy ( )
522+ t = to_numpy ( obj )
521523 except Exception :
522524 # pass unable to convert into numpy (bfloat16, ...)
523525 if verbose :
@@ -1233,9 +1235,13 @@ def max_diff(
12331235
12341236 if isinstance (expected , np .ndarray ) or isinstance (got , np .ndarray ):
12351237 if isinstance (expected , torch .Tensor ):
1236- expected = expected .detach ().cpu ().numpy ()
1238+ from .torch_helper import to_numpy
1239+
1240+ expected = to_numpy (expected )
12371241 if isinstance (got , torch .Tensor ):
1238- got = got .detach ().cpu ().numpy ()
1242+ from .torch_helper import to_numpy
1243+
1244+ got = to_numpy (got )
12391245 if verbose >= 6 :
12401246 print (f"[max_diff] tensor: { string_type (expected )} ? { string_type (got )} " )
12411247
Original file line number Diff line number Diff line change @@ -203,6 +203,7 @@ def create_model_builder(
203203 "ChatGLMModel" : builder .ChatGLMModel ,
204204 "Ernie4_5_ForCausalLM" : builder .ErnieModel ,
205205 "GemmaForCausalLM" : builder .Gemma2Model ,
206+ "Gemma2ForCausalLM" : builder .Gemma2Model ,
206207 "Gemma3ForCausalLM" : builder .Gemma3Model ,
207208 "Gemma3ForConditionalGeneration" : builder .Gemma3Model ,
208209 "GraniteForCausalLM" : builder .GraniteModel ,
Original file line number Diff line number Diff line change 33import onnx
44import torch
55from .helper import string_type , flatten_object
6+ from .torch_helper import to_numpy
67from .cache_helper import is_cache_dynamic_registered
78
89
@@ -56,7 +57,7 @@ def make_feeds(
5657 f"{ string_type (torch .utils ._pytree .tree_flatten (inputs )[0 ], with_shape = True )} "
5758 )
5859 if use_numpy :
59- flat = [t . detach (). cpu (). numpy ( ) if isinstance (t , torch .Tensor ) else t for t in flat ]
60+ flat = [to_numpy ( t ) if isinstance (t , torch .Tensor ) else t for t in flat ]
6061 names = (
6162 [i .name for i in proto .graph .input ]
6263 if isinstance (proto , onnx .ModelProto )
Original file line number Diff line number Diff line change @@ -478,10 +478,10 @@ def is_torchdynamo_exporting() -> bool:
478478 return False
479479
480480
481- def to_numpy (tensor : "torch.Tensor" ): # noqa: F821
481+ def to_numpy (tensor : "torch.Tensor" ) -> np . ndarray : # noqa: F821
482482 """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
483483 try :
484- return tensor .numpy ()
484+ return tensor .detach (). cpu (). numpy ()
485485 except TypeError :
486486 # We try with ml_dtypes
487487 pass
@@ -490,7 +490,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821
490490
491491 conv = {torch .bfloat16 : ml_dtypes .bfloat16 }
492492 assert tensor .dtype in conv , f"Unsupported type { tensor .dtype } , not in { conv } "
493- return tensor .to (torch .float32 ).numpy ().astype (conv [tensor .dtype ])
493+ return tensor .detach (). to (torch .float32 ). cpu ( ).numpy ().astype (conv [tensor .dtype ])
494494
495495
496496def replace_string_by_dynamic (dynamic_shapes : Any ) -> Any :
Original file line number Diff line number Diff line change 33import numpy as np
44import onnx
55import torch
6- from ..helpers .torch_helper import to_tensor
6+ from ..helpers .torch_helper import to_tensor , to_numpy
77from ..torch_onnx .runtime_info import first_used_last_used , RuntimeValue
88from .report_results_comparison import ReportResultComparison
99from . import torch_ops
@@ -578,7 +578,7 @@ def run(
578578 print (f"- clean { o } " )
579579
580580 if use_numpy :
581- return [None if a is None else a . detach (). cpu (). numpy ( ) for a in fres ]
581+ return [None if a is None else to_numpy ( a ) for a in fres ]
582582 return fres
583583
584584 def run_with_values (
Original file line number Diff line number Diff line change 33import torch
44from ..helpers import string_type , string_diff , max_diff
55from ..helpers .onnx_helper import to_array_extended
6+ from ..helpers .torch_helper import to_numpy
67
78
89def validate_fx_tensor (
@@ -296,7 +297,7 @@ def post_process(obs):
296297 )
297298
298299 for inp , v in zip (onx .graph .input , args ):
299- onnx_results [inp .name ] = v . cpu (). numpy ( )
300+ onnx_results [inp .name ] = to_numpy ( v )
300301 if verbose :
301302 print (
302303 f"[run_aligned] +onnx-input: { inp .name } : "
You can’t perform that action at this time.
0 commit comments