@@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
5353 oh .make_node ("Add" , ["mm" , "B" ], ["MatMulAdd" ]),
5454 oh .make_node ("FusedMatMul" , ["A" , "X" ], ["fmm" ], domain = "com.microsoft" ),
5555 oh .make_node ("Add" , ["fmm" , "B" ], ["FusedMatMulAdd" ]),
56+ oh .make_node ("Cast" , ["A" ], ["Afloat" ], to = onnx .TensorProto .FLOAT ),
57+ oh .make_node ("Cast" , ["B" ], ["Bfloat" ], to = onnx .TensorProto .FLOAT ),
58+ oh .make_node ("Cast" , ["X" ], ["Xfloat" ], to = onnx .TensorProto .FLOAT ),
59+ oh .make_node ("Gemm" , ["Afloat" , "Xfloat" ], ["gmmfloat" ]),
60+ oh .make_node ("Add" , ["gmmfloat" , "Bfloat" ], ["gemmaddfloat" ]),
61+ oh .make_node ("Cast" , ["gemmaddfloat" ], ["CastGemmAddCast" ], to = itype ),
62+ oh .make_node ("Gemm" , ["Afloat" , "Xfloat" , "Bfloat" ], ["GemmOnlyfloat" ]),
63+ oh .make_node ("Cast" , ["GemmOnlyfloat" ], ["CastGemmOnlyCast" ], to = itype ),
5664 ],
5765 "test" ,
5866 [
@@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto:
6573 oh .make_tensor_value_info ("GemmAdd" , itype , ["a" , "c" ]),
6674 oh .make_tensor_value_info ("FusedMatMulAdd" , itype , ["a" , "c" ]),
6775 oh .make_tensor_value_info ("MatMulAdd" , itype , ["a" , "c" ]),
76+ oh .make_tensor_value_info ("CastGemmAddCast" , itype , ["a" , "c" ]),
77+ oh .make_tensor_value_info ("CastGemmOnlyCast" , itype , ["a" , "c" ]),
6878 ],
6979 ),
7080 opset_imports = [oh .make_opsetid ("" , 22 )],
@@ -85,7 +95,7 @@ def matrix_diff(tensors):
8595dtype = np .float16
8696model = make_model_gemm (itype )
8797
88- A = np .random .randn (512 , 256 ).astype (dtype )
98+ A = np .random .randn (1280 , 256 ).astype (dtype )
8999X = np .random .randn (256 , 256 ).astype (dtype )
90100B = np .random .randn (256 ).astype (dtype )
91101feeds = dict (A = A , X = X , B = B )
@@ -112,9 +122,9 @@ def matrix_diff(tensors):
112122# %%
113123# Let's try with CUDA and float32 if it is available.
114124
115- A = torch .randn ((512 , 512 ), dtype = torch .float32 )
116- X = torch .randn ((512 , 512 ), dtype = torch .float32 )
117- B = torch .randn ((512 ), dtype = torch .float32 )
125+ A = torch .randn ((1280 , 1280 ), dtype = torch .float32 )
126+ X = torch .randn ((1280 , 1280 ), dtype = torch .float32 )
127+ B = torch .randn ((1280 ), dtype = torch .float32 )
118128
119129for itype , dtype , device in [
120130 (onnx .TensorProto .FLOAT16 , torch .float16 , "cpu" ),
@@ -144,7 +154,9 @@ def matrix_diff(tensors):
144154# are similar to the others coefficients. What if we make them
145155# a lot higher.
146156
147- B = (torch .arange (512 , dtype = torch .float32 ) + 1 ) / 512 * 16384
157+ A = A / A .max ()
158+ X = X / X .max ()
159+ B = (torch .arange (1280 , dtype = torch .float32 ) + 1 ) / 1280 * 16
148160labels = ["F.linear" , * [o .name for o in model .graph .output ], "a @ x + b" ]
149161all_results = {}
150162
@@ -199,7 +211,7 @@ def make_figure_axis(all_results, i, j):
199211 print (f"labels={ labs } , { device } /{ dtype } : max(diff)={ diff .max ()} " )
200212 expand = 0.5 if diff .max () >= 1 else diff .max ().detach ().cpu () / 2
201213 ax [pos , 0 ].plot (
202- B .tolist (), (diff .detach ().cpu () + torch .rand (512 ) * expand ).tolist (), "."
214+ B .tolist (), (diff .detach ().cpu () + torch .rand (1280 ) * expand ).tolist (), "."
203215 )
204216 ax [pos , 0 ].set_title (f"{ labs [0 ]} -{ labs [1 ]} { device } /{ dtype } " , fontsize = 10 )
205217
0 commit comments