Skip to content

Commit e637767

Browse files
committed
[Python][UHI] Start introducing UHI serialization
1 parent 5dc4a61 commit e637767

File tree

9 files changed

+230
-8
lines changed

9 files changed

+230
-8
lines changed

bindings/pyroot/pythonizations/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ set(py_sources
127127
ROOT/_pythonization/_ttree.py
128128
ROOT/_pythonization/_tvector3.py
129129
ROOT/_pythonization/_tvectort.py
130-
ROOT/_pythonization/_uhi/main.py
130+
ROOT/_pythonization/_uhi/__init__.py
131131
ROOT/_pythonization/_uhi/tags.py
132132
ROOT/_pythonization/_uhi/indexing.py
133133
ROOT/_pythonization/_uhi/plotting.py
134+
ROOT/_pythonization/_uhi/serialization.py
134135
${PYROOT_EXTRA_PYTHON_SOURCES}
135136
)
136137

bindings/pyroot/pythonizations/python/ROOT/_facade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def uhi(self):
479479
uhi_module.__file__ = "<module ROOT>"
480480
uhi_module.__package__ = self
481481
try:
482-
from ._pythonization._uhi.main import _add_module_level_uhi_helpers
482+
from ._pythonization._uhi import _add_module_level_uhi_helpers
483483

484484
_add_module_level_uhi_helpers(uhi_module)
485485
except ImportError:

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_th1.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _FillWithArrayTH1(self, *args):
241241
for klass in _th1_derived_classes_to_pythonize:
242242
pythonization(klass)(inject_constructor_releasing_ownership)
243243

244-
from ROOT._pythonization._uhi.main import _add_plotting_features
244+
from ROOT._pythonization._uhi import _add_plotting_features
245245

246246
# Add UHI plotting features
247247
pythonization(klass)(_add_plotting_features)
@@ -252,20 +252,24 @@ def _enable_numpy_fill(klass):
252252
klass._Fill = klass.Fill
253253
klass.Fill = _FillWithArrayTH1
254254

255+
# Add serialization features
256+
from ROOT._pythonization._uhi import _add_serialization_features
257+
pythonization(klass)(_add_serialization_features)
258+
255259

256260
@pythonization("TH1")
257261
def pythonize_th1(klass):
258262
# Parameters:
259263
# klass: class to be pythonized
260-
from ROOT._pythonization._uhi.main import _add_indexing_features
264+
from ROOT._pythonization._uhi import _add_indexing_features
261265

262266
# Support hist *= scalar
263267
klass.__imul__ = _imul
264268

265269
klass._Original_SetDirectory = klass.SetDirectory
266270
klass.SetDirectory = _SetDirectory_SetOwnership
267271

268-
# Add UHI indexing features
272+
# Add UHI indexing and serialization features
269273
_add_indexing_features(klass)
270274

271275
inject_clone_releasing_ownership(klass)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_th2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _FillWithArrayTH2(self, *args):
7878
for klass in _th2_derived_classes_to_pythonize:
7979
pythonization(klass)(inject_constructor_releasing_ownership)
8080

81-
from ROOT._pythonization._uhi.main import _add_plotting_features
81+
from ROOT._pythonization._uhi import _add_plotting_features
8282

8383
# Add UHI plotting features
8484
pythonization(klass)(_add_plotting_features)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_th3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
for klass in _th3_derived_classes_to_pythonize:
2828
pythonization(klass)(inject_constructor_releasing_ownership)
2929

30-
from ROOT._pythonization._uhi.main import _add_plotting_features
30+
from ROOT._pythonization._uhi import _add_plotting_features
3131

3232
# Add UHI plotting features
3333
pythonization(klass)(_add_plotting_features)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/main.py renamed to bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,28 @@ def _add_plotting_features(klass: Any) -> None:
6464
klass.counts = _counts
6565
klass.axes = property(_axes)
6666
klass.values = values_func_dict.get(klass.__name__, _values_default)
67+
68+
69+
"""
70+
Implementation of the serialization component of the UHI
71+
"""
72+
73+
def _TH1_Constructor(self, *args, **kwargs):
74+
"""
75+
If UHI IR is detected, use the UHI deserialization constructor
76+
else forward to the original
77+
"""
78+
if len(args) == 1 and isinstance(args[0], dict):
79+
from .serialization import _from_uhi_
80+
81+
_from_uhi_(self, args[0])
82+
else:
83+
self._original_init_(*args, **kwargs)
84+
85+
def _add_serialization_features(klass: Any) -> None:
86+
from .serialization import _to_uhi_
87+
88+
klass._to_uhi_ = _to_uhi_
89+
90+
klass._original_init_ = klass.__init__
91+
klass.__init__ = _TH1_Constructor

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/plotting.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def circular(self) -> bool:
3838
def discrete(self) -> bool:
3939
return self._discrete
4040

