3333__contact__ = "jerome.kieffer@esrf.eu"
3434__license__ = "MIT"
3535__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
36- __date__ = "07 /11/2024"
36+ __date__ = "12 /11/2024"
3737
3838import logging
3939import numpy
5151
5252@unittest .skipIf (UtilsTest .opencl is False , "User request to skip OpenCL tests" )
5353@unittest .skipUnless (ocl , "PyOpenCl is missing" )
54- class TestReduction (unittest .TestCase ):
54+ class TestGroupFunction (unittest .TestCase ):
5555
5656 @classmethod
5757 def setUpClass (cls ):
58- super (TestReduction , cls ).setUpClass ()
58+ super (TestGroupFunction , cls ).setUpClass ()
5959
6060 if ocl :
6161 cls .ctx = ocl .create_context ()
@@ -74,8 +74,8 @@ def setUpClass(cls):
7474
7575 @classmethod
7676 def tearDownClass (cls ):
77- super (TestReduction , cls ).tearDownClass ()
78- print ("Maximum valid workgroup size %s on device %s" % (cls .max_valid_wg , cls .ctx .devices [0 ]))
77+ super (TestGroupFunction , cls ).tearDownClass ()
78+ # print("Maximum valid workgroup size %s on device %s" % (cls.max_valid_wg, cls.ctx.devices[0]))
7979 cls .ctx = None
8080 cls .queue = None
8181
@@ -88,8 +88,8 @@ def setUp(self):
8888 self .data_d = pyopencl .array .to_device (self .queue , self .data )
8989 self .sum_d = pyopencl .array .zeros_like (self .data_d )
9090 self .program = pyopencl .Program (self .ctx , get_opencl_code ("pyfai:openCL/collective/reduction.cl" )+
91- get_opencl_code ("pyfai:openCL/collective/scan.cl" )
92- ).build ()
91+ get_opencl_code ("pyfai:openCL/collective/scan.cl" )+
92+ get_opencl_code ( "pyfai:openCL/collective/comb_sort.cl" ) ).build ()
9393
9494 def tearDown (self ):
9595 self .img = self .data = None
@@ -230,10 +230,94 @@ def test_Blelloch_multipass(self):
230230 logger .info ("Wg: %s result: cumsum good: %s" , wg , good )
231231 self .assertTrue (good , "calculation is correct for WG=%s" % wg )
232232
233+
234+ @unittest .skipUnless (ocl , "pyopencl is missing" )
235+ def test_sort (self ):
236+ """
237+ tests the sort of floating points in a workgroup
238+ """
239+ data = numpy .arange (self .shape ).astype (numpy .float32 )
240+ numpy .random .shuffle (data )
241+ data_d = pyopencl .array .to_device (self .queue , data )
242+
243+ maxi = int (round (numpy .log2 (self .shape )))+ 1
244+ for i in range (5 ,maxi ):
245+ wg = 1 << i
246+
247+ ref = data .reshape ((- 1 , wg ))
248+ positions = ((numpy .arange (ref .shape [0 ])+ 1 )* wg ).astype (numpy .int32 )
249+ positions_d = pyopencl .array .to_device (self .queue , positions )
250+ data_d = pyopencl .array .to_device (self .queue , data )
251+ # print(ref.shape, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), positions)
252+ try :
253+ evt = self .program .test_combsort_float (self .queue , (ref .shape [0 ],min (wg , self .max_valid_wg )), (1 , min (wg , self .max_valid_wg )),
254+ data_d .data ,
255+ positions_d .data ,
256+ pyopencl .LocalMemory (4 * min (wg , self .max_valid_wg )))
257+ evt .wait ()
258+ except Exception as error :
259+ logger .error ("Error %s on WG=%s: test_sort" , error , wg )
260+ break
261+ else :
262+ res = data_d .get ()
263+ ref = numpy .sort (ref )
264+ good = numpy .allclose (res , ref .ravel ())
265+ logger .info ("Wg: %s result: sort OK %s" , wg , good )
266+ if not good :
267+ print (res .reshape (ref .shape ))
268+ print (ref )
269+ print (numpy .where (res .reshape (ref .shape )- ref ))
270+
271+ self .assertTrue (good , "calculation is correct for WG=%s" % wg )
272+
273+ @unittest .skipUnless (ocl , "pyopencl is missing" )
274+ def test_sort4 (self ):
275+ """
276+ tests the sort of floating points in a workgroup
277+ """
278+ data = numpy .arange (self .shape ).astype (numpy .float32 )
279+ data = numpy .outer (data , numpy .ones (4 , numpy .float32 )).view (numpy .dtype ([("s0" ,"<f4" ),("s1" ,"<f4" ),("s2" ,"<f4" ),("s3" ,"<f4" )]))
280+ numpy .random .shuffle (data )
281+ data_d = pyopencl .array .to_device (self .queue , data )
282+
283+ maxi = int (round (numpy .log2 (self .shape )))+ 1
284+ for i in range (5 ,maxi ):
285+ wg = 1 << i
286+
287+ ref = data .reshape ((- 1 , wg ))
288+ positions = ((numpy .arange (ref .shape [0 ])+ 1 )* wg ).astype (numpy .int32 )
289+ positions_d = pyopencl .array .to_device (self .queue , positions )
290+ data_d = pyopencl .array .to_device (self .queue , data )
291+ # print(ref.shape, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), positions)
292+ try :
293+ evt = self .program .test_combsort_float4 (self .queue , (ref .shape [0 ],min (wg , self .max_valid_wg )), (1 , min (wg , self .max_valid_wg )),
294+ data_d .data ,
295+ positions_d .data ,
296+ pyopencl .LocalMemory (4 * min (wg , self .max_valid_wg )))
297+ evt .wait ()
298+ except Exception as error :
299+ logger .error ("Error %s on WG=%s: test_sort" , error , wg )
300+ break
301+ else :
302+ res = data_d .get ()
303+ # print(res.dtype)
304+ ref = numpy .sort (ref , order = "s0" )
305+ # print(ref.dtype)
306+ good = numpy .allclose (res .view (numpy .float32 ).ravel (), ref .view (numpy .float32 ).ravel ())
307+ logger .info ("Wg: %s result: sort OK %s" , wg , good )
308+ if not good :
309+ print (res .reshape (ref .shape ))
310+ print (ref )
311+ print (numpy .where (res .reshape (ref .shape )- ref ))
312+
313+ self .assertTrue (good , "calculation is correct for WG=%s" % wg )
314+
315+
316+
233317def suite ():
234318 loader = unittest .defaultTestLoader .loadTestsFromTestCase
235319 testSuite = unittest .TestSuite ()
236- testSuite .addTest (loader (TestReduction ))
320+ testSuite .addTest (loader (TestGroupFunction ))
237321 return testSuite
238322
239323
0 commit comments