@@ -35,6 +35,13 @@ def __init__(self, name, typename, direction, role="default"):
3535 self .role = role
3636
3737
38+ no_role_kernels = [
39+ "awkward_NumpyArray_sort_asstrings_uint8" ,
40+ "awkward_argsort" ,
41+ "awkward_sort" ,
42+ ]
43+
44+
3845class Specification :
3946 def __init__ (self , templatized_kernel_name , spec , testdata , blacklisted ):
4047 self .templatized_kernel_name = templatized_kernel_name
@@ -51,6 +58,8 @@ def __init__(self, templatized_kernel_name, spec, testdata, blacklisted):
5158 )
5259 if blacklisted :
5360 self .tests = []
61+ elif templatized_kernel_name in no_role_kernels :
62+ self .tests = []
5463 else :
5564 self .tests = self .gettests (testdata )
5665
@@ -185,6 +194,7 @@ def gettests(self, testdata):
185194
186195def readspec ():
187196 specdict = {}
197+ specdict_unit = {}
188198 with open (os .path .join (CURRENT_DIR , ".." , "kernel-specification.yml" )) as f :
189199 loadfile = yaml .load (f , Loader = yaml .CSafeLoader )
190200
@@ -193,6 +203,13 @@ def readspec():
193203 data = json .load (f )["tests" ]
194204
195205 for spec in indspec :
206+ for childfunc in spec ["specializations" ]:
207+ specdict_unit [childfunc ["name" ]] = Specification (
208+ spec ["name" ],
209+ childfunc ,
210+ data ,
211+ not spec ["automatic-tests" ],
212+ )
196213 if "def " in spec ["definition" ]:
197214 for childfunc in spec ["specializations" ]:
198215 specdict [childfunc ["name" ]] = Specification (
@@ -201,7 +218,7 @@ def readspec():
201218 data ,
202219 not spec ["automatic-tests" ],
203220 )
204- return specdict
221+ return specdict , specdict_unit
205222
206223
207224def getdtypes (args ):
@@ -215,6 +232,8 @@ def getdtypes(args):
215232 typename = typename + "_"
216233 if count == 1 :
217234 dtypes .append ("cupy." + typename )
235+ elif count == 2 :
236+ dtypes .append ("cupy." + typename )
218237 return dtypes
219238
220239
@@ -239,7 +258,12 @@ def checkintrange(test_args, error, args):
239258 if "int" in typename or "uint" in typename :
240259 dtype = gettypename (typename )
241260 min_val , max_val = np .iinfo (dtype ).min , np .iinfo (dtype ).max
242- if "List" in typename :
261+ if "List[List" in typename :
262+ for row in val :
263+ for data in row :
264+ if not (min_val <= data <= max_val ):
265+ flag = False
266+ elif "List" in typename :
243267 for data in val :
244268 if not (min_val <= data <= max_val ):
245269 flag = False
@@ -687,6 +711,8 @@ def gencpuunittests(specdict):
687711 "awkward_RegularArray_getitem_next_range" ,
688712 "awkward_RegularArray_getitem_next_range_spreadadvanced" ,
689713 "awkward_RegularArray_getitem_next_array" ,
714+ "awkward_RegularArray_reduce_local_nextparents" ,
715+ "awkward_RegularArray_reduce_nonlocal_preparenext" ,
690716 "awkward_missing_repeat" ,
691717 "awkward_RegularArray_getitem_jagged_expand" ,
692718 "awkward_ListArray_getitem_jagged_expand" ,
@@ -733,6 +759,7 @@ def gencpuunittests(specdict):
733759 "awkward_reduce_sum_bool" ,
734760 "awkward_reduce_prod_bool" ,
735761 "awkward_reduce_countnonzero" ,
762+ "awkward_sorting_ranges_length" ,
736763]
737764
738765
@@ -973,8 +1000,12 @@ def gencudaunittests(specdict):
9731000 )
9741001 )
9751002 elif count == 2 :
976- raise NotImplementedError
977-
1003+ f .write (
1004+ " " * 4
1005+ + "{} = cupy.array({}, dtype=cupy.{})\n " .format (
1006+ arg , val , typename
1007+ )
1008+ )
9781009 cuda_string = (
9791010 "funcC = cupy_backend['"
9801011 + spec .templatized_kernel_name
@@ -1075,10 +1106,10 @@ def evalkernels():
10751106if __name__ == "__main__" :
10761107 genpykernels ()
10771108 evalkernels ()
1078- specdict = readspec ()
1109+ specdict , specdict_unit = readspec ()
10791110 genspectests (specdict )
10801111 gencpukerneltests (specdict )
1081- gencpuunittests (specdict )
1112+ gencpuunittests (specdict_unit )
10821113 genunittests ()
10831114 gencudakerneltests (specdict )
1084- gencudaunittests (specdict )
1115+ gencudaunittests (specdict_unit )
0 commit comments