Skip to content

Commit 828e9cd

Browse files
committed
feat: add flow parameter support to comparison plotters
Fixes #594
1 parent a04514e commit 828e9cd

File tree

8 files changed

+564
-164
lines changed

8 files changed

+564
-164
lines changed

src/mplhep/comparison_plotters.py

Lines changed: 284 additions & 32 deletions
Large diffs are not rendered by default.

src/mplhep/plot.py

Lines changed: 149 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def iterable_not_string(arg):
468468
flow=flow,
469469
xoffsets=xoffsets,
470470
)
471-
flow_bins, underflow, overflow = flow_info
471+
_flow_bins, underflow, overflow = flow_info
472472

473473
##########
474474
# Plotting
@@ -708,61 +708,45 @@ def iterable_not_string(arg):
708708
msg = "No figure found"
709709
raise ValueError(msg)
710710
if flow == "hint":
711+
# Get all shared x-axes to draw markers on all of them
712+
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
713+
shared_axes = [
714+
_ax for _ax in shared_axes if _ax.get_position().x0 == ax.get_position().x0
715+
]
716+
711717
_marker_size = (
712718
30
713719
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
714720
)
715-
if underflow > 0.0:
716-
ax.scatter(
717-
final_bins[0],
718-
0,
719-
_marker_size,
720-
marker=align_marker("<", halign="right"),
721-
edgecolor="black",
722-
zorder=5,
723-
clip_on=False,
724-
facecolor="white",
725-
transform=ax.get_xaxis_transform(),
726-
)
727-
if overflow > 0.0:
728-
ax.scatter(
729-
final_bins[-1],
730-
0,
731-
_marker_size,
732-
marker=align_marker(">", halign="left"),
733-
edgecolor="black",
734-
zorder=5,
735-
clip_on=False,
736-
facecolor="white",
737-
transform=ax.get_xaxis_transform(),
738-
)
739-
740-
elif flow == "show":
741-
underflow_xticklabel = f"<{flow_bins[1]:g}"
742-
overflow_xticklabel = f">{flow_bins[-2]:g}"
743721

744-
# Loop over shared x axes to get xticks and xticklabels
745-
xticks, xticklabels = np.array([]), []
746-
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
747-
shared_axes = [
748-
_ax for _ax in shared_axes if _ax.get_position().x0 == ax.get_position().x0
749-
]
722+
# Draw markers on all shared axes
750723
for _ax in shared_axes:
751-
_xticks = _ax.get_xticks()
752-
_xticklabels = [label.get_text() for label in _ax.get_xticklabels()]
753-
754-
# Check if underflow/overflow xtick already exists
755-
if (
756-
underflow_xticklabel in _xticklabels
757-
or overflow_xticklabel in _xticklabels
758-
):
759-
xticks = _xticks
760-
xticklabels = _xticklabels
761-
break
762-
if len(_xticklabels) > 0:
763-
xticks = _xticks
764-
xticklabels = _xticklabels
724+
if underflow > 0.0:
725+
_ax.scatter(
726+
final_bins[0],
727+
0,
728+
_marker_size,
729+
marker=align_marker("<", halign="right"),
730+
edgecolor="black",
731+
zorder=5,
732+
clip_on=False,
733+
facecolor="white",
734+
transform=_ax.get_xaxis_transform(),
735+
)
736+
if overflow > 0.0:
737+
_ax.scatter(
738+
final_bins[-1],
739+
0,
740+
_marker_size,
741+
marker=align_marker(">", halign="left"),
742+
edgecolor="black",
743+
zorder=5,
744+
clip_on=False,
745+
facecolor="white",
746+
transform=_ax.get_xaxis_transform(),
747+
)
765748

749+
elif flow == "show":
766750
lw = ax.spines["bottom"].get_linewidth()
767751
_edges = plottables[0].edges_1d()
768752
_centers = plottables[0].centers
@@ -771,91 +755,101 @@ def iterable_not_string(arg):
771755
* ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width
772756
)
773757

