@@ -108,6 +108,21 @@ def run_fx_node(
108108 return args
109109 if node .op == "call_function" :
110110 assert callable (node .target ), f"{ node .target !r} not callable in node { node !r} "
111+ for a , ea in zip (args , node .args ):
112+ if isinstance (a , torch .Tensor ) and hasattr (ea , "meta" ) and "val" in ea .meta :
113+ ta = ea .meta ["val" ]
114+ # if not isinstance(ta, torch.Tensor):
115+ # print("******", args)
116+ # print("******", node.args)
117+ # print("******", node.kwargs)
118+ # print("******", node.meta)
119+ # print(ta)
120+ assert len (a .shape ) == len (ta .shape ) and a .dtype == ta .dtype , (
121+ f"Unable to run node { node !r} , target={ node .target !r} , "
122+ f"node.args={ node .args !r} , node.kwargs={ node .kwargs !r} , "
123+ f"args={ string_type (args , with_shape = True , with_device = True )} , "
124+ f"kwargs={ string_type (kwargs , with_shape = True , with_device = True )} "
125+ )
111126 try :
112127 outputs = node .target (* args , ** (kwargs or {}))
113128 except RuntimeError as e :
@@ -381,7 +396,7 @@ def forward(self, x):
381396 -v 1 --atol=0.1 --rtol=1
382397 """
383398 assert callable (run_cls ), f"run_cls={ run_cls } not a callable"
384- str_kws = dict (with_shape = True , with_device = True , with_min_max = True )
399+ str_kws = dict (with_shape = True , with_device = True )
385400 has_cuda = any (
386401 (isinstance (t , torch .Tensor ) and t .is_cuda )
387402 for t in flatten_object ([args , kwargs ], drop_keys = True )
@@ -592,7 +607,12 @@ def _loop_cmp(
592607 print (f"[run_aligned] run ep.graph.nodes[{ i } ]: { node .op } -> { node .name !r} " )
593608
594609 if node .op == "placeholder" :
595- if node .name in onnx_results :
610+ is_input = node .name in placeholders
611+ if node .name in onnx_results and (
612+ is_input
613+ or ep_state_dict [placeholders_to_state_dict [node .name ]].shape
614+ == onnx_results [node .name ]
615+ ):
596616 torch_results [node .name ] = (
597617 onnx_results [node .name ]
598618 if use_tensor
@@ -602,7 +622,6 @@ def _loop_cmp(
602622 t = torch_results [node .name ]
603623 print (f"[run_aligned-ep] =plh: { node .name } ={ string_type (t , ** str_kws )} " )
604624 # Otherwise, it is an input.
605- is_input = node .name in placeholders
606625 yield (
607626 - 1 ,
608627 - 1 ,
@@ -615,7 +634,8 @@ def _loop_cmp(
615634 {}
616635 if is_input
617636 else max_diff (
618- placeholders_to_state_dict [node .name ], onnx_results [node .name ]
637+ ep_state_dict [placeholders_to_state_dict [node .name ]],
638+ onnx_results [node .name ],
619639 )
620640 ),
621641 )
0 commit comments