1010 ExtTestCase ,
1111 ignore_warnings ,
1212 requires_cuda ,
13+ requires_onnxruntime ,
1314)
1415
1516
@@ -18,7 +19,9 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
1819 import onnxruntime
1920 from onnxruntime import InferenceSession , SessionOptions
2021
21- op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND"
22+ op_type = (
23+ "ScatterElements" if "ScatterElements" in str (expected_names ) else "ScatterND"
24+ )
2225 ndim = 2 if op_type == "ScatterElements" else 3
2326
2427 assert dtype in (np .float16 , np .float32 )
@@ -61,7 +64,10 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
6164 # onnxruntime might introduces some intermediate cast.
6265 if pv .Version (onnxruntime .__version__ ) <= pv .Version ("1.17.1" ):
6366 raise unittest .SkipTest ("float16 not supported on cpu" )
64- self .assertEqual (expected_names , names )
67+ if isinstance (expected_names , list ):
68+ self .assertEqual (names , expected_names )
69+ else :
70+ self .assertIn (names , expected_names )
6571
6672 sonx = str (onx ).replace (" " , "" ).replace ("\n " , "|" )
6773 sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype
@@ -126,24 +132,47 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
126132 (row .get ("args_provider" , None ), row .get ("args_op_name" , None ))
127133 )
128134 short_list = [(a , b ) for a , b in exe_providers if a is not None and b is not None ]
129- self .assertEqual (short_list , [("CUDAExecutionProvider" , o ) for o in expected_names ])
135+ if isinstance (expected_names , list ):
136+ self .assertEqual (
137+ short_list , [("CUDAExecutionProvider" , o ) for o in expected_names ]
138+ )
139+ else :
140+ self .assertIn (
141+ short_list ,
142+ tuple ([("CUDAExecutionProvider" , o ) for o in en ] for en in expected_names ),
143+ )
130144
131145 @unittest .skip ("https://github.com/sdpython/onnx-diagnostic/issues/240" )
132146 @requires_cuda ()
133147 @ignore_warnings (DeprecationWarning )
148+ @requires_onnxruntime ("1.23" )
134149 def test_scatterels_cuda (self ):
135- default_value = [
136- "Cast" ,
137- # "MemcpyToHost",
138- "ScatterElements" ,
139- # "MemcpyFromHost",
140- "Sub" ,
141- ]
150+ default_value = (
151+ [
152+ "Cast" ,
153+ # "MemcpyToHost",
154+ "ScatterElements" ,
155+ # "MemcpyFromHost",
156+ "Sub" ,
157+ ],
158+ [
159+ "Cast" ,
160+ "Cast" ,
161+ # "MemcpyToHost",
162+ "ScatterElements" ,
163+ # "MemcpyFromHost",
164+ "Sub" ,
165+ ],
166+ )
142167 expected = {
143168 (np .float32 , "none" ): default_value ,
144169 (np .float16 , "none" ): default_value ,
145170 (np .float32 , "add" ): default_value ,
146171 (np .float16 , "add" ): default_value ,
172+ (np .float32 , "min" ): default_value ,
173+ (np .float16 , "min" ): default_value ,
174+ (np .float32 , "max" ): default_value ,
175+ (np .float16 , "max" ): default_value ,
147176 }
148177 for opset , dtype , reduction in itertools .product (
149178 [16 , 18 ], [np .float32 , np .float16 ], ["none" , "add" , "min" , "max" ]
@@ -161,13 +190,23 @@ def test_scatterels_cuda(self):
161190 @requires_cuda ()
162191 @ignore_warnings (DeprecationWarning )
163192 def test_scatternd_cuda (self ):
164- default_value = [
165- "Cast" ,
166- # "MemcpyToHost",
167- "ScatterND" ,
168- # "MemcpyFromHost",
169- "Sub" ,
170- ]
193+ default_value = (
194+ [
195+ "Cast" ,
196+ # "MemcpyToHost",
197+ "ScatterND" ,
198+ # "MemcpyFromHost",
199+ "Sub" ,
200+ ],
201+ [
202+ "Cast" ,
203+ "Cast" ,
204+ # "MemcpyToHost",
205+ "ScatterND" ,
206+ # "MemcpyFromHost",
207+ "Sub" ,
208+ ],
209+ )
171210 expected = {
172211 (np .float32 , "none" ): default_value ,
173212 (np .float16 , "none" ): default_value ,
@@ -188,20 +227,30 @@ def test_scatternd_cuda(self):
188227
189228 @unittest .skip ("https://github.com/sdpython/onnx-diagnostic/issues/240" )
190229 @ignore_warnings (DeprecationWarning )
230+ @requires_onnxruntime ("1.23" )
191231 def test_scatterels_cpu (self ):
192232 default_value = [
193233 "Cast" ,
194234 "ScatterElements" ,
195235 "Sub" ,
196236 ]
197- default_value_16 = [
198- "Cast" ,
199- "Cast" ,
200- "ScatterElements" ,
201- "Cast" ,
202- "Sub" ,
203- "Cast" ,
204- ]
237+ default_value_16 = (
238+ [
239+ "Cast" ,
240+ "ScatterElements" ,
241+ "Cast" ,
242+ "Sub" ,
243+ "Cast" ,
244+ ],
245+ [
246+ "Cast" ,
247+ "Cast" ,
248+ "ScatterElements" ,
249+ "Cast" ,
250+ "Sub" ,
251+ "Cast" ,
252+ ],
253+ )
205254 expected = {
206255 (np .float32 , "none" ): default_value ,
207256 (np .float16 , "none" ): default_value_16 ,
@@ -222,20 +271,30 @@ def test_scatterels_cpu(self):
222271
223272 @unittest .skip ("https://github.com/sdpython/onnx-diagnostic/issues/240" )
224273 @ignore_warnings (DeprecationWarning )
274+ @requires_onnxruntime ("1.23" )
225275 def test_scatternd_cpu (self ):
226276 default_value = [
227277 "Cast" ,
228278 "ScatterND" ,
229279 "Sub" ,
230280 ]
231- default_value_16 = [
232- "Cast" ,
233- "Cast" ,
234- "ScatterND" ,
235- "Cast" ,
236- "Sub" ,
237- "Cast" ,
238- ]
281+ default_value_16 = (
282+ [
283+ "Cast" ,
284+ "ScatterND" ,
285+ "Cast" ,
286+ "Sub" ,
287+ "Cast" ,
288+ ],
289+ [
290+ "Cast" ,
291+ "Cast" ,
292+ "ScatterND" ,
293+ "Cast" ,
294+ "Sub" ,
295+ "Cast" ,
296+ ],
297+ )
239298 expected = {
240299 (np .float32 , "none" ): default_value ,
241300 (np .float16 , "none" ): default_value_16 ,
0 commit comments