774-
if underflow > 0.0 or underflow_xticklabel in xticklabels:
775-
# Replace any existing xticks in underflow region with underflow bin center
776-
_mask = xticks > flow_bins[1]
777-
xticks = np.insert(xticks[_mask], 0, _centers[0])
778-
xticklabels = [underflow_xticklabel] + [
779-
xlab for i, xlab in enumerate(xticklabels) if _mask[i]
780-
]
758+
# Use edge values for flow bin labels (not center values)
759+
underflow_xticklabel = f"<{_edges[1]:g}"
760+
overflow_xticklabel = f">{_edges[-2]:g}"
761+
762+
# Get shared axes for marker placement
763+
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
764+
shared_axes = [
765+
_ax for _ax in shared_axes if _ax.get_position().x0 == ax.get_position().x0
766+
]
767+
768+
# Use existing tick positions (from matplotlib's default ticker)
769+
# rather than creating ticks at every bin edge
770+
existing_ticks = ax.get_xticks()
771+
# Filter to only include ticks within the regular bin range
772+
regular_edges = _edges[1:-1]
773+
new_xticks = [
774+
tick
775+
for tick in existing_ticks
776+
if regular_edges[0] <= tick <= regular_edges[-1]
777+
]
778+
new_xticklabels = [f"{tick:g}" for tick in new_xticks]
779+
780+
# Find bottom axis for marker placement
781+
bottom_axis = min(shared_axes, key=lambda a: a.get_position().y0)
781782

782-
# Don't draw markers on the top of the top axis
783-
top_axis = max(shared_axes, key=lambda a: a.get_position().y0)
783+
if underflow > 0.0:
784+
# Add underflow bin center at the beginning
785+
new_xticks.insert(0, _centers[0])
786+
new_xticklabels.insert(0, underflow_xticklabel)
784787

785-
# Draw on all shared axes
788+
# Draw markers only on the bottom (h=0) of axes
786789
for _ax in shared_axes:
787-
_ax.set_xticks(xticks)
788-
_ax.set_xticklabels(xticklabels)
789-
for h in [0, 1]:
790-
# Don't draw marker on the top of the top axis
791-
if _ax == top_axis and h == 1:
792-
continue
793-
794-
_ax.plot(
795-
[_edges[0], _edges[1]],
796-
[h, h],
797-
color="white",
798-
zorder=5,
799-
ls="--",
800-
lw=lw,
801-
transform=_ax.get_xaxis_transform(),
802-
clip_on=False,
803-
)
790+
h = 0 # Only draw on bottom
791+
_ax.plot(
792+
[_edges[0], _edges[1]],
793+
[h, h],
794+
color="white",
795+
zorder=5,
796+
ls="--",
797+
lw=lw,
798+
transform=_ax.get_xaxis_transform(),
799+
clip_on=False,
800+
)
804801

805-
_ax.scatter(
806-
_centers[0],
807-
h,
808-
_marker_size,
809-
marker=align_marker("d", valign="center"),
810-
edgecolor="black",
811-
zorder=5,
812-
clip_on=False,
813-
facecolor="white",
814-
transform=_ax.get_xaxis_transform(),
815-
)
816-
if overflow > 0.0 or overflow_xticklabel in xticklabels:
817-
# Replace any existing xticks in overflow region with overflow bin center
818-
_mask = xticks < flow_bins[-2]
819-
xticks = np.insert(xticks[_mask], sum(_mask), _centers[-1])
820-
xticklabels = [xlab for i, xlab in enumerate(xticklabels) if _mask[i]] + [
821-
overflow_xticklabel
822-
]
823-
824-
# Don't draw markers on the top of the top axis
825-
top_axis = max(shared_axes, key=lambda a: a.get_position().y0)
826-
827-
# Draw on all shared axes
802+
_ax.scatter(
803+
_centers[0],
804+
h,
805+
_marker_size,
806+
marker=align_marker("d", valign="center"),
807+
edgecolor="black",
808+
zorder=5,
809+
clip_on=False,
810+
facecolor="white",
811+
transform=_ax.get_xaxis_transform(),
812+
)
813+
if overflow > 0.0:
814+
# Add overflow bin center at the end
815+
new_xticks.append(_centers[-1])
816+
new_xticklabels.append(overflow_xticklabel)
817+
818+
# Draw markers only on the bottom (h=0) of axes
828819
for _ax in shared_axes:
829-
_ax.set_xticks(xticks)
830-
_ax.set_xticklabels(xticklabels)
831-
832-
for h in [0, 1]:
833-
# Don't draw marker on the top of the top axis
834-
if _ax == top_axis and h == 1:
835-
continue
836-
837-
_ax.plot(
838-
[_edges[-2], _edges[-1]],
839-
[h, h],
840-
color="white",
841-
zorder=5,
842-
ls="--",
843-
lw=lw,
844-
transform=_ax.get_xaxis_transform(),
845-
clip_on=False,
846-
)
820+
h = 0 # Only draw on bottom
821+
_ax.plot(
822+
[_edges[-2], _edges[-1]],
823+
[h, h],
824+
color="white",
825+
zorder=5,
826+
ls="--",
827+
lw=lw,
828+
transform=_ax.get_xaxis_transform(),
829+
clip_on=False,
830+
)
847831

