66:ref:`l-plot-parallelized-reduction`, reduction operations
77are sensitive to parallelization.
88
9- We consider a small model including a layer normalization
10- followed by a matrix multiplication and we show that replacing
11- a kernel by another one may significantly impact the output.
9+ Methodology
10+ +++++++++++
11+
12+ We consider a simple model with a LayerNormalization followed by a MatMul.
13+ Each operator can be run with :epkg:`onnxruntime` or :epkg:`pytorch`.
14+ We compare the four combinations.
1215
1316The model
1417+++++++++
1518"""
1619
1720import itertools
21+ import numpy as np
1822import pandas
1923import onnx
2024import onnx .helper as oh
2125import onnxruntime
2226import torch
2327from onnx_array_api .plotting .graphviz_helper import plot_dot
24- from onnx_diagnostic .doc import rotate_align , save_fig
28+ from onnx_diagnostic .doc import rotate_align , save_fig , plot_histogram , title
2529from onnx_diagnostic .ext_test_case import unit_test_going
2630from onnx_diagnostic .helpers import max_diff , string_diff , string_type
2731from onnx_diagnostic .helpers .onnx_helper import onnx_dtype_name , onnx_dtype_to_np_dtype
@@ -80,6 +84,8 @@ def make_feeds(last_dim: int):
8084
8185
8286def cast_feeds (itype , provider , feeds ):
87+ ttype = onnx_dtype_to_torch_dtype (itype )
88+ np_dtype = onnx_dtype_to_np_dtype (itype )
8389 np_feeds = {k : v .detach ().numpy () for k , v in feeds .items ()}
8490 if provider == "CUDA" :
8591 if not torch .cuda .is_available ():
@@ -102,8 +108,6 @@ def cast_feeds(itype, provider, feeds):
102108baseline = {}
103109
104110for provider , itype in itertools .product (["CPU" , "CUDA" ], [TFLOAT , TFLOAT16 ]):
105- ttype = onnx_dtype_to_torch_dtype (itype )
106- np_dtype = onnx_dtype_to_np_dtype (itype )
107111 tch_feeds , ort_feeds = cast_feeds (itype , provider , feeds )
108112 if tch_feeds is None :
109113 continue
@@ -156,6 +160,22 @@ def cast_feeds(itype, provider, feeds):
156160# Let's see which operator is responsible for them,
157161# *LayerNormalization* or *MatMul*.
158162
163+ # %%
164+ # Distribution of the results
165+ # +++++++++++++++++++++++++++
166+
167+ tensor = baseline [TFLOAT16 , "CPU" , "ort" ][0 ].ravel ().astype (np .float32 )
168+ print (pandas .DataFrame ({"expected" : tensor }).describe ())
169+
170+ # %%
171+ # Histogram.
172+
173+ save_fig (
174+ title (plot_histogram (tensor ), "Distribution of the computed results" ),
175+ "plot_layer_norm_discrepancies_hist.png" ,
176+ )
177+
178+
159179# %%
160180# The discrepancies come from?
161181# ++++++++++++++++++++++++++++
@@ -165,19 +185,18 @@ def cast_feeds(itype, provider, feeds):
165185data = []
166186
167187for mod , provider , itype in itertools .product (
168- ["ORT-TORCH" , "TORCH-ORT" ], ["CPU" , "CUDA" ], [TFLOAT , TFLOAT16 ]
188+ ["ORT-ORT" , "ORT- TORCH" , "TORCH-ORT" , "TORCH-TORCH " ], ["CPU" , "CUDA" ], [TFLOAT , TFLOAT16 ]
169189):
170190 ttype = onnx_dtype_to_torch_dtype (itype )
171191 np_dtype = onnx_dtype_to_np_dtype (itype )
172192 tch_feeds , _ = cast_feeds (itype , provider , feeds )
173193 if tch_feeds is None :
174194 continue
175195
196+ ker1 , ker2 = mod .split ("-" )
176197 custom_kernels = (
177- {("" , "LayerNormalization" ): LayerNormalizationOrt }
178- if mod == "ORT-TORCH"
179- else {("" , "MatMul" ): MatMulOrt }
180- )
198+ {("" , "LayerNormalization" ): LayerNormalizationOrt } if ker1 == "ORT" else {}
199+ ) | ({("" , "MatMul" ): MatMulOrt } if ker2 == "ORT" else {})
181200
182201 model = get_model (itype )
183202 print ()
@@ -206,7 +225,7 @@ def cast_feeds(itype, provider, feeds):
206225 )
207226
208227# %%
209- df = pandas .DataFrame (data ).set_index (["model " , "provider" , "dtype " ])
228+ df = pandas .DataFrame (data ).set_index (["dtype " , "provider" , "model " ])
210229df = df .sort_index ()
211230print (df )
212231
@@ -216,8 +235,17 @@ def cast_feeds(itype, provider, feeds):
216235save_fig (
217236 rotate_align (
218237 df [["diff_ort" , "diff_torch" ]].plot .bar (
219- title = "ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B"
238+ title = "ORT/Torch or Torch/ORT for LayerNorm(X) @ W + B" ,
239+ figsize = (10 , 4 ),
220240 )
221241 ),
222242 "plot_layer_norm_discrepancies_2.png" ,
223243)
244+
245+ # %%
246+ # Conclusion
247+ # ++++++++++
248+ #
249+ # :epkg:`torch` seems able to replicate the same results if the same computation
250+ # is run multiple times. :epkg:`onnxruntime` is only able to do that on CUDA.
251+ # With float16 and CUDA, LayerNormalization seems to introduce some discrepancies.
0 commit comments