@@ -29,15 +29,15 @@ def test_run_aligned_record(self):
2929 onnx_name = "B" ,
3030 ep_target = "C" ,
3131 onnx_op_type = "D" ,
32- shape_type = "E" ,
32+ ep_shape_type = "E" ,
3333 err_abs = 0.1 ,
3434 err_rel = 0.2 ,
3535 err_dev = 0.3 ,
3636 err_nan = 0.4 ,
3737 )
3838 sr = str (r )
3939 self .assertIn ("RunAlignedRecord(" , sr )
40- self .assertIn ("shape_type ='E'" , sr )
40+ self .assertIn ("ep_shape_type ='E'" , sr )
4141
4242 @hide_stdout ()
4343 @unittest .skipIf (to_onnx is None , "to_onnx not installed" )
@@ -69,7 +69,7 @@ def forward(self, x):
6969 verbose = 10 ,
7070 ),
7171 )
72- self .assertEqual (len (results ), 6 )
72+ self .assertEqual (len (results ), 7 )
7373
7474 @hide_stdout ()
7575 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -104,7 +104,7 @@ def forward(self, x):
104104 verbose = 10 ,
105105 ),
106106 )
107- self .assertEqual (len (results ), 5 )
107+ self .assertEqual (len (results ), 6 )
108108
109109 @hide_stdout ()
110110 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -136,7 +136,7 @@ def forward(self, x):
136136 verbose = 10 ,
137137 ),
138138 )
139- self .assertEqual (len (results ), 5 )
139+ self .assertEqual (len (results ), 6 )
140140
141141 @hide_stdout ()
142142 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -167,7 +167,7 @@ def forward(self, x):
167167 verbose = 11 ,
168168 ),
169169 )
170- self .assertEqual (len (results ), 6 )
170+ self .assertEqual (len (results ), 7 )
171171
172172 @hide_stdout ()
173173 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -199,7 +199,7 @@ def forward(self, x):
199199 use_tensor = True ,
200200 ),
201201 )
202- self .assertEqual (len (results ), 7 )
202+ self .assertEqual (len (results ), 8 )
203203
204204 @hide_stdout ()
205205 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -232,7 +232,7 @@ def forward(self, x):
232232 use_tensor = True ,
233233 ),
234234 )
235- self .assertEqual (len (results ), 7 )
235+ self .assertEqual (len (results ), 8 )
236236
237237 @hide_stdout ()
238238 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -267,7 +267,7 @@ def forward(self, x):
267267 use_tensor = True ,
268268 ),
269269 )
270- self .assertEqual (len (results ), 8 )
270+ self .assertEqual (len (results ), 14 )
271271
272272 @hide_stdout ()
273273 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -301,9 +301,9 @@ def forward(self, x):
301301 use_tensor = True ,
302302 ),
303303 )
304- self .assertEqual (len (results ), 8 )
304+ self .assertEqual (len (results ), 14 )
305305 self .assertEqual (
306- [None , None , 0 , 0 , 0 , 0 , 0 , 0 ],
306+ [None , None , None , None , None , None , None , None , 0 , 0 , 0 , 0 , 0 , 0 ],
307307 [r .err_dev for r in results ],
308308 )
309309
@@ -349,29 +349,27 @@ def forward(self, x):
349349 [
350350 "ep_id_node" ,
351351 "ep_name" ,
352+ "ep_shape_type" ,
352353 "ep_target" ,
353354 "ep_time_run" ,
354355 "err_abs" ,
355356 "err_dev" ,
357+ "err_h01" ,
356358 "err_nan" ,
357359 "err_rel" ,
358360 "onnx_id_node" ,
359361 "onnx_id_output" ,
360362 "onnx_name" ,
361363 "onnx_op_type" ,
364+ "onnx_shape_type" ,
362365 "onnx_time_run" ,
363- "shape_type" ,
364366 ],
365367 sorted (df .columns ),
366368 )
367369 self .assertEqual (len (results ), 8 )
370+ self .assertEqual ([0 , 0 , 0 , 0 , None , 0 , 0 , 0 ], [r .err_dev for r in results ])
368371 self .assertEqual (
369- [None , None , None , None , None , 0 , 0 , 0 ],
370- [r .err_dev for r in results ],
371- )
372- self .assertEqual (
373- [- 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
374- df ["onnx_id_node" ].fillna (- 10 ).tolist (),
372+ [- 1 , - 1 , - 1 , - 1 , - 1 , 0 , 1 , 2 ], df ["onnx_id_node" ].fillna (- 10 ).tolist ()
375373 )
376374 self .clean_dump ()
377375
@@ -417,29 +415,27 @@ def forward(self, x):
417415 [
418416 "ep_id_node" ,
419417 "ep_name" ,
418+ "ep_shape_type" ,
420419 "ep_target" ,
421420 "ep_time_run" ,
422421 "err_abs" ,
423422 "err_dev" ,
423+ "err_h01" ,
424424 "err_nan" ,
425425 "err_rel" ,
426426 "onnx_id_node" ,
427427 "onnx_id_output" ,
428428 "onnx_name" ,
429429 "onnx_op_type" ,
430+ "onnx_shape_type" ,
430431 "onnx_time_run" ,
431- "shape_type" ,
432432 ],
433433 sorted (df .columns ),
434434 )
435435 self .assertEqual (len (results ), 8 )
436+ self .assertEqual ([0 , 0 , 0 , 0 , None , 0 , 0 , 0 ], [r .err_dev for r in results ])
436437 self .assertEqual (
437- [None , None , None , None , None , 0 , 0 , 0 ],
438- [r .err_dev for r in results ],
439- )
440- self .assertEqual (
441- [- 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
442- df ["onnx_id_node" ].fillna (- 10 ).tolist (),
438+ [- 1 , - 1 , - 1 , - 1 , - 1 , 0 , 1 , 2 ], df ["onnx_id_node" ].fillna (- 10 ).tolist ()
443439 )
444440 self .clean_dump ()
445441
@@ -466,7 +462,7 @@ def forward(self, x):
466462 use_tensor = True ,
467463 ),
468464 )
469- self .assertEqual (len (results ), 2 )
465+ self .assertEqual (len (results ), 5 )
470466
471467
472468if __name__ == "__main__" :
0 commit comments