Skip to content

Commit ae267f8

Browse files
committed
all tests work except test_hist_plot::test_plotgrid
1 parent 9deeed5 commit ae267f8

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

src/coffea/jitters/hist/hist_tools.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ class Bin(DenseAxis):
487487
def __init__(self, name, label, n_or_arr, lo=None, hi=None):
488488
super().__init__(name, label)
489489
self._lazy_intervals = None
490-
if isinstance(n_or_arr, (list, numpy.ndarray)):
490+
if isinstance(n_or_arr, (list, numpy.ndarray, cupy.ndarray)):
491491
self._uniform = False
492-
self._bins = numpy.array(n_or_arr, dtype="d")
492+
self._bins = cupy.array(n_or_arr, dtype="d")
493493
if not all(numpy.sort(self._bins) == self._bins):
494494
raise ValueError("Binning not sorted!")
495495
self._lo = self._bins[0]
@@ -583,7 +583,10 @@ def index(self, identifier):
583583
)
584584

585585
if isinstance(idx, (cupy.ndarray, numpy.ndarray)):
586-
_replace_nans(self.size - 1, idx)
586+
_replace_nans(
587+
self.size - 1,
588+
idx if idx.dtype.kind == "f" else idx.astype(cupy.float32),
589+
)
587590
idx = idx.astype(int)
588591
elif numpy.isnan(idx):
589592
idx = self.size - 1
@@ -596,7 +599,13 @@ def index(self, identifier):
596599
if identifier.nan():
597600
return self.size - 1
598601
for idx, interval in enumerate(self._intervals):
599-
if interval._lo <= identifier._lo and interval._hi >= identifier._hi:
602+
if (
603+
interval._lo <= identifier._lo
604+
or cupy.isclose(interval._lo, identifier._lo)
605+
) and (
606+
interval._hi >= identifier._hi
607+
or cupy.isclose(interval._hi, identifier._hi)
608+
):
600609
return idx
601610
raise ValueError(
602611
"Axis %r has no interval that fully contains identifier %r"
@@ -759,10 +768,10 @@ def edges(self, overflow="none"):
759768
See `Hist.sum` description for the allowed values.
760769
"""
761770
if self._uniform:
762-
out = numpy.linspace(self._lo, self._hi, self._bins + 1)
771+
out = cupy.linspace(self._lo, self._hi, self._bins + 1)
763772
else:
764773
out = self._bins[:-1].copy()
765-
out = numpy.r_[
774+
out = cupy.r_[
766775
2 * out[0] - out[1], out, 2 * out[-1] - out[-2], 3 * out[-1] - 2 * out[-2]
767776
]
768777
return out[overflow_behavior(overflow)]

src/coffea/jitters/hist/plot.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def plot1d(
231231
elif isinstance(axis, DenseAxis):
232232
ax.set_xlabel(axis.label)
233233
ax.set_ylabel(hist.label)
234-
edges = axis.edges(overflow=overflow)
234+
edges = axis.edges(overflow=overflow).get()
235235
if order is None:
236236
identifiers = (
237237
hist.identifiers(overlay, overflow=overlay_overflow)
@@ -417,12 +417,14 @@ def plotratio(
417417
elif isinstance(axis, DenseAxis):
418418
ax.set_xlabel(axis.label)
419419
ax.set_ylabel(num.label)
420-
edges = axis.edges(overflow=overflow)
421-
centers = axis.centers(overflow=overflow)
420+
edges = axis.edges(overflow=overflow).get()
421+
centers = axis.centers(overflow=overflow).get()
422422
ranges = (edges[1:] - edges[:-1]) / 2 if xerr else None
423423

424424
sumw_num, sumw2_num = num.values(sumw2=True, overflow=overflow)[()]
425425
sumw_denom, sumw2_denom = denom.values(sumw2=True, overflow=overflow)[()]
426+
sumw_num, sumw2_num = sumw_num.get(), sumw2_num.get()
427+
sumw_denom, sumw2_denom = sumw_denom.get(), sumw2_denom.get()
426428

427429
rsumw = sumw_num / sumw_denom
428430
if unc == "clopper-pearson":
@@ -557,9 +559,10 @@ def plot2d(
557559
if isinstance(xaxis, SparseAxis) or isinstance(yaxis, SparseAxis):
558560
raise NotImplementedError("Plot a sparse axis (e.g. bar chart or labeled bins)")
559561
else:
560-
xedges = xaxis.edges(overflow=xoverflow)
561-
yedges = yaxis.edges(overflow=yoverflow)
562+
xedges = xaxis.edges(overflow=xoverflow).get()
563+
yedges = yaxis.edges(overflow=yoverflow).get()
562564
sumw, sumw2 = hist.values(sumw2=True, overflow="allnan")[()]
565+
sumw, sumw2 = sumw.get(), sumw2.get()
563566
if transpose:
564567
sumw = sumw.T
565568
sumw2 = sumw2.T

tests/test_hist_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def test_plotratio():
125125

126126
# Add some pseudodata to a pt histogram so we can make a nice data/mc plot
127127
pthist = lepton_kinematics.sum("eta")
128-
bin_values = pthist.axis("pt").centers()
129-
poisson_means = pthist.sum("flavor").values()[()]
128+
bin_values = pthist.axis("pt").centers().get()
129+
poisson_means = pthist.sum("flavor").values()[()].get()
130130
values = np.repeat(bin_values, np.random.poisson(poisson_means))
131131
pthist.fill(flavor="pseudodata", pt=values)
132132

tests/test_hist_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def test_issue_247():
283283
def test_issue_333():
284284
axis = hist.Bin("channel", "Channel b1", 50, 0, 2000)
285285
temp = np.arange(0, 2000, 40, dtype=np.int16)
286-
assert np.all(axis.index(temp) == np.arange(50) + 1)
286+
assert np.all(axis.index(temp).get() == np.arange(50) + 1)
287287

288288

289289
def test_issue_394():

0 commit comments

Comments
 (0)