Skip to content

Commit 802024a

Browse files
authored
test: cleanup matplotlib rc params before and after each test (#576)
* test: cleanup matplotlib rc params before and after each test * chore: update ruff version to v0.12.1 and clean up imports in multiple files
1 parent e421876 commit 802024a

File tree

10 files changed

+71
-72
lines changed

10 files changed

+71
-72
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ repos:
1818

1919
- repo: https://github.com/astral-sh/ruff-pre-commit
2020
# Ruff version.
21-
rev: v0.12.0
21+
rev: v0.12.1
2222
hooks:
2323
# Run the linter.
2424
- id: ruff

src/mplhep/__init__.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,37 +67,37 @@
6767

6868
# Log submodules
6969
__all__ = [
70-
"cms",
71-
"atlas",
72-
"lhcb",
70+
"EnhancedPlottableHistogram",
71+
"_check_counting_histogram",
7372
"alice",
74-
"plot",
75-
"style",
76-
"label",
77-
"savelabels",
73+
"append_axes",
74+
"atlas",
75+
"box_aspect",
76+
"cms",
77+
"get_asymmetry",
78+
"get_comparison",
79+
"get_difference",
80+
"get_efficiency",
81+
"get_plottables",
82+
"get_pull",
83+
"get_ratio",
84+
"hist2dplot",
7885
# Log plot functions
7986
"histplot",
80-
"hist2dplot",
81-
"mpl_magic",
82-
"yscale_legend",
83-
"yscale_anchored_text",
84-
"ylow",
85-
"rescale_to_axessize",
86-
"box_aspect",
87+
"label",
88+
"lhcb",
89+
"make_plottable_histogram",
8790
"make_square_add_cbar",
8891
"merge_legend_handles_labels",
89-
"append_axes",
90-
"sort_legend",
92+
"mpl_magic",
93+
"plot",
94+
"rescale_to_axessize",
9195
"save_variations",
96+
"savelabels",
9297
"set_style",
93-
"get_plottables",
94-
"EnhancedPlottableHistogram",
95-
"make_plottable_histogram",
96-
"_check_counting_histogram",
97-
"get_comparison",
98-
"get_difference",
99-
"get_ratio",
100-
"get_pull",
101-
"get_asymmetry",
102-
"get_efficiency",
98+
"sort_legend",
99+
"style",
100+
"ylow",
101+
"yscale_anchored_text",
102+
"yscale_legend",
103103
]

src/mplhep/error_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ def poisson_interval(sumw, sumw2, coverage=_coverage1sd):
5050
lo = scale * scipy.stats.chi2.ppf((1 - coverage) / 2, 2 * counts) / 2.0
5151
hi = scale * scipy.stats.chi2.ppf((1 + coverage) / 2, 2 * (counts + 1)) / 2.0
5252
interval = np.array([lo, hi])
53-
interval[interval == np.nan] = 0.0 # chi2.ppf produces nan for counts=0
53+
interval[np.isnan(interval)] = 0.0 # chi2.ppf produces nan for counts=0
5454
return interval

src/mplhep/label.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def save_variations(fig, name, text_list=None, exp=None):
555555
if text_list is None:
556556
text_list = ["Preliminary", ""]
557557

558-
from mplhep.label import ExpSuffix, ExpText
558+
from mplhep.label import ExpSuffix, ExpText # noqa: PLC0415
559559

560560
for text in text_list:
561561
for ax in fig.get_axes():

src/mplhep/plot.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import matplotlib as mpl
1010
import matplotlib.pyplot as plt
1111
import numpy as np
12+
from matplotlib.lines import Line2D
1213
from matplotlib.offsetbox import AnchoredText
14+
from matplotlib.patches import Patch, Rectangle
15+
from matplotlib.text import Text
1316
from matplotlib.transforms import Bbox
1417
from mpl_toolkits.axes_grid1 import axes_size, make_axes_locatable
1518

@@ -1031,9 +1034,6 @@ def overlap(ax, bbox, get_vertices=False):
10311034
"""
10321035
Find overlap of bbox for drawn elements an axes.
10331036
"""
1034-
from matplotlib.lines import Line2D
1035-
from matplotlib.patches import Patch, Rectangle
1036-
from matplotlib.text import Text
10371037

10381038
# From
10391039
# https://github.com/matplotlib/matplotlib/blob/08008d5cb4d1f27692e9aead9a76396adc8f0b19/lib/matplotlib/legend.py#L845
@@ -1371,13 +1371,13 @@ def extend_ratio(ax, yhax):
13711371
fig.get_size_inches()[1],
13721372
)
13731373
elif position in ["left"]:
1374-
divider.set_horizontal(xsizes[::-1] + [axes_size.Fixed(width)])
1374+
divider.set_horizontal([*xsizes[::-1], axes_size.Fixed(width)])
13751375
fig.set_size_inches(
13761376
fig.get_size_inches()[0] * extend_ratio(ax, yhax)[0],
13771377
fig.get_size_inches()[1],
13781378
)
13791379
elif position in ["top"]:
1380-
divider.set_vertical([axes_size.Fixed(height)] + xsizes[::-1])
1380+
divider.set_vertical([axes_size.Fixed(height), *xsizes[::-1]])
13811381
fig.set_size_inches(
13821382
fig.get_size_inches()[0],
13831383
fig.get_size_inches()[1] * extend_ratio(ax, yhax)[1],
@@ -1397,8 +1397,6 @@ def extend_ratio(ax, yhax):
13971397
####################
13981398
# Legend Helpers
13991399
def hist_legend(ax=None, **kwargs):
1400-
from matplotlib.lines import Line2D
1401-
14021400
if ax is None:
14031401
ax = plt.gca()
14041402

src/mplhep/utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def isLight(rgb):
3131
def get_plottable_protocol_bins(
3232
axis: PlottableAxis,
3333
) -> tuple[np.ndarray, np.ndarray | None]:
34-
out = np.arange(len(axis) + 1).astype(float)
34+
out: np.ndarray = np.arange(len(axis) + 1).astype(float)
3535
if isinstance(axis[0], tuple): # Regular axis
3636
out[0] = axis[0][0]
3737
out[1:] = [axis[i][1] for i in range(len(axis))] # type: ignore[index]
@@ -451,7 +451,7 @@ def norm_stack_plottables(plottables, bins, stack=False, density=False, binwnorm
451451

452452
# Stack
453453
if stack and len(plottables) > 1:
454-
from .utils import stack as stack_fun
454+
from .utils import stack as stack_fun # noqa: PLC0415
455455

456456
plottables = stack_fun(*plottables)
457457

@@ -504,6 +504,7 @@ def __init__(
504504
np.zeros_like(self.values()),
505505
np.zeros_like(self.values()),
506506
)
507+
self._hash = None
507508

508509
def __eq__(self, other):
509510
"""Check equality between two EnhancedPlottableHistogram instances based on values(), variances(), and edges."""
@@ -515,6 +516,22 @@ def __eq__(self, other):
515516
]
516517
)
517518

