1010)
1111from onnx_diagnostic .reference import ExtendedReferenceEvaluator , OnnxruntimeEvaluator
1212from onnx_diagnostic .torch_export_patches .patch_inputs import use_dyn_not_str
13- from onnx_diagnostic .torch_onnx .sbs import run_aligned , post_process_run_aligned_obs
13+ from onnx_diagnostic .torch_onnx .sbs import run_aligned , RunAlignedRecord
1414from onnx_diagnostic .export .api import to_onnx
1515
1616
@@ -21,6 +21,24 @@ def setUpClass(cls):
2121
2222 cls .torch = torch
2323
24+ def test_run_aligned_record (self ):
25+ r = RunAlignedRecord (
26+ ep_id_node = - 1 ,
27+ onnx_id_node = - 1 ,
28+ ep_name = "A" ,
29+ onnx_name = "B" ,
30+ ep_target = "C" ,
31+ onnx_op_type = "D" ,
32+ shape_type = "E" ,
33+ err_abs = 0.1 ,
34+ err_rel = 0.2 ,
35+ err_dev = 0.3 ,
36+ err_nan = 0.4 ,
37+ )
38+ sr = str (r )
39+ self .assertIn ("RunAlignedRecord(" , sr )
40+ self .assertIn ("shape_type='E'" , sr )
41+
2442 @hide_stdout ()
2543 @unittest .skipIf (to_onnx is None , "to_onnx not installed" )
2644 @ignore_errors (OSError ) # connectivity issues
@@ -48,7 +66,7 @@ def forward(self, x):
4866 run_cls = ExtendedReferenceEvaluator ,
4967 atol = 1e-5 ,
5068 rtol = 1e-5 ,
51- verbose = 1 ,
69+ verbose = 10 ,
5270 ),
5371 )
5472 self .assertEqual (len (results ), 7 )
@@ -83,7 +101,7 @@ def forward(self, x):
83101 run_cls = ExtendedReferenceEvaluator ,
84102 atol = 1e-5 ,
85103 rtol = 1e-5 ,
86- verbose = 1 ,
104+ verbose = 10 ,
87105 ),
88106 )
89107 self .assertEqual (len (results ), 6 )
@@ -115,7 +133,7 @@ def forward(self, x):
115133 run_cls = ExtendedReferenceEvaluator ,
116134 atol = 1e-5 ,
117135 rtol = 1e-5 ,
118- verbose = 1 ,
136+ verbose = 10 ,
119137 ),
120138 )
121139 self .assertEqual (len (results ), 6 )
@@ -285,7 +303,10 @@ def forward(self, x):
285303 ),
286304 )
287305 self .assertEqual (len (results ), 14 )
288- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 14 )
306+ self .assertEqual (
307+ [r .err_dev for r in results ],
308+ [None , None , None , None , None , None , None , None , 0 , 0 , 0 , 0 , 0 , 0 ],
309+ )
289310
290311 @hide_stdout ()
291312 @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
@@ -323,7 +344,7 @@ def forward(self, x):
323344 use_tensor = True ,
324345 ),
325346 )
326- df = pandas .DataFrame (list (map ( post_process_run_aligned_obs , results ) ))
347+ df = pandas .DataFrame (list (results ))
327348 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom.xlsx" ))
328349 self .assertEqual (
329350 [
@@ -332,6 +353,7 @@ def forward(self, x):
332353 "ep_target" ,
333354 "err_abs" ,
334355 "err_dev" ,
356+ "err_nan" ,
335357 "err_rel" ,
336358 "onnx_id_node" ,
337359 "onnx_name" ,
@@ -341,7 +363,10 @@ def forward(self, x):
341363 sorted (df .columns ),
342364 )
343365 self .assertEqual (len (results ), 12 )
344- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 12 )
366+ self .assertEqual (
367+ [r .err_dev for r in results ],
368+ [None , None , None , None , None , None , None , None , None , 0 , 0 , 0 ],
369+ )
345370 self .assertEqual (
346371 [- 1.0 , - 1.0 , - 1.0 , - 1.0 , - 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
347372 df ["onnx_id_node" ].fillna (- 10 ).tolist (),
@@ -384,7 +409,7 @@ def forward(self, x):
384409 use_tensor = True ,
385410 ),
386411 )
387- df = pandas .DataFrame (list (map ( post_process_run_aligned_obs , results ) ))
412+ df = pandas .DataFrame (list (results ))
388413 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_dynamo.xlsx" ))
389414 self .assertEqual (
390415 [
@@ -393,6 +418,7 @@ def forward(self, x):
393418 "ep_target" ,
394419 "err_abs" ,
395420 "err_dev" ,
421+ "err_nan" ,
396422 "err_rel" ,
397423 "onnx_id_node" ,
398424 "onnx_name" ,
@@ -402,7 +428,10 @@ def forward(self, x):
402428 sorted (df .columns ),
403429 )
404430 self .assertEqual (len (results ), 12 )
405- self .assertEqual ([r [- 1 ].get ("dev" , 0 ) for r in results ], [0 ] * 12 )
431+ self .assertEqual (
432+ [r .err_dev for r in results ],
433+ [None , None , None , None , None , None , None , None , None , 0 , 0 , 0 ],
434+ )
406435 self .assertEqual (
407436 [- 1.0 , - 1.0 , - 1.0 , - 1.0 , - 10.0 , - 10.0 , - 10.0 , - 10.0 , - 1.0 , 0.0 , 1.0 , 2.0 ],
408437 df ["onnx_id_node" ].fillna (- 10 ).tolist (),
0 commit comments