77from collections import namedtuple
88
99import awkward
10+ import cupy
1011import numpy
1112
1213# Python 2 and 3 compatibility
2324
2425MaybeSumSlice = namedtuple ("MaybeSumSlice" , ["start" , "stop" , "sum" ])
2526
27+ _replace_nans = cupy .ElementwiseKernel ("T v" , "T x" , "x = isnan(x)?v:x" , "replace_nans" )
28+
29+ _clip_bins = cupy .ElementwiseKernel (
30+ "T Nbins, T lo, T hi, T id" ,
31+ "T idx" ,
32+ """
33+ const T floored = floor((id - lo) * float(Nbins) / (hi - lo)) + 1;
34+ idx = floored < 0 ? 0 : floored;
35+ idx = floored > Nbins ? Nbins + 1 : floored;
36+ """ ,
37+ "clip_bins" ,
38+ )
39+
2640
2741def assemble_blocks (array , ndslice , depth = 0 ):
2842 """
@@ -481,8 +495,8 @@ def __init__(self, name, label, n_or_arr, lo=None, hi=None):
481495 self ._lo = self ._bins [0 ]
482496 self ._hi = self ._bins [- 1 ]
483497 # to make searchsorted differentiate inf from nan
484- self ._bins = numpy .append (self ._bins , numpy .inf )
485- self ._interval_bins = numpy .r_ [- numpy .inf , self ._bins , numpy .nan ]
498+ self ._bins = cupy .append (self ._bins , cupy .inf )
499+ self ._interval_bins = cupy .r_ [- cupy .inf , self ._bins , cupy .nan ]
486500 self ._bin_names = numpy .full (self ._interval_bins [:- 1 ].size , None )
487501 elif isinstance (n_or_arr , numbers .Integral ):
488502 if lo is None or hi is None :
@@ -493,11 +507,11 @@ def __init__(self, name, label, n_or_arr, lo=None, hi=None):
493507 self ._lo = lo
494508 self ._hi = hi
495509 self ._bins = n_or_arr
496- self ._interval_bins = numpy .r_ [
497- - numpy .inf ,
498- numpy .linspace (self ._lo , self ._hi , self ._bins + 1 ),
499- numpy .inf ,
500- numpy .nan ,
510+ self ._interval_bins = cupy .r_ [
511+ - cupy .inf ,
512+ cupy .linspace (self ._lo , self ._hi , self ._bins + 1 ),
513+ cupy .inf ,
514+ cupy .nan ,
501515 ]
502516 self ._bin_names = numpy .full (self ._interval_bins [:- 1 ].size , None )
503517 else :
@@ -528,7 +542,7 @@ def __setstate__(self, d):
528542 if "_intervals" in d : # convert old hists to new serialization format
529543 _old_intervals = d .pop ("_intervals" )
530544 interval_bins = [i ._lo for i in _old_intervals ] + [_old_intervals [- 1 ]._hi ]
531- d ["_interval_bins" ] = numpy .array (interval_bins )
545+ d ["_interval_bins" ] = cupy .array (interval_bins )
532546 d ["_bin_names" ] = numpy .array (
533547 [interval ._label for interval in _old_intervals ]
534548 )
@@ -548,31 +562,36 @@ def index(self, identifier):
548562 Returns an integer corresponding to the index in the axis where the histogram would be filled.
549563 The integer range includes flow bins: ``0 = underflow, n+1 = overflow, n+2 = nanflow``
550564 """
551- isarray = isinstance (identifier , (awkward .Array , numpy .ndarray ))
565+ isarray = isinstance (identifier , (awkward .Array , cupy . ndarray , numpy .ndarray ))
552566 if isarray or isinstance (identifier , numbers .Number ):
553- if isarray :
554- identifier = numpy .asarray (identifier )
567+ identifier = awkward .to_cupy (identifier ) # cupy.asarray(identifier)
555568 if self ._uniform :
556- idx = numpy .clip (
557- numpy .floor (
558- (identifier - self ._lo )
559- * float (self ._bins )
560- / (self ._hi - self ._lo )
569+ idx = None
570+ if isarray :
571+ idx = cupy .zeros_like (identifier )
572+ _clip_bins (float (self ._bins ), self ._lo , self ._hi , identifier , idx )
573+ else :
574+ idx = numpy .clip (
575+ numpy .floor (
576+ (identifier - self ._lo )
577+ * float (self ._bins )
578+ / (self ._hi - self ._lo )
579+ )
580+ + 1 ,
581+ 0 ,
582+ self ._bins + 1 ,
561583 )
562- + 1 ,
563- 0 ,
564- self ._bins + 1 ,
565- )
566- if isinstance (idx , numpy .ndarray ):
567- idx [numpy .isnan (idx )] = self .size - 1
584+
585+ if isinstance (idx , (cupy .ndarray , numpy .ndarray )):
586+ _replace_nans (self .size - 1 , idx )
568587 idx = idx .astype (int )
569588 elif numpy .isnan (idx ):
570589 idx = self .size - 1
571590 else :
572591 idx = int (idx )
573592 return idx
574593 else :
575- return numpy .searchsorted (self ._bins , identifier , side = "right" )
594+ return cupy .searchsorted (self ._bins , identifier , side = "right" )
576595 elif isinstance (identifier , Interval ):
577596 if identifier .nan ():
578597 return self .size - 1
@@ -1095,7 +1114,9 @@ def __getitem__(self, keys):
10951114 dense_idx = tuple (dense_idx )
10961115
10971116 def dense_op (array ):
1098- return numpy .block (assemble_blocks (array , dense_idx ))
1117+ as_numpy = array .get ()
1118+ blocked = numpy .block (assemble_blocks (as_numpy , dense_idx ))
1119+ return cupy .asarray (blocked )
10991120
11001121 out = Hist (self ._label , * new_dims , dtype = self ._dtype )
11011122 if self ._sumw2 is not None :
@@ -1139,10 +1160,10 @@ def fill(self, **values):
11391160
11401161 """
11411162 weight = values .pop ("weight" , None )
1142- if isinstance (weight , (awkward .Array , numpy .ndarray )):
1143- weight = numpy . asarray (weight )
1163+ if isinstance (weight , (awkward .Array , cupy . ndarray , numpy .ndarray )):
1164+ weight = cupy . array (weight )
11441165 if isinstance (weight , numbers .Number ):
1145- weight = numpy .atleast_1d (weight )
1166+ weight = cupy .atleast_1d (weight )
11461167 if not all (d .name in values for d in self ._axes ):
11471168 missing = ", " .join (d .name for d in self ._axes if d .name not in values )
11481169 raise ValueError (
@@ -1161,44 +1182,46 @@ def fill(self, **values):
11611182
11621183 sparse_key = tuple (d .index (values [d .name ]) for d in self .sparse_axes ())
11631184 if sparse_key not in self ._sumw :
1164- self ._sumw [sparse_key ] = numpy .zeros (
1185+ self ._sumw [sparse_key ] = cupy .zeros (
11651186 shape = self ._dense_shape , dtype = self ._dtype
11661187 )
11671188 if self ._sumw2 is not None :
1168- self ._sumw2 [sparse_key ] = numpy .zeros (
1189+ self ._sumw2 [sparse_key ] = cupy .zeros (
11691190 shape = self ._dense_shape , dtype = self ._dtype
11701191 )
11711192
11721193 if self .dense_dim () > 0 :
11731194 dense_indices = tuple (
1174- d .index (values [d .name ]) for d in self ._axes if isinstance (d , DenseAxis )
1195+ cupy .asarray (d .index (values [d .name ]))
1196+ for d in self ._axes
1197+ if isinstance (d , DenseAxis )
11751198 )
1176- xy = numpy .atleast_1d (
1177- numpy .ravel_multi_index (dense_indices , self ._dense_shape )
1199+ xy = cupy .atleast_1d (
1200+ cupy .ravel_multi_index (dense_indices , self ._dense_shape )
11781201 )
11791202 if weight is not None :
1180- self ._sumw [sparse_key ][:] += numpy .bincount (
1203+ self ._sumw [sparse_key ][:] += cupy .bincount (
11811204 xy , weights = weight , minlength = numpy .array (self ._dense_shape ).prod ()
11821205 ).reshape (self ._dense_shape )
1183- self ._sumw2 [sparse_key ][:] += numpy .bincount (
1206+ self ._sumw2 [sparse_key ][:] += cupy .bincount (
11841207 xy ,
11851208 weights = weight ** 2 ,
11861209 minlength = numpy .array (self ._dense_shape ).prod (),
11871210 ).reshape (self ._dense_shape )
11881211 else :
1189- self ._sumw [sparse_key ][:] += numpy .bincount (
1212+ self ._sumw [sparse_key ][:] += cupy .bincount (
11901213 xy , weights = None , minlength = numpy .array (self ._dense_shape ).prod ()
11911214 ).reshape (self ._dense_shape )
11921215 if self ._sumw2 is not None :
1193- self ._sumw2 [sparse_key ][:] += numpy .bincount (
1216+ self ._sumw2 [sparse_key ][:] += cupy .bincount (
11941217 xy ,
11951218 weights = None ,
11961219 minlength = numpy .array (self ._dense_shape ).prod (),
11971220 ).reshape (self ._dense_shape )
11981221 else :
11991222 if weight is not None :
1200- self ._sumw [sparse_key ] += numpy .sum (weight )
1201- self ._sumw2 [sparse_key ] += numpy .sum (weight ** 2 )
1223+ self ._sumw [sparse_key ] += cupy .sum (weight )
1224+ self ._sumw2 [sparse_key ] += cupy .sum (weight ** 2 )
12021225 else :
12031226 self ._sumw [sparse_key ] += 1.0
12041227 if self ._sumw2 is not None :
@@ -1604,14 +1627,14 @@ def expandkey(key):
16041627 for sparse_key , sumw in values .items ():
16051628 index = tuple (expandkey (sparse_key ))
16061629 view = out .view (flow = True )
1607- view [index ] = sumw
1630+ view [index ] = sumw . get ()
16081631 else :
16091632 values = self .values (sumw2 = True , overflow = "all" )
16101633 for sparse_key , (sumw , sumw2 ) in values .items ():
16111634 index = tuple (expandkey (sparse_key ))
16121635 view = out .view (flow = True )
1613- view [index ].value = sumw
1614- view [index ].variance = sumw2
1636+ view [index ].value = sumw . get ()
1637+ view [index ].variance = sumw2 . get ()
16151638
16161639 return out
16171640
0 commit comments