@@ -41,13 +41,13 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
4141 return oh .make_model (
4242 oh .make_graph (
4343 [
44- oh .make_node ("Gemm" , ["A" , "X" , "B" ], ["Ygemmfused " ]),
44+ oh .make_node ("Gemm" , ["A" , "X" , "B" ], ["GemmOnly " ]),
4545 oh .make_node ("Gemm" , ["A" , "X" ], ["gmm" ]),
46- oh .make_node ("Add" , ["gmm" , "B" ], ["Ygemm " ]),
46+ oh .make_node ("Add" , ["gmm" , "B" ], ["GemmAdd " ]),
4747 oh .make_node ("MatMul" , ["A" , "X" ], ["mm" ]),
48- oh .make_node ("Add" , ["mm" , "B" ], ["Ymm " ]),
48+ oh .make_node ("Add" , ["mm" , "B" ], ["MatMulAdd " ]),
4949 oh .make_node ("FusedMatMul" , ["A" , "X" ], ["fmm" ], domain = "com.microsoft" ),
50- oh .make_node ("Add" , ["fmm" , "B" ], ["Yfused " ]),
50+ oh .make_node ("Add" , ["fmm" , "B" ], ["FusedMatMulAdd " ]),
5151 ],
5252 "test" ,
5353 [
@@ -56,10 +56,10 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5656 oh .make_tensor_value_info ("B" , itype , ["c" ]),
5757 ],
5858 [
59- oh .make_tensor_value_info ("Ygemmfused " , itype , ["a" , "c" ]),
60- oh .make_tensor_value_info ("Yfused " , itype , ["a" , "c" ]),
61- oh .make_tensor_value_info ("Ygemm " , itype , ["a" , "c" ]),
62- oh .make_tensor_value_info ("Ymm " , itype , ["a" , "c" ]),
59+ oh .make_tensor_value_info ("GemmOnly " , itype , ["a" , "c" ]),
60+ oh .make_tensor_value_info ("GemmAdd " , itype , ["a" , "c" ]),
61+ oh .make_tensor_value_info ("FusedMatMulAdd " , itype , ["a" , "c" ]),
62+ oh .make_tensor_value_info ("MatMulAdd " , itype , ["a" , "c" ]),
6363 ],
6464 ),
6565 opset_imports = [oh .make_opsetid ("" , 22 )],
@@ -140,13 +140,17 @@ def matrix_diff(tensors):
140140# a lot higher.
141141
142142B = (torch .arange (512 , dtype = torch .float32 ) + 1 ) / 512 * 16384
143- labels = ["torch" , * [o .name for o in model .graph .output ]]
143+ labels = ["linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
144+ all_results = {}
144145
145146for itype , dtype , device in [
146147 (onnx .TensorProto .FLOAT , torch .float32 , "cpu" ),
147148 (onnx .TensorProto .FLOAT16 , torch .float16 , "cpu" ),
149+ # missing implementation in onnxruntime
150+ # (onnx.TensorProto.BFLOAT16, torch.bfloat16, "cpu"),
148151 (onnx .TensorProto .FLOAT , torch .float32 , "cuda" ),
149152 (onnx .TensorProto .FLOAT16 , torch .float16 , "cuda" ),
153+ (onnx .TensorProto .BFLOAT16 , torch .bfloat16 , "cuda" ),
150154]:
151155 if device == "cuda" and not torch .cuda .is_available ():
152156 continue
@@ -163,8 +167,9 @@ def matrix_diff(tensors):
163167 graph_optimization_level = GraphOptimizationLevel .ORT_DISABLE_ALL ,
164168 optimized_model_filepath = filename ,
165169 )
170+ results = [torch .nn .functional .linear (a , x .T , b ), * sess .run (None , feeds ), a @ x + b ]
171+ all_results [device , dtype ] = results
166172 has_cast = "Cast" in [n .op_type for n in onnx .load (filename ).graph .node ]
167- results = [a @ x + b , * sess .run (None , feeds )]
168173 diffs = matrix_diff (results )
169174 df = pandas .DataFrame (diffs , columns = labels , index = labels )
170175 print (f"------ has_cast={ has_cast } , dtype={ dtype } , device={ device !r} , max(b)={ b .max ()} " )
@@ -176,18 +181,32 @@ def matrix_diff(tensors):
176181#
177182# bias value vs discrepancies
178183# ===========================
179-
180-
181- m1 , m2 = results [0 :2 ]
182- diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
183- print (f"max(diff)={ diff .max ()} " )
184-
185- fig , ax = plt .subplots (1 , 1 , figsize = (5 , 3 ))
186- ax .plot (B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * 0.5 ).tolist (), "." )
187- ax .set_title ("Discrepancies (y) VS Bias (x)" )
184+ #
185+ # Let's compare GemmOnly (so bias is included) and Gemm+Add.
186+
187+ i , j = 1 , - 1
188+ labs = labels [i ], labels [j ]
189+
190+ fig , ax = plt .subplots (len (all_results ), 2 , figsize = (8 , 2.5 * len (results )))
191+ for pos , ((device , dtype ), results ) in enumerate (all_results .items ()):
192+ m1 , m2 = results [i ], results [j ]
193+ diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
194+ print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
195+ expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
196+ ax [pos , 0 ].plot (B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * expand ).tolist (), "." )
197+ ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " )
198+
199+ corr = matrix_diff (results )
200+ ax [pos , 1 ].imshow (corr , cmap = "Blues" , vmin = 0 , vmax = corr .max ())
201+ # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
202+ ax [pos , 1 ].set_xticks (range (len (labels )), labels , rotation = 45 )
203+ ax [pos , 1 ].set_yticks (range (len (labels )), labels )
204+ ax [pos , 1 ].set_title (f"max={ diff .max ()} " )
205+ fig .tight_layout ()
188206fig .savefig ("plot_gemm_or_matmul_add.png" )
189207
190208# %%
191209# Discrepancies do not happen all the time but it is very likely to happen.
192- # Fused Gemm should be avoided when the bias is very different from the multiplied
193- # matrix and avoided in the generic case.
210+ # The use of Gemm with a bias not null should be used when torch is doing
211+ # the same and it seems to depend on the type as well.
212+ # The difference is even higher for bfloat16.
0 commit comments