Skip to content

Commit 53f0a3d

Browse files
committed
backport latest flow fixes
1 parent c5a6ba1 commit 53f0a3d

File tree

3 files changed

+54
-40
lines changed

3 files changed

+54
-40
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/indexing.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _temporarily_disable_add_directory():
3434
ROOT.TH1.AddDirectory(old_status)
3535

3636

37-
def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice_stop=False):
37+
def _process_index_for_axis(self, index, axis, is_slice_stop=False):
3838
"""Process an index for a histogram axis handling callables and index shifting."""
3939
if callable(index):
4040
# If the index is a `loc`, `underflow`, `overflow`, or `len`
@@ -56,7 +56,7 @@ def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice
5656
raise index
5757

5858

59-
def _compute_uhi_index(self, index, axis, include_flow_bins=True):
59+
def _compute_uhi_index(self, index, axis, flow=True):
6060
"""Convert tag functors to valid bin indices."""
6161
if isinstance(index, _rebin) or index is _sum:
6262
index = slice(None, None, index)
@@ -65,13 +65,13 @@ def _compute_uhi_index(self, index, axis, include_flow_bins=True):
6565
return _process_index_for_axis(self, index, axis)
6666

6767
if isinstance(index, slice):
68-
start, stop = _resolve_slice_indices(self, index, axis, include_flow_bins)
68+
start, stop = _resolve_slice_indices(self, index, axis, flow)
6969
return slice(start, stop, index.step)
7070

7171
raise TypeError(f"Unsupported index type: {type(index).__name__}")
7272

7373

74-
def _compute_common_index(self, index, include_flow_bins=True):
74+
def _compute_common_index(self, index, flow=True):
7575
"""Normalize and expand the index to match the histogram dimension."""
7676
dim = self.GetDimension()
7777
if isinstance(index, dict):
@@ -93,26 +93,26 @@ def _compute_common_index(self, index, include_flow_bins=True):
9393
if len(index) != dim:
9494
raise IndexError(f"Expected {dim} indices, got {len(index)}")
9595

96-
return [_compute_uhi_index(self, idx, axis, include_flow_bins) for axis, idx in enumerate(index)]
96+
return [_compute_uhi_index(self, idx, axis, flow) for axis, idx in enumerate(index)]
9797

9898

9999
def _setbin(self, index, value):
100100
"""Set the bin content for a specific bin index"""
101101
self.SetBinContent(index, value)
102102

103103

