@@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5353 oh .make_node ("Add" , ["mm" , "B" ], ["MatMulAdd" ]),
5454 oh .make_node ("FusedMatMul" , ["A" , "X" ], ["fmm" ], domain = "com.microsoft" ),
5555 oh .make_node ("Add" , ["fmm" , "B" ], ["FusedMatMulAdd" ]),
56+ oh .make_node ("Cast" , ["A" ], ["Afloat" ], to = onnx .TensorProto .FLOAT ),
57+ oh .make_node ("Cast" , ["B" ], ["Bfloat" ], to = onnx .TensorProto .FLOAT ),
58+ oh .make_node ("Cast" , ["X" ], ["Xfloat" ], to = onnx .TensorProto .FLOAT ),
59+ oh .make_node ("Gemm" , ["Afloat" , "Xfloat" ], ["gmmfloat" ]),
60+ oh .make_node ("Add" , ["gmmfloat" , "Bfloat" ], ["gemmaddfloat" ]),
61+ oh .make_node ("Cast" , ["gemmaddfloat" ], ["CastGemmAddCast" ], to = itype ),
62+ oh .make_node ("Gemm" , ["Afloat" , "Xfloat" , "Bfloat" ], ["GemmOnlyfloat" ]),
63+ oh .make_node ("Cast" , ["GemmOnlyfloat" ], ["CastGemmOnlyCast" ], to = itype ),
5664 ],
5765 "test" ,
5866 [
@@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
6573 oh .make_tensor_value_info ("GemmAdd" , itype , ["a" , "c" ]),
6674 oh .make_tensor_value_info ("FusedMatMulAdd" , itype , ["a" , "c" ]),
6775 oh .make_tensor_value_info ("MatMulAdd" , itype , ["a" , "c" ]),
76+ oh .make_tensor_value_info ("CastGemmAddCast" , itype , ["a" , "c" ]),
77+ oh .make_tensor_value_info ("CastGemmOnlyCast" , itype , ["a" , "c" ]),
6878 ],
6979 ),
7080 opset_imports = [oh .make_opsetid ("" , 22 )],
@@ -85,7 +95,7 @@ def matrix_diff(tensors):
8595dtype = np .float16
8696model = make_model_gemm (itype )
8797
88- A = np .random .randn (512 , 256 ).astype (dtype )
98+ A = np .random .randn (1280 , 256 ).astype (dtype )
8999X = np .random .randn (256 , 256 ).astype (dtype )
90100B = np .random .randn (256 ).astype (dtype )
91101feeds = dict (A = A , X = X , B = B )
@@ -112,9 +122,9 @@ def matrix_diff(tensors):
112122# %%
113123# Let's try with CUDA and float32 if it is available.
114124
115- A = torch .randn ((512 , 512 ), dtype = torch .float32 )
116- X = torch .randn ((512 , 512 ), dtype = torch .float32 )
117- B = torch .randn ((512 ), dtype = torch .float32 )
125+ A = torch .randn ((1280 , 1280 ), dtype = torch .float32 )
126+ X = torch .randn ((1280 , 1280 ), dtype = torch .float32 )
127+ B = torch .randn ((1280 ), dtype = torch .float32 )
118128
119129for itype , dtype , device in [
120130 (onnx .TensorProto .FLOAT16 , torch .float16 , "cpu" ),
@@ -144,8 +154,10 @@ def matrix_diff(tensors):
144154# are similar to the others coefficients. What if we make them
145155# a lot higher.
146156
147- B = (torch .arange (512 , dtype = torch .float32 ) + 1 ) / 512 * 16384
148- labels = ["linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
157+ A = A / A .max ()
158+ X = X / X .max ()
159+ B = (torch .arange (1280 , dtype = torch .float32 ) + 1 ) / 1280 * 16
160+ labels = ["F.linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
149161all_results = {}
150162
151163for itype , dtype , device in [
@@ -187,28 +199,58 @@ def matrix_diff(tensors):
187199# bias value vs discrepancies
188200# ===========================
189201#
190- # Let's compare GemmOnly (so bias is included) and Gemm+Add.
191-
192- i , j = 1 , - 1
193- labs = labels [i ], labels [j ]
194-
195- fig , ax = plt .subplots (len (all_results ), 2 , figsize = (8 , 2.5 * len (results )))
196- for pos , ((device , dtype ), results ) in enumerate (all_results .items ()):
197- m1 , m2 = results [i ], results [j ]
198- diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
199- print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
200- expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
201- ax [pos , 0 ].plot (B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * expand ).tolist (), "." )
202- ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " )
203-
204- corr = matrix_diff (results )
205- ax [pos , 1 ].imshow (corr , cmap = "Blues" , vmin = 0 , vmax = corr .max ())
206- # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
207- ax [pos , 1 ].set_xticks (range (len (labels )), labels , rotation = 45 )
208- ax [pos , 1 ].set_yticks (range (len (labels )), labels )
209- ax [pos , 1 ].set_title (f"max={ diff .max ()} " )
202+ # Let's compare torch linear with GemmOnly.
203+
204+
205+ def make_figure_axis (all_results , i , j ):
206+ labs = labels [i ], labels [j ]
207+ fig , ax = plt .subplots (len (all_results ), 2 , figsize = (12 , 4 * len (all_results )))
208+ for pos , ((device , dtype ), results ) in enumerate (all_results .items ()):
209+ m1 , m2 = results [i ], results [j ]
210+ diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
211+ print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
212+ expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
213+ ax [pos , 0 ].plot (
214+ B .tolist (), (diff .detach ().cpu () + torch .rand (1280 ) * expand ).tolist (), "."
215+ )
216+ ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " , fontsize = 10 )
217+
218+ corr = matrix_diff (results )
219+ ax [pos , 1 ].imshow (corr , cmap = "Wistia" , vmin = 0 , vmax = corr .max ())
220+ # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
221+ ax [pos , 1 ].set_xticks (range (len (labels )), labels , rotation = 45 , ha = "right" , fontsize = 10 )
222+ ax [pos , 1 ].set_yticks (range (len (labels )), labels , fontsize = 10 )
223+ ax [pos , 1 ].set_title (f"max={ diff .max ():1.2g} " , fontsize = 10 )
224+ for _i in range (corr .shape [0 ]):
225+ for _j in range (corr .shape [1 ]):
226+ ax [pos , 1 ].text (
227+ _j ,
228+ _i ,
229+ f"{ corr [_i , _j ]:1.1g} " ,
230+ ha = "center" ,
231+ va = "center" ,
232+ color = "black" ,
233+ fontsize = 8 ,
234+ )
235+ fig .suptitle (
236+ f"Left column: discrepancies { labs [0 ]} VS { labs [1 ]} \n "
237+ f"Right column: max absolute error, across all configuration\n "
238+ f"white is good, orange is not"
239+ )
240+ return fig , ax
241+
242+
243+ fig , ax = make_figure_axis (all_results , 0 , 1 )
210244fig .tight_layout ()
211- fig .savefig ("plot_gemm_or_matmul_add.png" )
245+ fig .savefig ("plot_gemm_or_matmul_add1.png" )
246+
247+ # %%
248+ # Let's compare with ``A @ X + B``.
249+
250+ fig , ax = make_figure_axis (all_results , - 1 , 1 )
251+ fig .tight_layout ()
252+ fig .savefig ("plot_gemm_or_matmul_add2.png" )
253+
212254
213255# %%
214256# Discrepancies do not happen all the time but it is very likely to happen.
0 commit comments