848-
_ax.scatter(
849-
_centers[-1],
850-
h,
851-
_marker_size,
852-
marker=align_marker("d", valign="center"),
853-
edgecolor="black",
854-
zorder=5,
855-
clip_on=False,
856-
facecolor="white",
857-
transform=_ax.get_xaxis_transform(),
858-
)
832+
_ax.scatter(
833+
_centers[-1],
834+
h,
835+
_marker_size,
836+
marker=align_marker("d", valign="center"),
837+
edgecolor="black",
838+
zorder=5,
839+
clip_on=False,
840+
facecolor="white",
841+
transform=_ax.get_xaxis_transform(),
842+
)
843+
844+
# Set the final xticks and xticklabels on all shared axes
845+
for _ax in shared_axes:
846+
_ax.set_xticks(new_xticks)
847+
# Only set tick labels on the bottom axis
848+
if _ax == bottom_axis:
849+
_ax.set_xticklabels(new_xticklabels)
850+
else:
851+
# Explicitly set empty labels on other axes
852+
_ax.set_xticklabels(["" for _ in new_xticks])
859853

860854
return return_artists
861855

@@ -1400,14 +1394,16 @@ def model(
14001394
histtype="band",
14011395
)
14021396
else:
1397+
# Remove flow parameter for funcplot (it only works with histplot)
1398+
funcplot_kwargs = {k: v for k, v in stacked_kwargs.items() if k != "flow"}
14031399
funcplot(
14041400
stacked_components,
14051401
ax=ax,
14061402
stack=True,
14071403
colors=stacked_colors,
14081404
labels=stacked_labels,
14091405
range=xlim,
1410-
**stacked_kwargs,
1406+
**funcplot_kwargs,
14111407
)
14121408

14131409
if len(unstacked_components) > 0:
@@ -1435,14 +1431,18 @@ def model(
14351431
**unstacked_kwargs,
14361432
)
14371433
else:
1434+
# Remove flow parameter for funcplot (it only works with histplot)
1435+
funcplot_unstacked_kwargs = {
1436+
k: v for k, v in unstacked_kwargs.items() if k != "flow"
1437+
}
14381438
funcplot(
14391439
component,
14401440
ax=ax,
14411441
stack=False,
14421442
color=color,
14431443
label=label,
14441444
range=xlim,
1445-
**unstacked_kwargs,
1445+
**funcplot_unstacked_kwargs,
14461446
)
14471447
# Plot the sum of all the components
14481448
if model_sum_kwargs.pop("show", True) and (
@@ -1472,11 +1472,15 @@ def model(
14721472
def sum_function(x):
14731473
return sum(f(x) for f in components)
14741474

1475+
# Remove flow parameter for funcplot (it only works with histplot)
1476+
funcplot_sum_kwargs = {
1477+
k: v for k, v in model_sum_kwargs.items() if k != "flow"
1478+
}
14751479
funcplot(
14761480
sum_function,
14771481
ax=ax,
14781482
range=xlim,
1479-
**model_sum_kwargs,
1483+
**funcplot_sum_kwargs,
14801484
)
14811485
elif (
14821486
model_uncertainty
@@ -1488,7 +1492,20 @@ def sum_function(x):
14881492
sum(components), ax=ax, label=model_uncertainty_label, histtype="band"
14891493
)
14901494

1491-
ax.set_xlim(xlim)
1495+
# Check if flow="show" is set in any of the kwargs
1496+
# If so, don't reset xlim as histplot will have set it correctly for flow bins
1497+
flow_in_kwargs = (
1498+
stacked_kwargs.get("flow") == "show"
1499+
or model_sum_kwargs.get("flow") == "show"
1500+
or any(
1501+
kwargs.get("flow") == "show"
1502+
for kwargs in unstacked_kwargs_list
1503+
if kwargs is not None
1504+
)
1505+
)
1506+
1507+
if not flow_in_kwargs:
1508+
ax.set_xlim(xlim)
14921509
ax.set_xlabel(xlabel)
14931510
ax.set_ylabel(ylabel)
14941511
set_fitting_ylabel_fontsize(ax)
84.9 KB
Loading
48.1 KB
Loading
35.4 KB
Loading

0 commit comments

Comments
 (0)