Skip to content

Commit 9a837ef

Browse files
committed
ressurect main functionality tests, bits of cruft yet to fix, decent throughput on multi-dim histograms
1 parent dd34926 commit 9a837ef

File tree

5 files changed

+72
-44
lines changed

5 files changed

+72
-44
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
uv pip install --system xgboost
7474
uv pip install --system 'tritonclient[grpc,http]!=2.41.0'
7575
# install checked out coffea
76-
uv pip install --system -q '.[dev,parsl,dask,spark]' --upgrade
76+
uv pip install --system -q '.[dev,parsl,dask,spark,gpu]' --upgrade
7777
uv pip list --system
7878
java -version
7979
- name: Install dependencies (MacOS)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ rucio = [
9292
"rucio-clients>=32;python_version>'3.8'",
9393
"rucio-clients<32;python_version<'3.9'",
9494
]
95+
gpu = [
96+
"cupy>=13.1.0"
97+
]
9598
dev = [
9699
"pre-commit",
97100
"flake8",

src/coffea/jitters/hist/hist_tools.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import namedtuple
88

99
import awkward
10+
import cupy
1011
import numpy
1112

1213
# Python 2 and 3 compatibility
@@ -23,6 +24,19 @@
2324

2425
MaybeSumSlice = namedtuple("MaybeSumSlice", ["start", "stop", "sum"])
2526

27+
_replace_nans = cupy.ElementwiseKernel("T v", "T x", "x = isnan(x)?v:x", "replace_nans")
28+
29+
_clip_bins = cupy.ElementwiseKernel(
30+
"T Nbins, T lo, T hi, T id",
31+
"T idx",
32+
"""
33+
const T floored = floor((id - lo) * float(Nbins) / (hi - lo)) + 1;
34+
idx = floored < 0 ? 0 : floored;
35+
idx = floored > Nbins ? Nbins + 1 : floored;
36+
""",
37+
"clip_bins",
38+
)
39+
2640

2741
def assemble_blocks(array, ndslice, depth=0):
2842
"""
@@ -481,8 +495,8 @@ def __init__(self, name, label, n_or_arr, lo=None, hi=None):
481495
self._lo = self._bins[0]
482496
self._hi = self._bins[-1]
483497
# to make searchsorted differentiate inf from nan
484-
self._bins = numpy.append(self._bins, numpy.inf)
485-
self._interval_bins = numpy.r_[-numpy.inf, self._bins, numpy.nan]
498+
self._bins = cupy.append(self._bins, cupy.inf)
499+
self._interval_bins = cupy.r_[-cupy.inf, self._bins, cupy.nan]
486500
self._bin_names = numpy.full(self._interval_bins[:-1].size, None)
487501
elif isinstance(n_or_arr, numbers.Integral):
488502
if lo is None or hi is None:
@@ -493,11 +507,11 @@ def __init__(self, name, label, n_or_arr, lo=None, hi=None):
493507
self._lo = lo
494508
self._hi = hi
495509
self._bins = n_or_arr
496-
self._interval_bins = numpy.r_[
497-
-numpy.inf,
498-
numpy.linspace(self._lo, self._hi, self._bins + 1),
499-
numpy.inf,
500-
numpy.nan,
510+
self._interval_bins = cupy.r_[
511+
-cupy.inf,
512+
cupy.linspace(self._lo, self._hi, self._bins + 1),
513+
cupy.inf,
514+
cupy.nan,
501515
]
502516
self._bin_names = numpy.full(self._interval_bins[:-1].size, None)
503517
else:
@@ -528,7 +542,7 @@ def __setstate__(self, d):
528542
if "_intervals" in d: # convert old hists to new serialization format
529543
_old_intervals = d.pop("_intervals")
530544
interval_bins = [i._lo for i in _old_intervals] + [_old_intervals[-1]._hi]
531-
d["_interval_bins"] = numpy.array(interval_bins)
545+
d["_interval_bins"] = cupy.array(interval_bins)
532546
d["_bin_names"] = numpy.array(
533547
[interval._label for interval in _old_intervals]
534548
)
@@ -548,31 +562,36 @@ def index(self, identifier):
548562
Returns an integer corresponding to the index in the axis where the histogram would be filled.
549563
The integer range includes flow bins: ``0 = underflow, n+1 = overflow, n+2 = nanflow``
550564
"""
551-
isarray = isinstance(identifier, (awkward.Array, numpy.ndarray))
565+
isarray = isinstance(identifier, (awkward.Array, cupy.ndarray, numpy.ndarray))
552566
if isarray or isinstance(identifier, numbers.Number):
553-
if isarray:
554-
identifier = numpy.asarray(identifier)
567+
identifier = awkward.to_cupy(identifier) # cupy.asarray(identifier)
555568
if self._uniform:
556-
idx = numpy.clip(
557-
numpy.floor(
558-
(identifier - self._lo)
559-
* float(self._bins)
560-
/ (self._hi - self._lo)
569+
idx = None
570+
if isarray:
571+
idx = cupy.zeros_like(identifier)
572+
_clip_bins(float(self._bins), self._lo, self._hi, identifier, idx)
573+
else:
574+
idx = numpy.clip(
575+
numpy.floor(
576+
(identifier - self._lo)
577+
* float(self._bins)
578+
/ (self._hi - self._lo)
579+
)
580+
+ 1,
581+
0,
582+
self._bins + 1,
561583
)
562-
+ 1,
563-
0,
564-
self._bins + 1,
565-
)
566-
if isinstance(idx, numpy.ndarray):
567-
idx[numpy.isnan(idx)] = self.size - 1
584+
585+
if isinstance(idx, (cupy.ndarray, numpy.ndarray)):
586+
_replace_nans(self.size - 1, idx)
568587
idx = idx.astype(int)
569588
elif numpy.isnan(idx):
570589
idx = self.size - 1
571590
else:
572591
idx = int(idx)
573592
return idx
574593
else:
575-
return numpy.searchsorted(self._bins, identifier, side="right")
594+
return cupy.searchsorted(self._bins, identifier, side="right")
576595
elif isinstance(identifier, Interval):
577596
if identifier.nan():
578597
return self.size - 1
@@ -1095,7 +1114,9 @@ def __getitem__(self, keys):
10951114
dense_idx = tuple(dense_idx)
10961115

10971116
def dense_op(array):
1098-
return numpy.block(assemble_blocks(array, dense_idx))
1117+
as_numpy = array.get()
1118+
blocked = numpy.block(assemble_blocks(as_numpy, dense_idx))
1119+
return cupy.asarray(blocked)
10991120

11001121
out = Hist(self._label, *new_dims, dtype=self._dtype)
11011122
if self._sumw2 is not None:
@@ -1139,10 +1160,10 @@ def fill(self, **values):
11391160
11401161
"""
11411162
weight = values.pop("weight", None)
1142-
if isinstance(weight, (awkward.Array, numpy.ndarray)):
1143-
weight = numpy.asarray(weight)
1163+
if isinstance(weight, (awkward.Array, cupy.ndarray, numpy.ndarray)):
1164+
weight = cupy.array(weight)
11441165
if isinstance(weight, numbers.Number):
1145-
weight = numpy.atleast_1d(weight)
1166+
weight = cupy.atleast_1d(weight)
11461167
if not all(d.name in values for d in self._axes):
11471168
missing = ", ".join(d.name for d in self._axes if d.name not in values)
11481169
raise ValueError(
@@ -1161,44 +1182,46 @@ def fill(self, **values):
11611182

11621183
sparse_key = tuple(d.index(values[d.name]) for d in self.sparse_axes())
11631184
if sparse_key not in self._sumw:
1164-
self._sumw[sparse_key] = numpy.zeros(
1185+
self._sumw[sparse_key] = cupy.zeros(
11651186
shape=self._dense_shape, dtype=self._dtype
11661187
)
11671188
if self._sumw2 is not None:
1168-
self._sumw2[sparse_key] = numpy.zeros(
1189+
self._sumw2[sparse_key] = cupy.zeros(
11691190
shape=self._dense_shape, dtype=self._dtype
11701191
)
11711192

11721193
if self.dense_dim() > 0:
11731194
dense_indices = tuple(
1174-
d.index(values[d.name]) for d in self._axes if isinstance(d, DenseAxis)
1195+
cupy.asarray(d.index(values[d.name]))
1196+
for d in self._axes
1197+
if isinstance(d, DenseAxis)
11751198
)
1176-
xy = numpy.atleast_1d(
1177-
numpy.ravel_multi_index(dense_indices, self._dense_shape)
1199+
xy = cupy.atleast_1d(
1200+
cupy.ravel_multi_index(dense_indices, self._dense_shape)
11781201
)
11791202
if weight is not None:
1180-
self._sumw[sparse_key][:] += numpy.bincount(
1203+
self._sumw[sparse_key][:] += cupy.bincount(
11811204
xy, weights=weight, minlength=numpy.array(self._dense_shape).prod()
11821205
).reshape(self._dense_shape)
1183-
self._sumw2[sparse_key][:] += numpy.bincount(
1206+
self._sumw2[sparse_key][:] += cupy.bincount(
11841207
xy,
11851208
weights=weight**2,
11861209
minlength=numpy.array(self._dense_shape).prod(),
11871210
).reshape(self._dense_shape)
11881211
else:
1189-
self._sumw[sparse_key][:] += numpy.bincount(
1212+
self._sumw[sparse_key][:] += cupy.bincount(
11901213
xy, weights=None, minlength=numpy.array(self._dense_shape).prod()
11911214
).reshape(self._dense_shape)
11921215
if self._sumw2 is not None:
1193-
self._sumw2[sparse_key][:] += numpy.bincount(
1216+
self._sumw2[sparse_key][:] += cupy.bincount(
11941217
xy,
11951218
weights=None,
11961219
minlength=numpy.array(self._dense_shape).prod(),
11971220
).reshape(self._dense_shape)
11981221
else:
11991222
if weight is not None:
1200-
self._sumw[sparse_key] += numpy.sum(weight)
1201-
self._sumw2[sparse_key] += numpy.sum(weight**2)
1223+
self._sumw[sparse_key] += cupy.sum(weight)
1224+
self._sumw2[sparse_key] += cupy.sum(weight**2)
12021225
else:
12031226
self._sumw[sparse_key] += 1.0
12041227
if self._sumw2 is not None:
@@ -1604,14 +1627,14 @@ def expandkey(key):
16041627
for sparse_key, sumw in values.items():
16051628
index = tuple(expandkey(sparse_key))
16061629
view = out.view(flow=True)
1607-
view[index] = sumw
1630+
view[index] = sumw.get()
16081631
else:
16091632
values = self.values(sumw2=True, overflow="all")
16101633
for sparse_key, (sumw, sumw2) in values.items():
16111634
index = tuple(expandkey(sparse_key))
16121635
view = out.view(flow=True)
1613-
view[index].value = sumw
1614-
view[index].variance = sumw2
1636+
view[index].value = sumw.get()
1637+
view[index].variance = sumw2.get()
16151638

16161639
return out
16171640

src/coffea/jitters/hist/plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ def plot1d(
263263
the_slice = (the_slice[1], the_slice[0])
264264
sumw = sumw[the_slice]
265265
sumw2 = sumw2[the_slice]
266-
plot_info["sumw"].append(sumw)
267-
plot_info["sumw2"].append(sumw2)
266+
plot_info["sumw"].append(sumw.get())
267+
plot_info["sumw2"].append(sumw2.get())
268268

269269
def w2err(sumw, sumw2):
270270
err = []

tests/test_hist_tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pytest
44
from dummy_distributions import dummy_jagged_eta_pt
55

6+
pytest.importorskip("cupy")
7+
68
from coffea.jitters import hist
79

810

0 commit comments

Comments
 (0)