@@ -646,6 +646,18 @@ def call_torch_export_export(
646646 strict = "nostrict" not in exporter
647647 args , kwargs = split_args_kwargs (data ["inputs_export" ])
648648 ds = data .get ("dynamic_shapes" , None )
649+
650+ summary ["export_exporter" ] = exporter
651+ summary ["export_optimization" ] = optimization or ""
652+ summary ["export_strict" ] = strict
653+ summary ["export_args" ] = string_type (args , with_shape = True )
654+ summary ["export_kwargs" ] = string_type (kwargs , with_shape = True )
655+ summary ["export_dynamic_shapes" ] = string_type (ds )
656+
657+ # There is an issue with DynamicShape [[],[]] becomes []
658+ dse = CoupleInputsDynamicShapes (args , kwargs , ds ).replace_string_by ()
659+ summary ["export_dynamic_shapes_export_export" ] = string_type (dse )
660+
649661 if verbose :
650662 print (
651663 f"[call_torch_export_export] exporter={ exporter !r} , "
@@ -654,18 +666,14 @@ def call_torch_export_export(
654666 print (f"[call_torch_export_export] args={ string_type (args , with_shape = True )} " )
655667 print (f"[call_torch_export_export] kwargs={ string_type (kwargs , with_shape = True )} " )
656668 print (f"[call_torch_export_export] dynamic_shapes={ _ds_clean (ds )} " )
669+ print (f"[call_torch_export_export] dynamic_shapes_export_export={ string_type (dse )} " )
657670 print ("[call_torch_export_export] export..." )
658- summary ["export_exporter" ] = exporter
659- summary ["export_optimization" ] = optimization or ""
660- summary ["export_strict" ] = strict
661- summary ["export_args" ] = string_type (args , with_shape = True )
662- summary ["export_kwargs" ] = string_type (kwargs , with_shape = True )
663671
664672 begin = time .perf_counter ()
665673 if quiet :
666674 try :
667675 ep = torch .export .export (
668- data ["model" ], args , kwargs = kwargs , dynamic_shapes = ds , strict = strict
676+ data ["model" ], args , kwargs = kwargs , dynamic_shapes = dse , strict = strict
669677 )
670678 except Exception as e :
671679 summary ["ERR_export_export" ] = str (e )
@@ -674,7 +682,7 @@ def call_torch_export_export(
674682 return summary , data
675683 else :
676684 ep = torch .export .export (
677- data ["model" ], args , kwargs = kwargs , dynamic_shapes = ds , strict = strict
685+ data ["model" ], args , kwargs = kwargs , dynamic_shapes = dse , strict = strict
678686 )
679687
680688 summary ["time_export_export" ] = time .perf_counter () - begin
0 commit comments