@@ -113,6 +113,23 @@ def _duplicated_values(d):
113113 return final
114114
115115
116+ def _validation_nn_functional (
117+ node : onnx .NodeProto , new_feeds : Dict [str , torch .Tensor ], expected : List [torch .Tensor ]
118+ ) -> Optional [str ]:
119+ if node .op_type == "Gemm" and len (node .input ) == 3 :
120+ atts = {}
121+ for att in node .attribute :
122+ if att .name in ("alpha" , "beta" ):
123+ atts [att .name ] = att .f
124+ elif att .name in ("transA" , "transB" ):
125+ atts [att .name ] = att .i
126+ if atts == {"transB" : 1 }:
127+ res = torch .nn .functional .linear (* [new_feeds [i ] for i in node .input ])
128+ diff = max_diff (res , expected [0 ])
129+ return f"function.linear:{ string_diff (diff )} "
130+ return None
131+
132+
116133def _loop_onnx_node (
117134 onx : onnx .ModelProto ,
118135 ep_graph_nodes : List [torch .fx .Node ],
@@ -201,6 +218,7 @@ def _loop_onnx_node(
201218 f"node is { pretty_onnx (node )} "
202219 )
203220
221+ comment = None
204222 cross = None
205223 if run_onnx_with_torch_inputs :
206224 # Let's run the operator with torch results if they are available
@@ -223,8 +241,13 @@ def _loop_onnx_node(
223241 cross = ref .run (None , new_feeds )
224242 if verbose > 1 :
225243 print (f"[run_aligned] got for second run={ string_type (cross , ** str_kws )} " )
244+ # Gemm = torch.nn.function.linear, in that case, we just run it as well
245+ to = mapping_onnx_to_torch .get (node .output [0 ], node .output [0 ])
246+ if to in torch_results :
247+ comment = _validation_nn_functional (node , new_feeds , [torch_results [to ]])
226248 elif verbose > 1 :
227249 print (f"[run_aligned] second run not possible because of missing { removed } " )
250+
228251 if cross is None :
229252 cross = [None for _ in res ]
230253
@@ -264,6 +287,7 @@ def _loop_onnx_node(
264287 status .yielded_nodes += 1
265288 if tmp .err_abs is not None :
266289 status .update (tmp .err_abs )
290+ tmp .comment = comment
267291 yield tmp
268292
269293 # do we need to dump pieces if graph the user can replay?
@@ -1055,6 +1079,7 @@ def forward(self, x):
10551079 if r :
10561080 yield r .check (already_yielded )
10571081
1058- loop .close ()
1082+ if loop is not None :
1083+ loop .close ()
10591084 if verbose :
10601085 print (f"[run_aligned] done with status={ status .to_str ()} " )
0 commit comments