@@ -567,10 +567,11 @@ def _loop_cmp(
567567 print (f"[run_aligned-nx] +inp: { inp .name } : { string_type (v , ** str_kws )} " )
568568
569569 placeholders = {node .name for node in ep .graph .nodes if node .op == "placeholder" }
570- ep_state_dict = {** ep .state_dict , ** dict (ep .named_buffers ())}
570+ ep_state_dict = {** ep .state_dict , ** dict (ep .named_buffers (), ** ep . tensor_constants )}
571571 placeholders_to_state_dict = {
572572 ** {f"p_{ name .replace ('.' , '_' )} " : name for name in ep .state_dict },
573573 ** {f"b_{ name .replace ('.' , '_' )} " : name for name , _ in ep .named_buffers ()},
574+ ** {f"c_{ name .replace ('.' , '_' )} " : name for name in ep .tensor_constants },
574575 }
575576 for n in onnx_results :
576577 if n not in placeholders :
@@ -588,6 +589,7 @@ def _loop_cmp(
588589 else :
589590 loop = list (enumerate (ep_graph_nodes ))
590591
592+ already_run = set ()
591593 ep_durations = {}
592594 yielded_nodes = 0
593595 max_abs = 0
@@ -641,8 +643,8 @@ def _loop_cmp(
641643 yield record
642644 else :
643645 assert node .name in placeholders_to_state_dict , (
644- f"Unable to find placeholder { node .name !r} in "
645- f"{ sorted (placeholders_to_state_dict )} "
646+ f"Unable to find placeholder { node .name !r} (node.op= { node . op !r } ), "
647+ f"existing: { sorted (placeholders_to_state_dict )} "
646648 )
647649 torch_results [node .name ] = ep_state_dict [placeholders_to_state_dict [node .name ]]
648650 if verbose > 1 :
@@ -683,6 +685,8 @@ def _loop_cmp(
683685 continue
684686
685687 for i_onnx in range (last_position , max_pos + 1 ):
688+ if i_onnx in already_run :
689+ continue
686690 node = onx .graph .node [i_onnx ]
687691 if verbose > 1 :
688692 print (
@@ -695,9 +699,16 @@ def _loop_cmp(
695699 f"mapped { yielded_nodes } maxabs { max_abs :1.5f} "
696700 )
697701 ref = run_cls (node , ** run_cls_kwargs )
698- feeds = {k : onnx_results [k ] for k in node .input }
702+ feeds = {k : onnx_results [k ] for k in node .input if k }
703+ assert "" not in feeds , f"Unexpected feeds={ string_type (feeds , ** str_kws )} "
699704 begin = time .perf_counter ()
700- res = ref .run (None , feeds ) # type: ignore[attr-defined]
705+ try :
706+ res = ref .run (None , feeds ) # type: ignore[attr-defined]
707+ except Exception as e :
708+ raise RuntimeError (
709+ f"Unable to run node { node .op_type } , domain={ node .domain } "
710+ f"with inputs={ node .input } , feeds={ string_type (feeds , ** str_kws )} "
711+ ) from e
701712 duration = time .perf_counter () - begin
702713 assert (
703714 not has_cuda
@@ -748,6 +759,7 @@ def _loop_cmp(
748759 if tmp .err_abs is not None :
749760 max_abs = max (max_abs , tmp .err_abs )
750761 yield tmp
762+ already_run .add (i_onnx )
751763
752764 last_position = max_pos + 1
753765
@@ -758,14 +770,17 @@ def _loop_cmp(
758770 f"to { len (onx .graph .node )} "
759771 )
760772 for i_onnx in range (last_position , len (onx .graph .node )):
773+ if i_onnx in already_run :
774+ continue
761775 node = onx .graph .node [i_onnx ]
762776 if verbose > 1 :
763777 print (
764778 f"[run_aligned] run onx.graph.node[{ i_onnx } ]: "
765779 f"{ node .op_type } ({ ', ' .join (node .input )} ) -> { ', ' .join (node .output )} "
766780 )
767781 ref = run_cls (node , ** run_cls_kwargs )
768- feeds = {k : onnx_results [k ] for k in node .input }
782+ feeds = {k : onnx_results [k ] for k in node .input if k }
783+ assert "" not in feeds , f"Unexpected feeds={ string_type (feeds , ** str_kws )} "
769784 begin = time .perf_counter ()
770785 res = ref .run (None , feeds ) # type: ignore[attr-defined]
771786 duration = time .perf_counter () - begin
@@ -800,6 +815,8 @@ def _loop_cmp(
800815 if tmp .err_abs is not None :
801816 max_abs = max (max_abs , tmp .err_abs )
802817 yield tmp
818+ already_run .add (i_onnx )
819+
803820 if verbose :
804821 print (f"[run_aligned] done with { yielded_nodes } mapped nodes" )
805822 print (f"[run_aligned] max absolution error={ max_abs } " )
0 commit comments