@@ -19,7 +19,9 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
1919 import onnxruntime
2020 from onnxruntime import InferenceSession , SessionOptions
2121
22- op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND"
22+ op_type = (
23+ "ScatterElements" if "ScatterElements" in str (expected_names ) else "ScatterND"
24+ )
2325 ndim = 2 if op_type == "ScatterElements" else 3
2426
2527 assert dtype in (np .float16 , np .float32 )
@@ -62,7 +64,10 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
6264 # onnxruntime might introduces some intermediate cast.
6365 if pv .Version (onnxruntime .__version__ ) <= pv .Version ("1.17.1" ):
6466 raise unittest .SkipTest ("float16 not supported on cpu" )
65- self .assertEqual (expected_names , names )
67+ if isinstance (expected_names , list ):
68+ self .assertEqual (names , expected_names )
69+ else :
70+ self .assertIn (names , expected_names )
6671
6772 sonx = str (onx ).replace (" " , "" ).replace ("\n " , "|" )
6873 sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype
@@ -127,19 +132,37 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
127132 (row .get ("args_provider" , None ), row .get ("args_op_name" , None ))
128133 )
129134 short_list = [(a , b ) for a , b in exe_providers if a is not None and b is not None ]
130- self .assertEqual (short_list , [("CUDAExecutionProvider" , o ) for o in expected_names ])
135+ if isinstance (expected_names , list ):
136+ self .assertEqual (
137+ short_list , [("CUDAExecutionProvider" , o ) for o in expected_names ]
138+ )
139+ else :
140+ self .assertIn (
141+ short_list ,
142+ tuple ([("CUDAExecutionProvider" , o ) for o in en ] for en in expected_names ),
143+ )
131144
132145 @requires_cuda ()
133146 @ignore_warnings (DeprecationWarning )
134147 @requires_onnxruntime ("1.23" )
135148 def test_scatterels_cuda (self ):
136- default_value = [
137- "Cast" ,
138- # "MemcpyToHost",
139- "ScatterElements" ,
140- # "MemcpyFromHost",
141- "Sub" ,
142- ]
149+ default_value = (
150+ [
151+ "Cast" ,
152+ # "MemcpyToHost",
153+ "ScatterElements" ,
154+ # "MemcpyFromHost",
155+ "Sub" ,
156+ ],
157+ [
158+ "Cast" ,
159+ "Cast" ,
160+ # "MemcpyToHost",
161+ "ScatterElements" ,
162+ # "MemcpyFromHost",
163+ "Sub" ,
164+ ],
165+ )
143166 expected = {
144167 (np .float32 , "none" ): default_value ,
145168 (np .float16 , "none" ): default_value ,
@@ -165,13 +188,23 @@ def test_scatterels_cuda(self):
165188 @requires_cuda ()
166189 @ignore_warnings (DeprecationWarning )
167190 def test_scatternd_cuda (self ):
168- default_value = [
169- "Cast" ,
170- # "MemcpyToHost",
171- "ScatterND" ,
172- # "MemcpyFromHost",
173- "Sub" ,
174- ]
191+ default_value = (
192+ [
193+ "Cast" ,
194+ # "MemcpyToHost",
195+ "ScatterND" ,
196+ # "MemcpyFromHost",
197+ "Sub" ,
198+ ],
199+ [
200+ "Cast" ,
201+ "Cast" ,
202+ # "MemcpyToHost",
203+ "ScatterND" ,
204+ # "MemcpyFromHost",
205+ "Sub" ,
206+ ],
207+ )
175208 expected = {
176209 (np .float32 , "none" ): default_value ,
177210 (np .float16 , "none" ): default_value ,
@@ -198,13 +231,23 @@ def test_scatterels_cpu(self):
198231 "ScatterElements" ,
199232 "Sub" ,
200233 ]
201- default_value_16 = [
202- "Cast" ,
203- "ScatterElements" ,
204- "Cast" ,
205- "Sub" ,
206- "Cast" ,
207- ]
234+ default_value_16 = (
235+ [
236+ "Cast" ,
237+ "ScatterElements" ,
238+ "Cast" ,
239+ "Sub" ,
240+ "Cast" ,
241+ ],
242+ [
243+ "Cast" ,
244+ "Cast" ,
245+ "ScatterElements" ,
246+ "Cast" ,
247+ "Sub" ,
248+ "Cast" ,
249+ ],
250+ )
208251 expected = {
209252 (np .float32 , "none" ): default_value ,
210253 (np .float16 , "none" ): default_value_16 ,
@@ -231,13 +274,23 @@ def test_scatternd_cpu(self):
231274 "ScatterND" ,
232275 "Sub" ,
233276 ]
234- default_value_16 = [
235- "Cast" ,
236- "ScatterND" ,
237- "Cast" ,
238- "Sub" ,
239- "Cast" ,
240- ]
277+ default_value_16 = (
278+ [
279+ "Cast" ,
280+ "ScatterND" ,
281+ "Cast" ,
282+ "Sub" ,
283+ "Cast" ,
284+ ],
285+ [
286+ "Cast" ,
287+ "Cast" ,
288+ "ScatterND" ,
289+ "Cast" ,
290+ "Sub" ,
291+ "Cast" ,
292+ ],
293+ )
241294 expected = {
242295 (np .float32 , "none" ): default_value ,
243296 (np .float16 , "none" ): default_value_16 ,
0 commit comments