Skip to content

Commit 5dc4a61

Browse files
committed
backport fixes
1 parent b3f023f commit 5dc4a61

File tree

3 files changed

+59
-41
lines changed

3 files changed

+59
-41
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/indexing.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _temporarily_disable_add_directory():
3434
ROOT.TH1.AddDirectory(old_status)
3535

3636

37-
def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice_stop=False):
37+
def _process_index_for_axis(self, index, axis, is_slice_stop=False):
3838
"""Process an index for a histogram axis handling callables and index shifting."""
3939
if callable(index):
4040
# If the index is a `loc`, `underflow`, `overflow`, or `len`
@@ -56,7 +56,7 @@ def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice
5656
raise index
5757

5858

59-
def _compute_uhi_index(self, index, axis, include_flow_bins=True):
59+
def _compute_uhi_index(self, index, axis, flow=True):
6060
"""Convert tag functors to valid bin indices."""
6161
if isinstance(index, _rebin) or index is _sum:
6262
index = slice(None, None, index)
@@ -65,13 +65,13 @@ def _compute_uhi_index(self, index, axis, include_flow_bins=True):
6565
return _process_index_for_axis(self, index, axis)
6666

6767
if isinstance(index, slice):
68-
start, stop = _resolve_slice_indices(self, index, axis, include_flow_bins)
68+
start, stop = _resolve_slice_indices(self, index, axis, flow)
6969
return slice(start, stop, index.step)
7070

7171
raise TypeError(f"Unsupported index type: {type(index).__name__}")
7272

7373

74-
def _compute_common_index(self, index, include_flow_bins=True):
74+
def _compute_common_index(self, index, flow=True):
7575
"""Normalize and expand the index to match the histogram dimension."""
7676
dim = self.GetDimension()
7777
if isinstance(index, dict):
@@ -93,26 +93,26 @@ def _compute_common_index(self, index, include_flow_bins=True):
9393
if len(index) != dim:
9494
raise IndexError(f"Expected {dim} indices, got {len(index)}")
9595

96-
return [_compute_uhi_index(self, idx, axis, include_flow_bins) for axis, idx in enumerate(index)]
96+
return [_compute_uhi_index(self, idx, axis, flow) for axis, idx in enumerate(index)]
9797

9898

9999
def _setbin(self, index, value):
100100
"""Set the bin content for a specific bin index"""
101101
self.SetBinContent(index, value)
102102

103103

