Skip to content

Commit 37cbd25

Browse files
authored
fix: correctly pass through w2method (#558)
1 parent 07c45f2 commit 37cbd25

File tree

5 files changed

+62
-6
lines changed

5 files changed

+62
-6
lines changed

src/mplhep/plot.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ def histplot(
178178
"hint",
179179
"none",
180180
}, "flow must be show, sum, hint, or none"
181+
if hasattr(H, "values") or hasattr(H[0], "values"): # Check for hist-like inputs
182+
assert bins is None, (
183+
"When plotting hist(-like) objects, specifying bins is not allowed."
184+
)
185+
assert w2 is None, (
186+
"When plotting hist(-like) objects, specifying w2 is not allowed."
187+
)
188+
if w2 is not None:
189+
assert np.array(w2).shape == np.array(H).shape, (
190+
"w2 must have the same shape as H"
191+
)
181192

182193
# Convert 1/0 etc to real bools
183194
stack = bool(stack)

src/mplhep/utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,15 @@ def get_plottables(
273273

274274
if w2 is not None:
275275
for _w2, _plottable in zip(
276-
w2.reshape(len(plottables), len(final_bins) - 1), plottables
276+
np.array(w2).reshape(len(plottables), len(final_bins) - 1), plottables
277277
):
278278
_plottable.variances = _w2
279279
_plottable.method = w2method
280280

281+
for _plottable in plottables:
282+
if _plottable.variances is not None:
283+
_plottable.method = w2method
284+
281285
if w2 is not None and yerr is not None:
282286
msg = "Can only supply errors or w2"
283287
raise ValueError(msg)
@@ -417,7 +421,9 @@ def norm_stack_plottables(plottables, bins, stack=False, density=False, binwnorm
417421

418422

419423
class Plottable:
420-
def __init__(self, values, *, edges=None, variances=None, yerr=None):
424+
def __init__(
425+
self, values, *, edges=None, variances=None, yerr=None, w2method="poisson"
426+
):
421427
self._values = np.array(values).astype(float)
422428
self.variances = None
423429
self._variances = None
@@ -434,7 +440,7 @@ def __init__(self, values, *, edges=None, variances=None, yerr=None):
434440
if self.edges is None:
435441
self.edges = np.arange(len(values) + 1)
436442
self.centers = self.edges[:-1] + np.diff(self.edges) / 2
437-
self.method = "poisson"
443+
self.method = w2method
438444

439445
self.yerr = yerr
440446
assert self.variances is None or self.yerr is None
@@ -470,12 +476,11 @@ def errors(self, method=None):
470476
method = "poisson"
471477
else:
472478
method = "sqrt"
473-
474479
if self._errors_present:
475480
return
476481

477-
def sqrt_method(values, _):
478-
return values - np.sqrt(values), values + np.sqrt(values)
482+
def sqrt_method(values, variances):
483+
return values - np.sqrt(variances), values + np.sqrt(variances)
479484

480485
def calculate_relative(method_fcn, variances):
481486
return np.abs(method_fcn(self.values, variances) - self.values)
27.8 KB
Loading
11.5 KB
Loading

tests/test_basic.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,46 @@ def test_histplot_w2():
633633
return fig
634634

635635

636+
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
637+
def test_histplot_w2_methods():
638+
htype1 = [10, 20, 30, 40, 30, 20, 10, 1, 0, 1, 0]
639+
htype1_w2 = [10, 20, 30, 40, 30, 20, 50, 1, 0, 1, 0]
640+
np.random.seed(0)
641+
htype2 = hist.new.Reg(11, 0, 11).Weight().fill(np.random.normal(3, 2, 100))
642+
643+
def fcn1(w, _):
644+
return np.maximum(0, w - np.ones_like(w) * 3), w + np.ones_like(w) * 3
645+
646+
def fcn2(w, _):
647+
return w - np.ones_like(w) * 0.2 * np.mean(w), w + np.ones_like(
648+
w
649+
) * 0.2 * np.mean(w)
650+
651+
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
652+
for ax, method in zip(axs.flatten(), [None, "poisson", "sqrt", fcn1, fcn2]):
653+
hep.histplot(htype1, w2=htype1_w2, w2method=method, ax=ax, label="With w2")
654+
hep.histplot(htype1, w2method=method, ax=ax, label="No w2 passed")
655+
htype2.plot(w2method=method, ax=ax, label="Hist")
656+
ax.set_title(str(method))
657+
ax.legend()
658+
return fig
659+
660+
661+
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
662+
def test_histplot_w2_poisson_handling():
663+
np.random.seed(0)
664+
evts = np.random.normal(2, 2, 100)
665+
weights = np.random.uniform(0.99, 1.01, 100)
666+
htype1 = hist.new.Reg(5, 0, 5).Weight().fill(evts)
667+
htype2 = hist.new.Reg(5, 0, 5).Weight().fill(evts, weight=weights)
668+
669+
fig, ax = plt.subplots()
670+
htype1.plot(ax=ax, histtype="errorbar", capsize=4, label="Raw counts")
671+
htype2.plot(ax=ax, label="Weighted counts")
672+
ax.legend()
673+
return fig
674+
675+
636676
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
637677
def test_histplot_types():
638678
hs, bins = [[2, 3, 4], [5, 4, 3]], [0, 1, 2, 3]

0 commit comments

Comments
 (0)