@@ -111,13 +111,11 @@ def run_fx_node(
111111 for a , ea in zip (args , node .args ):
112112 if isinstance (a , torch .Tensor ) and hasattr (ea , "meta" ) and "val" in ea .meta :
113113 ta = ea .meta ["val" ]
114- # if not isinstance(ta, torch.Tensor):
115- # print("******", args)
116- # print("******", node.args)
117- # print("******", node.kwargs)
118- # print("******", node.meta)
119- # print(ta)
120- assert len (a .shape ) == len (ta .shape ) and a .dtype == ta .dtype , (
114+ assert (
115+ isinstance (ta , torch .Tensor )
116+ and len (a .shape ) == len (ta .shape )
117+ and a .dtype == ta .dtype
118+ ), (
121119 f"Unable to run node { node !r} , target={ node .target !r} , "
122120 f"node.args={ node .args !r} , node.kwargs={ node .kwargs !r} , "
123121 f"args={ string_type (args , with_shape = True , with_device = True )} , "
@@ -672,7 +670,7 @@ def _loop_cmp(
672670 outputs = [node .name ] if isinstance (node .name , str ) else list (node .name )
673671 args , kwargs = prepare_args_kwargs (torch_results , node )
674672 new_outputs = run_fx_node (node , args , kwargs )
675- if isinstance (new_outputs , (torch .Tensor , int , float , list )):
673+ if isinstance (new_outputs , (torch .Tensor , int , float , list , tuple )):
676674 new_outputs = (new_outputs ,)
677675
678676 if new_outputs is None :
0 commit comments