@@ -145,7 +145,7 @@ def matrix_diff(tensors):
145145# a lot higher.
146146
147147B = (torch .arange (512 , dtype = torch .float32 ) + 1 ) / 512 * 16384
148- labels = ["linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
148+ labels = ["F. linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
149149all_results = {}
150150
151151for itype , dtype , device in [
@@ -187,28 +187,58 @@ def matrix_diff(tensors):
187187# bias value vs discrepancies
188188# ===========================
189189#
190- # Let's compare GemmOnly (so bias is included) and Gemm+Add.
191-
192- i , j = 1 , - 1
193- labs = labels [i ], labels [j ]
194-
195- fig , ax = plt .subplots (len (all_results ), 2 , figsize = (8 , 2.5 * len (results )))
196- for pos , ((device , dtype ), results ) in enumerate (all_results .items ()):
197- m1 , m2 = results [i ], results [j ]
198- diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
199- print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
200- expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
201- ax [pos , 0 ].plot (B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * expand ).tolist (), "." )
202- ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " )
203-
204- corr = matrix_diff (results )
205- ax [pos , 1 ].imshow (corr , cmap = "Blues" , vmin = 0 , vmax = corr .max ())
206- # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
207- ax [pos , 1 ].set_xticks (range (len (labels )), labels , rotation = 45 )
208- ax [pos , 1 ].set_yticks (range (len (labels )), labels )
209- ax [pos , 1 ].set_title (f"max={ diff .max ()} " )
190+ # Let's compare torch linear with GemmOnly.
191+
192+
193+ def make_figure_axis (all_results , i , j ):
194+ labs = labels [i ], labels [j ]
195+ fig , ax = plt .subplots (len (all_results ), 2 , figsize = (12 , 4 * len (all_results )))
196+ for pos , ((device , dtype ), results ) in enumerate (all_results .items ()):
197+ m1 , m2 = results [i ], results [j ]
198+ diff = torch .abs (m1 .to (torch .float32 ) - m2 .to (torch .float32 )).max (dim = 0 )[0 ]
199+ print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
200+ expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
201+ ax [pos , 0 ].plot (
202+ B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * expand ).tolist (), "."
203+ )
204+ ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " , fontsize = 10 )
205+
206+ corr = matrix_diff (results )
207+ ax [pos , 1 ].imshow (corr , cmap = "Wistia" , vmin = 0 , vmax = corr .max ())
208+ # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}')
209+ ax [pos , 1 ].set_xticks (range (len (labels )), labels , rotation = 45 , ha = "right" , fontsize = 10 )
210+ ax [pos , 1 ].set_yticks (range (len (labels )), labels , fontsize = 10 )
211+ ax [pos , 1 ].set_title (f"max={ diff .max ():1.2g} " , fontsize = 10 )
212+ for _i in range (corr .shape [0 ]):
213+ for _j in range (corr .shape [1 ]):
214+ ax [pos , 1 ].text (
215+ _j ,
216+ _i ,
217+ f"{ corr [_i , _j ]:1.1g} " ,
218+ ha = "center" ,
219+ va = "center" ,
220+ color = "black" ,
221+ fontsize = 8 ,
222+ )
223+ fig .suptitle (
224+ f"Left column: discrepancies { labs [0 ]} VS { labs [1 ]} \n "
225+ f"Right column: max absolute error, accross all configuration\n "
226+ f"white is good, orange is not"
227+ )
228+ return fig , ax
229+
230+
231+ fig , ax = make_figure_axis (all_results , 0 , 1 )
232+ fig .tight_layout ()
233+ fig .savefig ("plot_gemm_or_matmul_add1.png" )
234+
235+ # %%
236+ # Let's compare with ``a @ x + b``.
237+
238+ fig , ax = make_figure_axis (all_results , - 1 , 1 )
210239fig .tight_layout ()
211- fig .savefig ("plot_gemm_or_matmul_add.png" )
240+ fig .savefig ("plot_gemm_or_matmul_add2.png" )
241+
212242
213243# %%
214244# Discrepancies do not happen all the time but it is very likely to happen.
0 commit comments