@@ -379,7 +379,7 @@ def forward(self, x):
379379 use_tensor = True ,
380380 ),
381381 )
382- df = pandas .DataFrame (list (results ))
382+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
383383 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom.xlsx" ))
384384 self .assertEqual (
385385 [
@@ -390,8 +390,8 @@ def forward(self, x):
390390 "ep_time_run" ,
391391 "err_abs" ,
392392 "err_dev" ,
393+ "err_h001" ,
393394 "err_h01" ,
394- "err_nan" ,
395395 "err_rel" ,
396396 "onnx_id_node" ,
397397 "onnx_id_output" ,
@@ -445,7 +445,7 @@ def forward(self, x):
445445 use_tensor = True ,
446446 ),
447447 )
448- df = pandas .DataFrame (list (results ))
448+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
449449 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_dynamo.xlsx" ))
450450 self .assertEqual (
451451 [
@@ -456,8 +456,8 @@ def forward(self, x):
456456 "ep_time_run" ,
457457 "err_abs" ,
458458 "err_dev" ,
459+ "err_h001" ,
459460 "err_h01" ,
460- "err_nan" ,
461461 "err_rel" ,
462462 "onnx_id_node" ,
463463 "onnx_id_output" ,
@@ -542,7 +542,7 @@ def forward(self, x):
542542 reset_names = ["linear" ],
543543 ),
544544 )
545- df = pandas .DataFrame (list (results ))
545+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
546546 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom_reset.xlsx" ))
547547 onnx_op_type = df ["onnx_op_type" ].tolist ()
548548 self .assertEqual (onnx_op_type .count ("reset" ), 1 )
@@ -593,10 +593,83 @@ def forward(self, x):
593593 ),
594594 ),
595595 )
596- df = pandas .DataFrame (list (results ))
596+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
597597 df .to_excel (self .get_dump_file ("test_sbs_replay.xlsx" ))
598- print (df )
599- # self.clean_dump()
598+ self .assertEqual (df .shape , (8 , 16 ))
599+ self .clean_dump ()
600+
601+ @hide_stdout ()
602+ @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
603+ def test_sbs_run_onnx_with_torch_inputs (self ):
604+ torch = self .torch
605+
606+ class Model (self .torch .nn .Module ):
607+ def __init__ (self ):
608+ super (Model , self ).__init__ ()
609+ self .fc1 = torch .nn .Linear (10 , 32 ) # input size 10 → hidden size 32
610+ self .relu = torch .nn .ReLU ()
611+ self .fc2 = torch .nn .Linear (32 , 1 ) # hidden → output
612+
613+ def forward (self , x ):
614+ x = self .relu (self .fc1 (x ))
615+ x = self .fc2 (x )
616+ return x
617+
618+ inputs = dict (x = self .torch .randn ((5 , 10 )))
619+ ds = dict (x = {0 : "batch" })
620+ Model ()(** inputs )
621+ ep = self .torch .export .export (
622+ Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
623+ )
624+ filename = self .get_dump_file ("test_sbs_run_onnx_with_torch_inputs.onnx" )
625+ to_onnx (ep , exporter = "custom" , filename = filename )
626+ onx = onnx .load (filename )
627+ results = list (
628+ run_aligned (
629+ ep ,
630+ onx ,
631+ kwargs = inputs ,
632+ run_cls = OnnxruntimeEvaluator ,
633+ verbose = 11 ,
634+ use_tensor = True ,
635+ run_onnx_with_torch_inputs = True ,
636+ ),
637+ )
638+ df = pandas .DataFrame (list (results )).dropna (axis = 1 , how = "all" )
639+ df .to_excel (self .get_dump_file ("test_sbs_run_onnx_with_torch_inputs.xlsx" ))
640+ self .assertEqual (
641+ [
642+ "comment" ,
643+ "ep_id_node" ,
644+ "ep_name" ,
645+ "ep_shape_type" ,
646+ "ep_target" ,
647+ "ep_time_run" ,
648+ "err_abs" ,
649+ "err_abs2" ,
650+ "err_dev" ,
651+ "err_dev2" ,
652+ "err_h001" ,
653+ "err_h0012" ,
654+ "err_h01" ,
655+ "err_h012" ,
656+ "err_rel" ,
657+ "err_rel2" ,
658+ "onnx_id_node" ,
659+ "onnx_id_output" ,
660+ "onnx_name" ,
661+ "onnx_op_type" ,
662+ "onnx_shape_type" ,
663+ "onnx_time_run" ,
664+ ],
665+ sorted (df .columns ),
666+ )
667+ self .assertEqual (len (results ), 8 )
668+ self .assertEqual ([0 , 0 , 0 , 0 , None , 0 , 0 , 0 ], [r .err_dev for r in results ])
669+ self .assertEqual (
670+ [- 1 , - 1 , - 1 , - 1 , - 1 , 0 , 1 , 2 ], df ["onnx_id_node" ].fillna (- 10 ).tolist ()
671+ )
672+ self .clean_dump ()
600673
601674
602675if __name__ == "__main__" :
0 commit comments