41+
@property
42+
def underflow(self) -> bool:
43+
return True
44+
45+
@property
46+
def overflow(self) -> bool:
47+
return True
48+
4149

4250
class PlottableAxisBase(ABC):
4351
def __init__(self, tAxis: Any) -> None:
@@ -104,7 +112,7 @@ def _hasWeights(hist: Any) -> bool:
104112
def _axes(self) -> Tuple[Union[PlottableAxisContinuous, PlottableAxisDiscrete], ...]:
105113
return tuple(PlottableAxisFactory.create(_get_axis(self, i)) for i in range(self.GetDimension()))
106114

107-
115+
# TODO this is not correct?
108116
def _kind(self) -> Kind:
109117
# TProfile -> MEAN, everything else -> COUNT
110118
if self.__class__.__name__.startswith("TProfile"):
@@ -182,6 +190,8 @@ def _counts(self, flow=False) -> np.typing.NDArray[Any]: # noqa: F821
182190
where=sum_of_weights_squared != 0,
183191
)
184192

193+
def _get_sum_of_weights(self) -> np.typing.NDArray[Any]: # noqa: F821
194+
return self.values()
185195

186196
def _get_sum_of_weights(self) -> np.typing.NDArray[Any]: # noqa: F821
187197
return self.values()
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Author: Silia Taider CERN 10/2025
2+
3+
################################################################################
4+
# Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. #
5+
# All rights reserved. #
6+
# #
7+
# For the licensing terms see $ROOTSYS/LICENSE. #
8+
# For the list of contributors see $ROOTSYS/README/CREDITS. #
9+
################################################################################
10+
from __future__ import annotations
11+
12+
from typing import Any
13+
14+
import ROOT
15+
16+
from .plotting import PlottableAxisBase, _get_sum_of_weights, _get_sum_of_weights_squared, _hasWeights
17+
from .tags import _get_axis
18+
19+
"""
20+
Implementation of the serialization component of the UHI
21+
"""
22+
23+
24+
def _axis_to_dict(root_axis: ROOT.TAxis, uhi_axis: PlottableAxisBase) -> dict[str, Any]:
25+
"""
26+
Return a dictionary representation of the given ROOT axis.
27+
"""
28+
return {
29+
"type": "regular",
30+
"lower": root_axis.GetBinLowEdge(root_axis.GetFirst()),
31+
"upper": root_axis.GetBinUpEdge(root_axis.GetLast()),
32+
"bins": root_axis.GetNbins(),
33+
"underflow": uhi_axis.traits.underflow,
34+
"overflow": uhi_axis.traits.overflow,
35+
"circular": uhi_axis.traits.circular,
36+
}
37+
38+
39+
def _axis_from_dict(axis_dict: dict[str, Any]) -> list[Any]:
40+
"""
41+
Return the arguments needed to construct the corresponding ROOT histogram axis.
42+
For now only supports regular axes.
43+
"""
44+
45+
axis_type = axis_dict["type"]
46+
47+
if axis_type == "regular":
48+
nbins = axis_dict["bins"]
49+
lower = axis_dict["lower"]
50+
upper = axis_dict["upper"]
51+
return [nbins, lower, upper]
52+
53+
raise ValueError(f"Unsupported axis type for conversion to ROOT: {axis_type}")
54+
55+
56+
def _storage_to_dict(hist: Any) -> dict[str, Any]:
57+
"""
58+
Logic:
59+
- If histogram is a profile (TProfile*) --> Kind="MEAN":
60+
- if histogram has Sumw2: type is weighted_mean_storage (if _hasWeights(hist))
61+
- else: storage type is mean_storage
62+
- Else (TH1*/TH2*/TH3*) --> Kind="COUNT":
63+
- if histogram has Sumw2: type is weighted_storage
64+
- else if histogram is TH*I: type is int_storage
65+
- else: type is double_storage
66+
"""
67+
storage_dict = {
68+
"values": hist.values(),
69+
}
70+
71+
if hist.kind == "MEAN":
72+
storage_dict["variances"] = hist.variances()
73+
74+
if _hasWeights(hist):
75+
storage_dict["type"] = "weighted_mean"
76+
storage_dict["sum_of_weights"] = _get_sum_of_weights(hist)
77+
storage_dict["sum_of_weights_squared"] = _get_sum_of_weights_squared(hist)
78+
else:
79+
storage_dict["type"] = "mean"
80+
storage_dict["counts"] = hist.counts()
81+
82+
else: # COUNT
83+
if _hasWeights(hist):
84+
storage_dict["type"] = "weighted"
85+
storage_dict["variances"] = hist.variances()
86+
else:
87+
if hist.ClassName().endswith("I"):
88+
storage_dict["type"] = "int"
89+
else:
90+
storage_dict["type"] = "double"
91+
92+
return storage_dict
93+
94+
95+
def _set_histogram_storage_from_dict(hist: Any, storage_dict: dict[str, Any]) -> None:
96+
"""
97+
Set the histogram storage (values and statistics) from the given storage dictionary.
98+
"""
99+
hist_values = storage_dict["values"]
100+
hist[...] = hist_values
101+
102+
stype = storage_dict.get("type")
103+
if stype in ["weighted_mean", "mean", "weighted"]:
104+
hist.variances()[:] = storage_dict["variances"]
105+
106+
107+
def _to_uhi_(self) -> dict[str, Any]:
108+
return {
109+
"uhi_schema": 1,
110+
"writer_info": {"ROOT": {"version": ROOT.__version__, "class": self.ClassName(), "name": self.GetName()}},
111+
"axes": [_axis_to_dict(_get_axis(self, i), self.axes[i]) for i in range(self.GetDimension())],
112+
"storage": _storage_to_dict(self),
113+
}
114+
115+
116+
def _from_uhi_(self, uhi_dict: dict[str, Any]) -> ROOT.TH1:
117+
# rebuild axes
118+
axes = uhi_dict["axes"]
119+
axes_specs = [_axis_from_dict(axis_dict) for axis_dict in axes]
120+
121+
# construct the histogram
122+
ctor_args = ["h_uhi", "h_uhi"]
123+
for axis_spec in axes_specs:
124+
ctor_args.extend(axis_spec)
125+
126+
self._original_init_(*ctor_args)
127+
128+
# set storage
129+
_set_histogram_storage_from_dict(self, uhi_dict["storage"])
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import json
2+
3+
import numpy as np
4+
import ROOT
5+
import uhi.io.json
6+
7+
import hist
8+
9+
## to_uhi
10+
11+
h = ROOT.TH1D("h", "h", 10, -5, 5)
12+
h[...] = range(10)
13+
print("\nh =", h)
14+
print("values=", h.values())
15+
16+
ob = json.dumps(h, default=uhi.io.json.default)
17+
ir = json.loads(ob, object_hook=uhi.io.json.object_hook)
18+
19+
h_loaded = hist.Hist(ir)
20+
print("\nh_loaded =")
21+
print("values=", h_loaded.values())
22+
23+
24+
## from_uhi
25+
26+
h2 = hist.Hist(hist.axis.Regular(10, -5, 5, name="x"))
27+
h2[...] = range(10)
28+
print("\nh2 =")
29+
print("values=", h2.values())
30+
31+
ob2 = json.dumps(h2, default=uhi.io.json.default)
32+
ir2 = json.loads(ob2, object_hook=uhi.io.json.object_hook)
33+
34+
h2_loaded = ROOT.TH1D(ir2)
35+
print("\nh2_loaded =", h2_loaded)
36+
print("values=", h2_loaded.values())
37+
38+
39+
## 2D histogram
40+
h2d = hist.Hist(
41+
hist.axis.Regular(10, -5, 5, name="x"),
42+
hist.axis.Regular(10, -5, 5, name="y"),
43+
)
44+
h2d[...] = np.arange(100).reshape(10, 10)
45+
print("\nh2d =")
46+
print("values=", h2d.values())
47+
48+
ob3 = json.dumps(h2d, default=uhi.io.json.default)
49+
ir3 = json.loads(ob3, object_hook=uhi.io.json.object_hook)
50+
51+
h2d_loaded = hist.Hist(ir3)
52+
print("\nh2d_loaded =")
53+
print("values=", h2d_loaded.values())

0 commit comments

Comments
 (0)