Skip to content

Commit 559690c

Browse files
committed
[Python][UHI] Add serialization tests
1 parent 6ef1065 commit 559690c

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Tests to verify that TH1 and derived histograms conform to the UHI Serialization interfaces.
3+
"""
4+
"""
5+
Tests to verify that ROOT histograms can be serialized to and from
6+
the UHI JSON IR (round-trip) for 1D, 2D, and 3D histograms.
7+
"""
8+
9+
import json
10+
11+
import numpy as np
12+
import pytest
13+
import ROOT
14+
import uhi.io.json
15+
16+
17+
# Helpers
18+
# TODO import from plotting
19+
def _iterate_bins(hist, flow=False):
20+
ranges = [
21+
range(0 if flow else 1, hist.GetNbinsX() + (2 if flow else 1)),
22+
range(0 if flow else 1, hist.GetNbinsY() + (2 if flow else 1)) if hist.GetDimension() > 1 else [0],
23+
range(0 if flow else 1, hist.GetNbinsZ() + (2 if flow else 1)) if hist.GetDimension() > 2 else [0],
24+
]
25+
yield from ((i, j, k)[: hist.GetDimension()] for i in ranges[0] for j in ranges[1] for k in ranges[2])
26+
27+
28+
def _bin_contents(hist, flow=False):
29+
return np.array(
30+
[hist.GetBinContent(*idx) for idx in _iterate_bins(hist, flow=flow)]
31+
)
32+
33+
34+
def _roundtrip(hist):
35+
"""
36+
ROOT histogram -> JSON -> UHI IR -> ROOT histogram
37+
"""
38+
json_str = json.dumps(hist, default=uhi.io.json.default)
39+
ir = json.loads(json_str, object_hook=uhi.io.json.object_hook)
40+
return hist.__class__(ir)
41+
42+
43+
# Tests
44+
class TestTH1Serialization:
45+
def test_roundtrip_1d(self):
46+
h = ROOT.TH1D("h1", "h1", 10, -5, 5)
47+
h[...] = np.arange(10)
48+
49+
# test constructor based serialization roundtrip
50+
h_loaded = _roundtrip(h)
51+
52+
assert np.array_equal(
53+
_bin_contents(h, flow=False),
54+
_bin_contents(h_loaded, flow=False),
55+
)
56+
57+
# test classmethod based serialization roundtrip
58+
ir = h._to_uhi_()
59+
h_loaded_cls = ROOT.TH1D._from_uhi_(ir)
60+
assert np.array_equal(
61+
_bin_contents(h, flow=False),
62+
_bin_contents(h_loaded_cls, flow=False),
63+
)
64+
65+
def test_roundtrip_1d_flow(self):
66+
h = ROOT.TH1D("h1f", "h1f", 5, 0, 5)
67+
h.Fill(-1)
68+
h.Fill(6)
69+
h[...] = np.arange(5)
70+
71+
# test constructor based serialization roundtrip
72+
h_loaded = _roundtrip(h)
73+
74+
assert np.array_equal(
75+
_bin_contents(h, flow=True),
76+
_bin_contents(h_loaded, flow=True),
77+
)
78+
79+
# test classmethod based serialization roundtrip
80+
ir = h._to_uhi_()
81+
h_loaded_cls = ROOT.TH1D._from_uhi_(ir)
82+
assert np.array_equal(
83+
_bin_contents(h, flow=True),
84+
_bin_contents(h_loaded_cls, flow=True),
85+
)
86+
87+
88+
class TestTH2Serialization:
89+
def test_roundtrip_2d(self):
90+
h = ROOT.TH2D("h2", "h2", 4, 0, 4, 3, 0, 3)
91+
92+
values = np.arange(12).reshape(4, 3)
93+
h[...] = values
94+
95+
h_loaded = _roundtrip(h)
96+
97+
assert np.array_equal(
98+
_bin_contents(h, flow=False),
99+
_bin_contents(h_loaded, flow=False),
100+
)
101+
102+
def test_roundtrip_2d_flow(self):
103+
h = ROOT.TH2D("h2f", "h2f", 3, 0, 3, 2, 0, 2)
104+
h.Fill(-1, 1)
105+
h.Fill(4, 3)
106+
h[...] = np.arange(6).reshape(3, 2)
107+
108+
h_loaded = _roundtrip(h)
109+
110+
assert np.array_equal(
111+
_bin_contents(h, flow=True),
112+
_bin_contents(h_loaded, flow=True),
113+
)
114+
115+
116+
class TestTH3Serialization:
117+
def test_roundtrip_3d(self):
118+
h = ROOT.TH3D("h3", "h3", 3, 0, 3, 2, 0, 2, 2, 0, 2)
119+
120+
values = np.arange(12).reshape(3, 2, 2)
121+
h[...] = values
122+
123+
h_loaded = _roundtrip(h)
124+
125+
assert np.array_equal(
126+
_bin_contents(h, flow=False),
127+
_bin_contents(h_loaded, flow=False),
128+
)
129+
130+
def test_roundtrip_3d_flow(self):
131+
h = ROOT.TH3D("h3f", "h3f", 2, 0, 2, 2, 0, 2, 2, 0, 2)
132+
h.Fill(-1, 0, 0)
133+
h.Fill(3, 3, 3)
134+
h[...] = np.arange(8).reshape(2, 2, 2)
135+
136+
h_loaded = _roundtrip(h)
137+
138+
assert np.array_equal(
139+
_bin_contents(h, flow=True),
140+
_bin_contents(h_loaded, flow=True),
141+
)
142+
143+
144+
if __name__ == "__main__":
145+
raise SystemExit(pytest.main(args=[__file__]))

0 commit comments

Comments
 (0)