104-
def _resolve_slice_indices(self, index, axis, include_flow_bins=True):
104+
def _resolve_slice_indices(self, index, axis, flow=True):
105105
"""Resolve slice start and stop indices for a given axis"""
106106
start, stop = index.start, index.stop
107107
start = (
108-
_process_index_for_axis(self, start, axis, include_flow_bins)
108+
_process_index_for_axis(self, start, axis, flow)
109109
if start is not None
110-
else _underflow(self, axis) + (0 if include_flow_bins else 1)
110+
else _underflow(self, axis) + (0 if flow else 1)
111111
)
112112
stop = (
113-
_process_index_for_axis(self, stop, axis, include_flow_bins, is_slice_stop=True)
113+
_process_index_for_axis(self, stop, axis, flow, is_slice_stop=True)
114114
if stop is not None
115-
else _overflow(self, axis) + (1 if include_flow_bins else 0)
115+
else _overflow(self, axis) + (1 if flow else 0)
116116
)
117117
if start < _underflow(self, axis) or stop > (_overflow(self, axis) + 1) or start > stop:
118118
raise IndexError(
@@ -205,14 +205,14 @@ def _slice_set(self, index, unprocessed_index, value):
205205
# Depending on the shape of the array provided, we can set or not the flow bins
206206
# Setting with a scalar does not set the flow bins
207207
# broadcasting an array to the shape of the slice does not set the flow bins neither
208-
include_flow_bins = False
208+
flow = False
209209
if isinstance(value, np.ndarray):
210210
processed_slices, _ = _get_processed_slices(self, index)
211211
slice_shape = tuple(stop - start for start, stop in processed_slices)
212-
include_flow_bins = value.size == np.prod(slice_shape)
212+
flow = value.size == np.prod(slice_shape)
213213

214-
if not include_flow_bins:
215-
index = _compute_common_index(self, unprocessed_index, include_flow_bins=False)
214+
if not flow:
215+
index = _compute_common_index(self, unprocessed_index, flow=False)
216216

217217
processed_slices, actions = _get_processed_slices(self, index)
218218
slice_shape = tuple(stop - start for start, stop in processed_slices)
@@ -252,6 +252,6 @@ def _setitem(self, index, value):
252252

253253

254254
def _iter(self):
255-
array = _values_by_copy(self, include_flow_bins=True)
255+
array = _values_by_copy(self, flow=True)
256256
for val in array.flat:
257257
yield val.item()

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/plotting.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -106,66 +106,74 @@ def _axes(self) -> Tuple[Union[PlottableAxisContinuous, PlottableAxisDiscrete],
106106

107107

108108
def _kind(self) -> Kind:
109-
return Kind.COUNT if not _hasWeights(self) else Kind.MEAN
109+
# TProfile -> MEAN, everything else -> COUNT
110+
if self.__class__.__name__.startswith("TProfile"):
111+
return Kind.MEAN
112+
return Kind.COUNT
110113

111114

112-
def _shape(hist: Any, include_flow_bins: bool = True) -> Tuple[int, ...]:
113-
return tuple(_get_axis_len(hist, i, include_flow_bins) for i in range(hist.GetDimension()))
115+
def _shape(hist: Any, flow: bool = True) -> Tuple[int, ...]:
116+
return tuple(_get_axis_len(hist, i, flow) for i in range(hist.GetDimension()))
114117

115118

116-
def _values_default(self) -> np.typing.NDArray[Any]: # noqa: F821
119+
def _values_default(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
117120
import numpy as np
118121

119122
llv = self.GetArray()
120123
ret = np.frombuffer(llv, dtype=llv.typecode, count=self.GetSize())
121-
return ret.reshape(_shape(self), order="F")[tuple([slice(1, -1)] * len(_shape(self)))]
124+
reshaped = ret.reshape(_shape(self), order="F")
122125

126+
if flow:
127+
# include all bins
128+
slices = tuple([slice(None)] * len(_shape(self)))
129+
else:
130+
# exclude underflow/overflow
131+
slices = tuple([slice(1, -1)] * len(_shape(self)))
123132

124-
def _values_by_copy(self, include_flow_bins=False) -> np.typing.NDArray[Any]: # noqa: F821
133+
return reshaped[slices]
134+
135+
136+
# Special case for TH*C and TProfile*
137+
def _values_by_copy(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
125138
from itertools import product
126139

127140
import numpy as np
128141

129-
offset = 0 if include_flow_bins else 1
130-
dimensions = [
131-
range(offset, _get_axis_len(self, axis, include_flow_bins=include_flow_bins) + offset)
132-
for axis in range(self.GetDimension())
133-
]
142+
offset = 0 if flow else 1
143+
dimensions = [range(offset, _get_axis_len(self, axis, flow=flow) + offset) for axis in range(self.GetDimension())]
134144
bin_combinations = product(*dimensions)
135145

136-
return np.array([self.GetBinContent(*bin) for bin in bin_combinations]).reshape(
137-
_shape(self, include_flow_bins=include_flow_bins)
138-
)
146+
return np.array([self.GetBinContent(*bin) for bin in bin_combinations]).reshape(_shape(self, flow=flow))
139147

140148

141-
def _variances(self) -> np.typing.NDArray[Any]: # noqa: F821
149+
def _variances(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
142150
import numpy as np
143151

144-
sum_of_weights = self.values()
152+
sum_of_weights = self.values(flow=flow)
145153

146154
if not _hasWeights(self) and _kind(self) == Kind.COUNT:
147155
return sum_of_weights
148156

149-
sum_of_weights_squared = _get_sum_of_weights_squared(self)
157+
sum_of_weights_squared = _get_sum_of_weights_squared(self, flow=flow)
150158

151159
if _kind(self) == Kind.MEAN:
152-
counts = self.counts()
160+
counts = self.counts(flow=flow)
153161
variances = sum_of_weights_squared.copy()
154162
variances[counts <= 1] = np.nan
155163
return variances
156164

157165
return sum_of_weights_squared
158166

159167

160-
def _counts(self) -> np.typing.NDArray[Any]: # noqa: F821
168+
def _counts(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
161169
import numpy as np
162170

163-
sum_of_weights = self.values()
171+
sum_of_weights = self.values(flow=flow)
164172

165173
if not _hasWeights(self):
166174
return sum_of_weights
167175

168-
sum_of_weights_squared = _get_sum_of_weights_squared(self)
176+
sum_of_weights_squared = _get_sum_of_weights_squared(self, flow=flow)
169177

170178
return np.divide(
171179
sum_of_weights**2,
@@ -175,13 +183,23 @@ def _counts(self) -> np.typing.NDArray[Any]: # noqa: F821
175183
)
176184

177185

178-
def _get_sum_of_weights_squared(self) -> np.typing.NDArray[Any]: # noqa: F821
186+
def _get_sum_of_weights(self) -> np.typing.NDArray[Any]: # noqa: F821
187+
return self.values()
188+
189+
def _get_sum_of_weights_squared(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
179190
import numpy as np
180191

181-
shape = _shape(self, include_flow_bins=False)
182192
sumw2_arr = np.frombuffer(
183193
self.GetSumw2().GetArray(),
184194
dtype=self.GetSumw2().GetArray().typecode,
185195
count=self.GetSumw2().GetSize(),
186196
)
187-
return sumw2_arr[tuple([slice(1, -1)] * len(shape))].reshape(shape, order="F") if sumw2_arr.size > 0 else sumw2_arr
197+
198+
reshaped = sumw2_arr.reshape(_shape(self, flow=True), order="F")
199+
200+
if flow:
201+
slices = tuple(slice(None) for _ in range(self.GetDimension()))
202+
else:
203+
slices = tuple(slice(1, -1) for _ in range(self.GetDimension()))
204+
205+
return reshaped[slices]

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/tags.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def _get_axis(self, axis):
1616
return getattr(self, f"Get{['X', 'Y', 'Z'][axis]}axis")()
1717

1818

19-
def _get_axis_len(self, axis, include_flow_bins=False):
20-
return _get_axis(self, axis).GetNbins() + (2 if include_flow_bins else 0)
19+
def _get_axis_len(self, axis, flow=False):
20+
return _get_axis(self, axis).GetNbins() + (2 if flow else 0)
2121

2222

2323
def _underflow(hist: Any, axis: int) -> int:

0 commit comments

Comments
 (0)