104-
def _resolve_slice_indices(self, index, axis, include_flow_bins=True):
104+
def _resolve_slice_indices(self, index, axis, flow=True):
105105
"""Resolve slice start and stop indices for a given axis"""
106106
start, stop = index.start, index.stop
107107
start = (
108-
_process_index_for_axis(self, start, axis, include_flow_bins)
108+
_process_index_for_axis(self, start, axis, flow)
109109
if start is not None
110-
else _underflow(self, axis) + (0 if include_flow_bins else 1)
110+
else _underflow(self, axis) + (0 if flow else 1)
111111
)
112112
stop = (
113-
_process_index_for_axis(self, stop, axis, include_flow_bins, is_slice_stop=True)
113+
_process_index_for_axis(self, stop, axis, flow, is_slice_stop=True)
114114
if stop is not None
115-
else _overflow(self, axis) + (1 if include_flow_bins else 0)
115+
else _overflow(self, axis) + (1 if flow else 0)
116116
)
117117
if start < _underflow(self, axis) or stop > (_overflow(self, axis) + 1) or start > stop:
118118
raise IndexError(
@@ -205,14 +205,14 @@ def _slice_set(self, index, unprocessed_index, value):
205205
# Depending on the shape of the array provided, we can set or not the flow bins
206206
# Setting with a scalar does not set the flow bins
207207
# broadcasting an array to the shape of the slice does not set the flow bins neither
208-
include_flow_bins = False
208+
flow = False
209209
if isinstance(value, np.ndarray):
210210
processed_slices, _ = _get_processed_slices(self, index)
211211
slice_shape = tuple(stop - start for start, stop in processed_slices)
212-
include_flow_bins = value.size == np.prod(slice_shape)
212+
flow = value.size == np.prod(slice_shape)
213213

214-
if not include_flow_bins:
215-
index = _compute_common_index(self, unprocessed_index, include_flow_bins=False)
214+
if not flow:
215+
index = _compute_common_index(self, unprocessed_index, flow=False)
216216

217217
processed_slices, actions = _get_processed_slices(self, index)
218218
slice_shape = tuple(stop - start for start, stop in processed_slices)
@@ -252,6 +252,6 @@ def _setitem(self, index, value):
252252

253253

254254
def _iter(self):
255-
array = _values_by_copy(self, include_flow_bins=True)
255+
array = _values_by_copy(self, flow=True)
256256
for val in array.flat:
257257
yield val.item()

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/plotting.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -112,71 +112,76 @@ def _hasWeights(hist: Any) -> bool:
112112
def _axes(self) -> Tuple[Union[PlottableAxisContinuous, PlottableAxisDiscrete], ...]:
113113
return tuple(PlottableAxisFactory.create(_get_axis(self, i)) for i in range(self.GetDimension()))
114114

115+
115116
def _kind(self) -> Kind:
116117
# TProfile -> MEAN, everything else -> COUNT
117118
if self.__class__.__name__.startswith("TProfile"):
118119
return Kind.MEAN
119120
return Kind.COUNT
120121

121122

122-
def _shape(hist: Any, include_flow_bins: bool = True) -> Tuple[int, ...]:
123-
return tuple(_get_axis_len(hist, i, include_flow_bins) for i in range(hist.GetDimension()))
123+
def _shape(hist: Any, flow: bool = True) -> Tuple[int, ...]:
124+
return tuple(_get_axis_len(hist, i, flow) for i in range(hist.GetDimension()))
124125

125126

126-
def _values_default(self) -> np.typing.NDArray[Any]: # noqa: F821
127+
def _values_default(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
127128
import numpy as np
128129

129130
llv = self.GetArray()
130131
ret = np.frombuffer(llv, dtype=llv.typecode, count=self.GetSize())
131-
return ret.reshape(_shape(self), order="F")[tuple([slice(1, -1)] * len(_shape(self)))]
132+
reshaped = ret.reshape(_shape(self), order="F")
133+
134+
if flow:
135+
# include all bins
136+
slices = tuple([slice(None)] * len(_shape(self)))
137+
else:
138+
# exclude underflow/overflow
139+
slices = tuple([slice(1, -1)] * len(_shape(self)))
140+
141+
return reshaped[slices]
132142

133143

134144
# Special case for TH*C and TProfile*
135-
def _values_by_copy(self, include_flow_bins=False) -> np.typing.NDArray[Any]: # noqa: F821
145+
def _values_by_copy(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
136146
from itertools import product
137147

138148
import numpy as np
139149

140-
offset = 0 if include_flow_bins else 1
141-
dimensions = [
142-
range(offset, _get_axis_len(self, axis, include_flow_bins=include_flow_bins) + offset)
143-
for axis in range(self.GetDimension())
144-
]
150+
offset = 0 if flow else 1
151+
dimensions = [range(offset, _get_axis_len(self, axis, flow=flow) + offset) for axis in range(self.GetDimension())]
145152
bin_combinations = product(*dimensions)
146153

147-
return np.array([self.GetBinContent(*bin) for bin in bin_combinations]).reshape(
148-
_shape(self, include_flow_bins=include_flow_bins)
149-
)
154+
return np.array([self.GetBinContent(*bin) for bin in bin_combinations]).reshape(_shape(self, flow=flow))
150155

151156

152-
def _variances(self) -> np.typing.NDArray[Any]: # noqa: F821
157+
def _variances(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
153158
import numpy as np
154159

155-
sum_of_weights = self.values()
160+
sum_of_weights = self.values(flow=flow)
156161

157162
if not _hasWeights(self) and _kind(self) == Kind.COUNT:
158163
return sum_of_weights
159164

160-
sum_of_weights_squared = _get_sum_of_weights_squared(self)
165+
sum_of_weights_squared = _get_sum_of_weights_squared(self, flow=flow)
161166

162167
if _kind(self) == Kind.MEAN:
163-
counts = self.counts()
168+
counts = self.counts(flow=flow)
164169
variances = sum_of_weights_squared.copy()
165170
variances[counts <= 1] = np.nan
166171
return variances
167172

168173
return sum_of_weights_squared
169174

170175

171-
def _counts(self) -> np.typing.NDArray[Any]: # noqa: F821
176+
def _counts(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
172177
import numpy as np
173178

174-
sum_of_weights = self.values()
179+
sum_of_weights = self.values(flow=flow)
175180

176181
if not _hasWeights(self):
177182
return sum_of_weights
178183

179-
sum_of_weights_squared = _get_sum_of_weights_squared(self)
184+
sum_of_weights_squared = _get_sum_of_weights_squared(self, flow=flow)
180185

181186
return np.divide(
182187
sum_of_weights**2,
@@ -185,16 +190,25 @@ def _counts(self) -> np.typing.NDArray[Any]: # noqa: F821
185190
where=sum_of_weights_squared != 0,
186191
)
187192

193+
188194
def _get_sum_of_weights(self) -> np.typing.NDArray[Any]: # noqa: F821
189195
return self.values()
190196

191-
def _get_sum_of_weights_squared(self) -> np.typing.NDArray[Any]: # noqa: F821
197+
198+
def _get_sum_of_weights_squared(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
192199
import numpy as np
193200

194-
shape = _shape(self, include_flow_bins=False)
195201
sumw2_arr = np.frombuffer(
196202
self.GetSumw2().GetArray(),
197203
dtype=self.GetSumw2().GetArray().typecode,
198204
count=self.GetSumw2().GetSize(),
199205
)
200-
return sumw2_arr[tuple([slice(1, -1)] * len(shape))].reshape(shape, order="F") if sumw2_arr.size > 0 else sumw2_arr
206+
207+
reshaped = sumw2_arr.reshape(_shape(self, flow=True), order="F")
208+
209+
if flow:
210+
slices = tuple(slice(None) for _ in range(self.GetDimension()))
211+
else:
212+
slices = tuple(slice(1, -1) for _ in range(self.GetDimension()))
213+
214+
return reshaped[slices]

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/tags.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def _get_axis(self, axis):
1616
return getattr(self, f"Get{['X', 'Y', 'Z'][axis]}axis")()
1717

1818

19-
def _get_axis_len(self, axis, include_flow_bins=False):
20-
return _get_axis(self, axis).GetNbins() + (2 if include_flow_bins else 0)
19+
def _get_axis_len(self, axis, flow=False):
20+
return _get_axis(self, axis).GetNbins() + (2 if flow else 0)
2121

2222

2323
def _underflow(hist: Any, axis: int) -> int:

0 commit comments

Comments
 (0)