2222 InferenceSessionForNumpy ,
2323 _InferenceSession ,
2424)
25- from .report_results_comparison import ReportResultsComparison
25+ from ..helpers .torch_helper import to_tensor
26+ from .report_results_comparison import ReportResultComparison
2627from .evaluator import ExtendedReferenceEvaluator
2728
2829
@@ -51,6 +52,8 @@ class OnnxruntimeEvaluator:
5152 :param ir_version: ir version to use when unknown
5253 :param opsets: opsets to use when unknown
5354 :param whole: if True, do not split node by node
55+ :param torch_or_numpy: force the use of one of them, Ture for torch,
56+ False for numpy, None to let the class choose
5457 """
5558
5659 def __init__ (
@@ -73,6 +76,7 @@ def __init__(
7376 ir_version : int = 10 ,
7477 opsets : Optional [Union [int , Dict [str , int ]]] = None ,
7578 whole : bool = False ,
79+ torch_or_numpy : Optional [bool ] = None ,
7680 ):
7781 if isinstance (proto , str ):
7882 self .proto : Proto = load (proto )
@@ -104,8 +108,10 @@ def __init__(
104108 disable_aot_function_inlining = disable_aot_function_inlining ,
105109 use_training_api = use_training_api ,
106110 )
111+ self .to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
107112
108113 self .verbose = verbose
114+ self .torch_or_numpy = torch_or_numpy
109115 self .sess_ : Optional [_InferenceSession ] = None
110116 if whole :
111117 self .nodes : Optional [List [NodeProto ]] = None
@@ -124,7 +130,10 @@ def __init__(
124130 )
125131 )
126132 self .rt_inits_ = (
127- {init .name : to_array_extended (init ) for init in self .proto .graph .initializer }
133+ {
134+ init .name : self .to_tensor_or_array (init )
135+ for init in self .proto .graph .initializer
136+ }
128137 if hasattr (self .proto , "graph" )
129138 else {}
130139 )
@@ -192,13 +201,14 @@ def _log_arg(self, a: Any) -> Any:
192201 return a
193202 device = f"D{ a .get_device ()} :" if hasattr (a , "detach" ) else ""
194203 if hasattr (a , "shape" ):
204+ prefix = "A:" if hasattr (a , "astype" ) else "T:"
195205 if self .verbose < 4 : # noqa: PLR2004
196- return f"{ device } { a .dtype } :{ a .shape } in [{ a .min ()} , { a .max ()} ]"
206+ return f"{ prefix } { device } { a .dtype } :{ a .shape } in [{ a .min ()} , { a .max ()} ]"
197207 elements = a .ravel ().tolist ()
198208 if len (elements ) > 10 : # noqa: PLR2004
199209 elements = elements [:10 ]
200- return f"{ device } { a .dtype } :{ a .shape } :{ ',' .join (map (str , elements ))} ..."
201- return f"{ device } { a .dtype } :{ a .shape } :{ elements } "
210+ return f"{ prefix } { device } { a .dtype } :{ a .shape } :{ ',' .join (map (str , elements ))} ..."
211+ return f"{ prefix } { device } { a .dtype } :{ a .shape } :{ elements } "
202212 if hasattr (a , "append" ):
203213 return ", " .join (map (self ._log_arg , a ))
204214 return a
@@ -216,7 +226,7 @@ def run(
216226 outputs : Optional [List [str ]],
217227 feed_inputs : Dict [str , Any ],
218228 intermediate : bool = False ,
219- report_cmp : Optional [ReportResultsComparison ] = None ,
229+ report_cmp : Optional [ReportResultComparison ] = None ,
220230 ) -> Union [Dict [str , Any ], List [Any ]]:
221231 """
222232 Runs the model.
@@ -228,7 +238,7 @@ def run(
228238 :param report_cmp: used as a reference,
229239 every intermediate results is compare to every existing one,
230240 if not empty, it is an instance of
231- :class:`onnx_diagnostic.reference.ReportResultsComparison `
241+ :class:`onnx_diagnostic.reference.ReportResultComparison `
232242 :return: outputs, as a list if return_all is False,
233243 as a dictionary if return_all is True
234244 """
@@ -437,8 +447,12 @@ def _get_sess(
437447 cls = (
438448 InferenceSessionForNumpy
439449 if any (isinstance (i , np .ndarray ) for i in inputs )
450+ and (not isinstance (self .torch_or_numpy , bool ) or not self .torch_or_numpy )
440451 else InferenceSessionForTorch
441452 )
453+ assert (
454+ cls is InferenceSessionForTorch
455+ ), f"ERROR: { string_type (inputs , with_shape = True )} "
442456 try :
443457 sess = cls (onx , ** self .session_kwargs )
444458 except (
@@ -497,6 +511,7 @@ def _get_sess_if(
497511 verbose = self .verbose ,
498512 ir_version = self .ir_version ,
499513 opsets = self .opsets ,
514+ torch_or_numpy = self .torch_or_numpy ,
500515 ** self .session_kwargs ,
501516 )
502517 return onx , sess
@@ -511,6 +526,7 @@ def _get_sess_local(
511526 verbose = self .verbose ,
512527 ir_version = self .ir_version ,
513528 opsets = self .opsets ,
529+ torch_or_numpy = self .torch_or_numpy ,
514530 ** self .session_kwargs ,
515531 )
516532 return ev .proto , sess
@@ -586,6 +602,7 @@ def _get_sess_scan(
586602 verbose = self .verbose ,
587603 ir_version = self .ir_version ,
588604 opsets = self .opsets ,
605+ torch_or_numpy = self .torch_or_numpy ,
589606 whole = True ,
590607 ** self .session_kwargs ,
591608 )
0 commit comments