519+
def __hash__(self):
520+
"""Return a hash of the EnhancedPlottableHistogram object based on its values, variances, and edges."""
521+
if self._hash is None:
522+
self._hash = hash(
523+
(
524+
tuple(self.values().flatten()),
525+
tuple(
526+
self.variances().flatten()
527+
if self.variances() is not None
528+
else []
529+
),
530+
tuple(self.edges_1d().flatten()),
531+
)
532+
)
533+
return self._hash
534+
518535
def __repr__(self):
519536
"""Return string representation of the EnhancedPlottableHistogram object."""
520537
return f"EnhancedPlottableHistogram(values={self.values()}, edges={self.axes[0].edges}, variances={self.variances()})"
@@ -640,7 +657,7 @@ def calculate_relative(method_fcn, variances):
640657
)
641658
elif method == "poisson":
642659
try:
643-
from .error_estimation import poisson_interval
660+
from .error_estimation import poisson_interval # noqa: PLC0415
644661

645662
self.yerr_lo, self.yerr_hi = calculate_relative(
646663
poisson_interval, self.variances()

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import matplotlib.pyplot as plt
2+
import pytest
3+
4+
5+
@pytest.fixture(autouse=True)
6+
def clear_mplhep_rcparams():
7+
"""Clear matplotlib rcParams before and after each test."""
8+
9+
plt.rcParams.update(plt.rcParamsDefault)
10+
yield
11+
plt.rcParams.update(plt.rcParamsDefault)

tests/test_basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
import pytest
11+
import uproot
1112

1213
os.environ["RUNNING_PYTEST"] = "true"
1314

@@ -94,8 +95,6 @@ def test_log():
9495

9596
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
9697
def test_onebin_hist():
97-
import hist
98-
9998
fig, axs = plt.subplots()
10099
h = hist.Hist(hist.axis.Regular(1, 0, 1))
101100
h.fill([-1, 0.5])
@@ -217,7 +216,6 @@ def test_histplot_uproot_flow():
217216
h2.fill(entries[entries < 15])
218217
h3.fill(entries[entries > 5])
219218
h4.fill(entries[(entries > 5) & (entries < 15)])
220-
import uproot
221219

222220
with uproot.recreate("flow_th1.root") as f:
223221
f["h"] = h

tests/test_inputs.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
os.environ["RUNNING_PYTEST"] = "true"
1111

12+
import boost_histogram as bh
13+
import uproot
14+
import uproot4
15+
from skhep_testdata import data_path
16+
1217
import mplhep as hep
1318

1419
"""
@@ -39,9 +44,6 @@ def test_inputs_basic():
3944

4045
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
4146
def test_inputs_uproot():
42-
import uproot4
43-
from skhep_testdata import data_path
44-
4547
fname = data_path("uproot-hepdata-example.root")
4648
f = uproot4.open(fname)
4749

@@ -55,10 +57,6 @@ def test_inputs_uproot():
5557

5658
@check_figures_equal(extensions=("png", "pdf"))
5759
def test_uproot_versions(fig_test, fig_ref):
58-
import uproot
59-
import uproot4
60-
from skhep_testdata import data_path
61-
6260
fname = data_path("uproot-hepdata-example.root")
6361
f4 = uproot4.open(fname)
6462
f3 = uproot.open(fname)
@@ -80,7 +78,6 @@ def test_uproot_versions(fig_test, fig_ref):
8078
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
8179
def test_inputs_bh():
8280
np.random.seed(0)
83-
import boost_histogram as bh
8481

8582
hist2d = bh.Histogram(bh.axis.Regular(10, 0.0, 1.0), bh.axis.Regular(10, 0, 1))
8683
hist2d.fill(np.random.normal(0.5, 0.2, 1000), np.random.normal(0.5, 0.2, 1000))
@@ -95,7 +92,6 @@ def test_inputs_bh():
9592
@pytest.mark.mpl_image_compare(style="default", remove_text=True)
9693
def test_inputs_bh_cat():
9794
np.random.seed(0)
98-
import boost_histogram as bh
9995

10096
hist2d = bh.Histogram(
10197
bh.axis.IntCategory(range(10)), bh.axis.StrCategory("", growth=True)

tests/test_styles.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
2727
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
2828
def test_style_atlas():
29-
plt.rcParams.update(plt.rcParamsDefault)
30-
3129
# Test suite does not have Helvetica
3230
plt.style.use([hep.style.ATLAS, {"font.sans-serif": ["Tex Gyre Heros"]}])
3331
fig, ax = plt.subplots()
@@ -39,8 +37,6 @@ def test_style_atlas():
3937
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
4038
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
4139
def test_style_cms():
42-
plt.rcParams.update(plt.rcParamsDefault)
43-
4440
plt.style.use(hep.style.CMS)
4541
fig, ax = plt.subplots()
4642
hep.cms.label("Preliminary")
@@ -51,8 +47,6 @@ def test_style_cms():
5147
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
5248
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
5349
def test_style_alice():
54-
plt.rcParams.update(plt.rcParamsDefault)
55-
5650
plt.style.use(hep.style.ALICE)
5751
fig, ax = plt.subplots()
5852
hep.alice.label("Preliminary")
@@ -63,8 +57,6 @@ def test_style_alice():
6357
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
6458
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
6559
def test_style_lhcb():
66-
plt.rcParams.update(plt.rcParamsDefault)
67-
6860
plt.style.use([hep.style.LHCb1, {"figure.autolayout": False}])
6961
fig, ax = plt.subplots()
7062
hep.lhcb.label("Preliminary")
@@ -74,8 +66,6 @@ def test_style_lhcb():
7466
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
7567
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
7668
def test_style_lhcb2():
77-
plt.rcParams.update(plt.rcParamsDefault)
78-
7969
plt.style.use([hep.style.LHCb2, {"figure.autolayout": False}])
8070
fig, ax = plt.subplots()
8171
hep.lhcb.label("Preliminary")
@@ -85,7 +75,6 @@ def test_style_lhcb2():
8575
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
8676
@pytest.mark.mpl_image_compare(style="default", remove_text=False)
8777
def test_style_plothist():
88-
plt.rcParams.update(plt.rcParamsDefault)
8978
plt.style.use(hep.style.PLOTHIST)
9079
fig, ax = plt.subplots()
9180
return fig
@@ -106,8 +95,6 @@ def test_style_plothist():
10695
ids=["ALICE", "ATLAS", "CMS", "LHCb1", "LHCb2", "ROOT"],
10796
)
10897
def test_use_style(fig_test, fig_ref, mplhep_style):
109-
plt.rcParams.update(plt.rcParamsDefault)
110-
11198
hep.rcParams.clear()
11299
plt.style.use(mplhep_style)
113100
fig_ref.subplots()
@@ -120,8 +107,6 @@ def test_use_style(fig_test, fig_ref, mplhep_style):
120107
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only")
121108
@check_figures_equal(extensions=["pdf"])
122109
def test_use_style_LHCb_dep(fig_test, fig_ref):
123-
plt.rcParams.update(plt.rcParamsDefault)
124-
125110
hep.rcParams.clear()
126111
with pytest.warns(FutureWarning):
127112
plt.style.use(hep.style.LHCb)
@@ -148,8 +133,6 @@ def test_use_style_LHCb_dep(fig_test, fig_ref):
148133
ids=["ALICE", "ATLAS", "CMS", "LHCb", "LHCb1", "LHCb2", "ROOT"],
149134
)
150135
def test_use_style_str_alias(fig_test, fig_ref, mplhep_style, str_alias):
151-
plt.rcParams.update(plt.rcParamsDefault)
152-
153136
hep.rcParams.clear()
154137
plt.style.use(mplhep_style)
155138
fig_ref.subplots()
@@ -175,8 +158,6 @@ def test_use_style_str_alias(fig_test, fig_ref, mplhep_style, str_alias):
175158
ids=["ALICE", "ATLAS", "CMS", "LHCb", "LHCb1", "LHCb2", "ROOT"],
176159
)
177160
def test_use_style_self_consistent(fig_test, fig_ref, mplhep_style, str_alias):
178-
plt.rcParams.update(plt.rcParamsDefault)
179-
180161
hep.rcParams.clear()
181162
hep.style.use(mplhep_style)
182163
fig_ref.subplots()
@@ -202,8 +183,6 @@ def test_use_style_self_consistent(fig_test, fig_ref, mplhep_style, str_alias):
202183
ids=["ALICE", "ATLAS", "CMS", "LHCb", "LHCb1", "LHCb2", "ROOT"],
203184
)
204185
def test_use_style_style_list(fig_test, fig_ref, mplhep_style, str_alias):
205-
plt.rcParams.update(plt.rcParamsDefault)
206-
207186
hep.rcParams.clear()
208187
plt.style.use([mplhep_style, {"font.sans-serif": "Comic Sans MS"}])
209188
fig_ref.subplots()

0 commit comments

Comments
 (0)