@@ -379,7 +379,7 @@ def forward(self, x):
379379 use_tensor = True ,
380380 ),
381381 )
382- df = pandas .DataFrame (list (results ))
382+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
383383 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom.xlsx" ))
384384 self .assertEqual (
385385 [
@@ -391,7 +391,6 @@ def forward(self, x):
391391 "err_abs" ,
392392 "err_dev" ,
393393 "err_h01" ,
394- "err_nan" ,
395394 "err_rel" ,
396395 "onnx_id_node" ,
397396 "onnx_id_output" ,
@@ -445,7 +444,7 @@ def forward(self, x):
445444 use_tensor = True ,
446445 ),
447446 )
448- df = pandas .DataFrame (list (results ))
447+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
449448 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_dynamo.xlsx" ))
450449 self .assertEqual (
451450 [
@@ -457,7 +456,6 @@ def forward(self, x):
457456 "err_abs" ,
458457 "err_dev" ,
459458 "err_h01" ,
460- "err_nan" ,
461459 "err_rel" ,
462460 "onnx_id_node" ,
463461 "onnx_id_output" ,
@@ -542,7 +540,7 @@ def forward(self, x):
542540 reset_names = ["linear" ],
543541 ),
544542 )
545- df = pandas .DataFrame (list (results ))
543+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
546544 df .to_excel (self .get_dump_file ("test_sbs_model_with_weights_custom_reset.xlsx" ))
547545 onnx_op_type = df ["onnx_op_type" ].tolist ()
548546 self .assertEqual (onnx_op_type .count ("reset" ), 1 )
@@ -593,10 +591,80 @@ def forward(self, x):
593591 ),
594592 ),
595593 )
596- df = pandas .DataFrame (list (results ))
594+ df = pandas .DataFrame (list (results )). dropna ( axis = 1 , how = "all" )
597595 df .to_excel (self .get_dump_file ("test_sbs_replay.xlsx" ))
598- print (df )
599- # self.clean_dump()
596+ self .assertEqual (df .shape , (8 , 15 ))
597+ self .clean_dump ()
598+
599+ @hide_stdout ()
600+ @ignore_warnings ((DeprecationWarning , FutureWarning , UserWarning ))
601+ def test_sbs_run_onnx_with_torch_inputs (self ):
602+ torch = self .torch
603+
604+ class Model (self .torch .nn .Module ):
605+ def __init__ (self ):
606+ super (Model , self ).__init__ ()
607+ self .fc1 = torch .nn .Linear (10 , 32 ) # input size 10 → hidden size 32
608+ self .relu = torch .nn .ReLU ()
609+ self .fc2 = torch .nn .Linear (32 , 1 ) # hidden → output
610+
611+ def forward (self , x ):
612+ x = self .relu (self .fc1 (x ))
613+ x = self .fc2 (x )
614+ return x
615+
616+ inputs = dict (x = self .torch .randn ((5 , 10 )))
617+ ds = dict (x = {0 : "batch" })
618+ Model ()(** inputs )
619+ ep = self .torch .export .export (
620+ Model (), (), kwargs = inputs , dynamic_shapes = use_dyn_not_str (ds )
621+ )
622+ filename = self .get_dump_file ("test_sbs_run_onnx_with_torch_inputs.onnx" )
623+ to_onnx (ep , exporter = "custom" , filename = filename )
624+ onx = onnx .load (filename )
625+ results = list (
626+ run_aligned (
627+ ep ,
628+ onx ,
629+ kwargs = inputs ,
630+ run_cls = OnnxruntimeEvaluator ,
631+ verbose = 11 ,
632+ use_tensor = True ,
633+ run_onnx_with_torch_inputs = True ,
634+ ),
635+ )
636+ df = pandas .DataFrame (list (results )).dropna (axis = 1 , how = "all" )
637+ df .to_excel (self .get_dump_file ("test_sbs_run_onnx_with_torch_inputs.xlsx" ))
638+ self .assertEqual (
639+ [
640+ "ep_id_node" ,
641+ "ep_name" ,
642+ "ep_shape_type" ,
643+ "ep_target" ,
644+ "ep_time_run" ,
645+ "err_abs" ,
646+ "err_abs2" ,
647+ "err_dev" ,
648+ "err_dev2" ,
649+ "err_h01" ,
650+ "err_h012" ,
651+ "err_rel" ,
652+ "err_rel2" ,
653+ "onnx_id_node" ,
654+ "onnx_id_output" ,
655+ "onnx_name" ,
656+ "onnx_op_type" ,
657+ "onnx_shape_type" ,
658+ "onnx_time_run" ,
659+ ],
660+ sorted (df .columns ),
661+ )
662+ self .assertEqual (len (results ), 8 )
663+ self .assertEqual ([0 , 0 , 0 , 0 , None , 0 , 0 , 0 ], [r .err_dev for r in results ])
664+ self .assertEqual (
665+ [- 1 , - 1 , - 1 , - 1 , - 1 , 0 , 1 , 2 ], df ["onnx_id_node" ].fillna (- 10 ).tolist ()
666+ )
667+ self .clean_dump ()
600668
601669
602670if __name__ == "__main__" :
0 commit comments