Skip to content

Commit b108cc7

Browse files
0ctagonandrzejnovak
authored andcommitted
feat: added subplot wrapper
1 parent 6f81b42 commit b108cc7

File tree

3 files changed

+51
-18
lines changed

3 files changed

+51
-18
lines changed

src/mplhep/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
sort_legend,
7373
yscale_anchored_text,
7474
yscale_legend,
75+
subplots,
7576
)
7677

7778
# Configs
@@ -148,4 +149,5 @@
148149
"style",
149150
"yscale_anchored_text",
150151
"yscale_legend",
152+
"subplots",
151153
]

src/mplhep/comparison_plotters.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .utils import (
2929
_get_model_type,
3030
set_fitting_ylabel_fontsize,
31+
subplots,
3132
)
3233

3334

@@ -111,15 +112,7 @@ def hists(
111112
_check_counting_histogram(h2_plottable)
112113

113114
if fig is None and ax_main is None and ax_comparison is None:
114-
figsize = plt.rcParams["figure.figsize"]
115-
fig, (ax_main, ax_comparison) = plt.subplots(
116-
nrows=2,
117-
figsize=(figsize[0], figsize[1] * 1.25),
118-
gridspec_kw={"height_ratios": [4, 1]},
119-
)
120-
fig.subplots_adjust(hspace=0.15)
121-
ax_main.xaxis.set_ticklabels([])
122-
ax_main.set_xlabel(" ")
115+
fig, (ax_main, ax_comparison) = subplots(nrows=2)
123116
elif fig is None or ax_main is None or ax_comparison is None:
124117
msg = "Need to provide fig, ax_main and ax_comparison (or none of them)."
125118
raise ValueError(msg)
@@ -496,15 +489,7 @@ def data_model(
496489

497490
if fig is None and ax_main is None and ax_comparison is None:
498491
if plot_only is None:
499-
figsize = plt.rcParams["figure.figsize"]
500-
fig, (ax_main, ax_comparison) = plt.subplots(
501-
nrows=2,
502-
figsize=(figsize[0], figsize[1] * 1.25),
503-
gridspec_kw={"height_ratios": [4, 1]},
504-
)
505-
fig.subplots_adjust(hspace=0.15)
506-
ax_main.xaxis.set_ticklabels([])
507-
ax_main.set_xlabel(" ")
492+
fig, (ax_main, ax_comparison) = subplots(nrows=2)
508493
elif plot_only == "ax_main":
509494
_, ax_comparison = plt.subplots()
510495
fig, ax_main = plt.subplots()

src/mplhep/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,49 @@ def _get_model_type(components):
533533
if all(callable(x) for x in components):
534534
return "functions"
535535
return "histograms"
536+
537+
538+
def subplots(
539+
figsize: tuple[float, float] | None = None,
540+
nrows: int = 1,
541+
gridspec_kw: dict | None = None,
542+
hspace: float = 0.15,
543+
*args,
544+
**kwargs,
545+
) -> tuple[plt.Figure, np.ndarray]:
546+
"""
547+
Wrapper around plt.subplots to create a figure with multiple subplots
548+
549+
Parameters
550+
----------
551+
figsize : tuple[float, float], optional
552+
Figure size in inches.
553+
nrows : int, optional
554+
Number of rows in the subplot grid. Default is 2.
555+
gridspec_kw : dict | None, optional
556+
Additional keyword arguments for the GridSpec. Default is None.
557+
If None is provided, this is set to {"height_ratios": [4, 1]}.
558+
hspace : float, optional
559+
Height spacing between subplots. Default is 0.15.
560+
561+
Returns
562+
-------
563+
fig : matplotlib.figure.Figure
564+
The created figure.
565+
axes : np.ndarray
566+
Array of Axes objects representing the subplots.
567+
"""
568+
if gridspec_kw is None and nrows > 1:
569+
gridspec_kw = {"height_ratios": [4*(1+0.25*(nrows-2)), *(1 for _ in range(nrows-1))]}
570+
if figsize is None:
571+
figsize = (plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][1] * (1 + 0.25 * (nrows - 1)))
572+
573+
fig, axes = plt.subplots(nrows=nrows, figsize=figsize, gridspec_kw=gridspec_kw)
574+
if nrows > 1:
575+
fig.subplots_adjust(hspace=hspace)
576+
577+
for ax in axes[:-1]:
578+
_ = ax.xaxis.set_ticklabels([])
579+
ax.set_xlabel(" ")
580+
581+
return fig, axes

0 commit comments

Comments
 (0)