@@ -43,8 +43,8 @@ def _loop_cmp(
4343 onnx_name : str ,
4444 torch_result : torch .Tensor ,
4545 verbose : int ,
46- atol : float ,
47- rtol : float ,
46+ atol : Optional [ float ] ,
47+ rtol : Optional [ float ] ,
4848 i_torch : int ,
4949 i_onnx : int ,
5050 str_kws : Dict [str , bool ],
@@ -140,6 +140,7 @@ def _loop_onnx_node(
140140 f"{ node .op_type } ({ ', ' .join (node .input )} ) -> { ', ' .join (node .output )} "
141141 )
142142 elif verbose == 1 :
143+ loop .update (i_torch + i_onnx )
143144 loop .set_description (
144145 f"ep { i_torch } /{ len (ep_graph_nodes )} nx { i_onnx } /{ len (onx .graph .node )} "
145146 f"{ status .to_str ()} "
@@ -435,7 +436,7 @@ def _preparation_with_onnx_model(
435436 t = torch_results [init .name ]
436437 torch_names_to_onnx_names [init .name ] = init .name
437438 elif init .name not in skip_onnx_name and init .name in rev_init_aliases :
438- new_names = [ # type: ignore[assignment]
439+ new_names = [
439440 k
440441 for k in rev_init_aliases [init .name ]
441442 if k in torch_results and k not in skip_mapping_torch_onnx
@@ -657,8 +658,8 @@ def forward(self, x):
657658 -v 1 --atol=0.1 --rtol=1
658659 """
659660 assert callable (run_cls ), f"run_cls={ run_cls } not a callable"
660- already_yielded = {} # type: ignore[var-annotated]
661- reset_names = set (reset_names ) if reset_names else set () # type: ignore[assignment]
661+ already_yielded = {}
662+ reset_names = set (reset_names ) if reset_names else set ()
662663 str_kws = dict (with_shape = True , with_device = True )
663664 has_cuda = any (
664665 (isinstance (t , torch .Tensor ) and t .is_cuda )
@@ -777,25 +778,28 @@ def forward(self, x):
777778 if verbose == 1 :
778779 import tqdm
779780
780- loop = tqdm .tqdm (list ( enumerate ( ep_graph_nodes )))
781+ loop = tqdm .tqdm (total = len ( ep_graph_nodes ) + len ( onx . graph . node ))
781782 else :
782- loop = list ( enumerate ( ep_graph_nodes ))
783+ loop = None
783784
784785 already_run : Set [int ] = set ()
785786 ep_durations = {}
786787 status = StatusRunAligned ()
787- for i , node in loop :
788+ for i_torch , node in enumerate ( ep_graph_nodes ) :
788789 if verbose > 1 :
789790 if node .op == "call_function" :
790791 print (
791- f"[run_aligned] run ep.graph.nodes[{ i } ]: "
792+ f"[run_aligned] run ep.graph.nodes[{ i_torch } ]: "
792793 f"{ node .op } [{ node .target } ] -> { node .name !r} "
793794 )
794795 else :
795- print (f"[run_aligned] run ep.graph.nodes[{ i } ]: { node .op } -> { node .name !r} " )
796+ print (
797+ f"[run_aligned] run ep.graph.nodes[{ i_torch } ]: { node .op } -> { node .name !r} "
798+ )
796799 elif verbose == 1 :
800+ loop .update (i_torch + last_position )
797801 loop .set_description (
798- f"ep { i } /{ len (ep_graph_nodes )} nx { last_position } /{ len (onx .graph .node )} "
802+ f"ep { i_torch } /{ len (ep_graph_nodes )} nx { last_position } /{ len (onx .graph .node )} "
799803 f"{ status .to_str ()} "
800804 )
801805
@@ -816,7 +820,7 @@ def forward(self, x):
816820 print (f"[run_aligned-ep] =ags: { node .name } ={ string_type (t , ** str_kws )} " )
817821 # Otherwise, it is an input.
818822 record = RunAlignedRecord (
819- ep_id_node = i ,
823+ ep_id_node = i_torch ,
820824 onnx_id_node = - 1 ,
821825 ep_name = node .name ,
822826 onnx_name = torch_names_to_onnx_names [node .name ],
@@ -848,7 +852,7 @@ def forward(self, x):
848852 f"{ node .name } ={ string_type (t , ** str_kws )} "
849853 )
850854 record = RunAlignedRecord (
851- ep_id_node = i ,
855+ ep_id_node = i_torch ,
852856 onnx_id_node = - 1 ,
853857 ep_name = node .name ,
854858 onnx_name = torch_names_to_onnx_names [node .name ],
@@ -874,7 +878,7 @@ def forward(self, x):
874878 f"[run_aligned-ep] +plh: { node .name } ={ string_type (t , ** str_kws )} "
875879 )
876880 yield RunAlignedRecord (
877- ep_id_node = i ,
881+ ep_id_node = i_torch ,
878882 ep_name = node .name ,
879883 ep_target = "placeholder" ,
880884 ep_shape_type = string_type (t , ** str_kws ),
@@ -886,7 +890,7 @@ def forward(self, x):
886890 begin = time .perf_counter ()
887891 new_outputs = run_fx_node (node , args , kwargs )
888892 duration = time .perf_counter () - begin
889- ep_durations [i ] = duration
893+ ep_durations [i_torch ] = duration
890894 if isinstance (new_outputs , (torch .Tensor , int , float , list , tuple )):
891895 new_outputs = (new_outputs ,)
892896
@@ -906,7 +910,7 @@ def forward(self, x):
906910 if "onnx" in positions [n ]:
907911 max_pos = max (max_pos , positions [n ]["onnx" ])
908912 if "fx" in positions [n ]:
909- if positions [n ]["fx" ] > i :
913+ if positions [n ]["fx" ] > i_torch :
910914 max_pos = - 2
911915 break
912916 if max_pos == - 2 :
@@ -923,7 +927,7 @@ def forward(self, x):
923927 ep_behind = False
924928 for iname in node .output :
925929 if iname in positions and "fx" in positions [iname ]:
926- if positions [iname ]["fx" ] > i :
930+ if positions [iname ]["fx" ] > i_torch :
927931 ep_behind = True
928932 break
929933 if ep_behind :
@@ -937,7 +941,7 @@ def forward(self, x):
937941 torch_results ,
938942 ep_durations ,
939943 use_tensor ,
940- i ,
944+ i_torch ,
941945 i_onnx ,
942946 name_to_ep_node ,
943947 run_cls_kwargs ,
@@ -978,7 +982,7 @@ def forward(self, x):
978982 torch_results ,
979983 ep_durations ,
980984 use_tensor ,
981- i ,
985+ i_torch ,
982986 i_onnx ,
983987 name_to_ep_node ,
984988 run_cls_kwargs ,
0 commit comments