Skip to content

Commit 5a1155c

Browse files
committed
[hist] Implement translation from RBinIndex to RLinearizedIndex
1 parent c620481 commit 5a1155c

File tree

6 files changed

+235
-6
lines changed

6 files changed

+235
-6
lines changed

hist/histv7/inc/ROOT/RAxes.hxx

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
#ifndef ROOT_RAxes
66
#define ROOT_RAxes
77

8+
#include "RBinIndex.hxx"
89
#include "RLinearizedIndex.hxx"
910
#include "RRegularAxis.hxx"
1011
#include "RVariableBinAxis.hxx"
1112

13+
#include <array>
14+
#include <cassert>
1215
#include <stdexcept>
1316
#include <tuple>
1417
#include <utility>
@@ -104,6 +107,38 @@ public:
104107
return ComputeGlobalIndex<0, A...>(0, args);
105108
}
106109

110+
/// Compute the global index for all axes.
111+
///
112+
/// \param[in] indices the array of RBinIndex
113+
/// \return the global index that may be invalid
114+
template <std::size_t N>
115+
RLinearizedIndex ComputeGlobalIndex(const std::array<RBinIndex, N> &indices) const
116+
{
117+
if (N != fAxes.size()) {
118+
throw std::invalid_argument("invalid number of indices passed to ComputeGlobalIndex");
119+
}
120+
std::size_t globalIndex = 0;
121+
for (std::size_t i = 0; i < N; i++) {
122+
const auto &index = indices[i];
123+
const auto &axis = fAxes[i];
124+
RLinearizedIndex linIndex;
125+
if (auto *regular = std::get_if<RRegularAxis>(&axis)) {
126+
globalIndex *= regular->GetTotalNBins();
127+
linIndex = regular->GetLinearizedIndex(index);
128+
} else if (auto *variable = std::get_if<RVariableBinAxis>(&axis)) {
129+
globalIndex *= variable->GetTotalNBins();
130+
linIndex = variable->GetLinearizedIndex(index);
131+
} else {
132+
throw std::logic_error("unimplemented axis type");
133+
}
134+
if (!linIndex.fValid) {
135+
return {0, false};
136+
}
137+
globalIndex += linIndex.fIndex;
138+
}
139+
return {globalIndex, true};
140+
}
141+
107142
/// ROOT Streamer function to throw when trying to store an object of this class.
108143
void Streamer(TBuffer &) { throw std::runtime_error("unable to store RAxes"); }
109144
};

hist/histv7/inc/ROOT/RRegularAxis.hxx

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
#ifndef ROOT_RRegularAxis
66
#define ROOT_RRegularAxis
77

8+
#include "RBinIndex.hxx"
89
#include "RLinearizedIndex.hxx"
910

11+
#include <cassert>
1012
#include <cstddef>
1113
#include <stdexcept>
1214
#include <string>
@@ -97,6 +99,27 @@ public:
9799
return {bin, true};
98100
}
99101

102+
/// Get the linearized index for an RBinIndex.
103+
///
104+
/// The normal bins have indices \f$0\f$ to \f$fNNormalBins - 1\f$, the underflow bin has index
105+
/// \f$fNNormalBins\f$, and the overflow bin has index \f$fNNormalBins + 1\f$.
106+
///
107+
/// \param[in] index the RBinIndex
108+
/// \return the linearized index that may be invalid
109+
RLinearizedIndex GetLinearizedIndex(RBinIndex index) const
110+
{
111+
if (index.IsUnderflow()) {
112+
return {fNNormalBins, fEnableFlowBins};
113+
} else if (index.IsOverflow()) {
114+
return {fNNormalBins + 1, fEnableFlowBins};
115+
} else if (index.IsInvalid()) {
116+
return {0, false};
117+
}
118+
assert(index.IsNormal());
119+
std::size_t bin = index.GetIndex();
120+
return {bin, bin < fNNormalBins};
121+
}
122+
100123
/// ROOT Streamer function to throw when trying to store an object of this class.
101124
void Streamer(TBuffer &) { throw std::runtime_error("unable to store RRegularAxis"); }
102125
};

hist/histv7/inc/ROOT/RVariableBinAxis.hxx

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
#ifndef ROOT_RVariableBinAxis
66
#define ROOT_RVariableBinAxis
77

8+
#include "RBinIndex.hxx"
89
#include "RLinearizedIndex.hxx"
910

11+
#include <cassert>
1012
#include <cstddef>
1113
#include <stdexcept>
1214
#include <string>
@@ -98,6 +100,27 @@ public:
98100
return {bin, true};
99101
}
100102

103+
/// Get the linearized index for an RBinIndex.
104+
///
105+
/// The normal bins have indices \f$0\f$ to \f$fBinEdges.size() - 2\f$, the underflow bin has index
106+
/// \f$fBinEdges.size() - 1\f$, and the overflow bin has index \f$fBinEdges.size()\f$.
107+
///
108+
/// \param[in] index the RBinIndex
109+
/// \return the linearized index that may be invalid
110+
RLinearizedIndex GetLinearizedIndex(RBinIndex index) const
111+
{
112+
if (index.IsUnderflow()) {
113+
return {fBinEdges.size() - 1, fEnableFlowBins};
114+
} else if (index.IsOverflow()) {
115+
return {fBinEdges.size(), fEnableFlowBins};
116+
} else if (index.IsInvalid()) {
117+
return {0, false};
118+
}
119+
assert(index.IsNormal());
120+
std::size_t bin = index.GetIndex();
121+
return {bin, bin < fBinEdges.size() - 1};
122+
}
123+
101124
/// ROOT Streamer function to throw when trying to store an object of this class.
102125
void Streamer(TBuffer &) { throw std::runtime_error("unable to store RVariableBinAxis"); }
103126
};

hist/histv7/test/hist_axes.cxx

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "hist_test.hxx"
22

3+
#include <array>
34
#include <stdexcept>
45
#include <tuple>
56
#include <variant>
@@ -96,21 +97,33 @@ TEST(RAxes, ComputeGlobalIndex)
9697
const RAxes axes({regularAxis, variableBinAxis});
9798

9899
{
99-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 2.5));
100+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 2.5));
101+
EXPECT_EQ(globalIndex.fIndex, 1 * (BinsY + 2) + 2);
102+
EXPECT_TRUE(globalIndex.fValid);
103+
const std::array<RBinIndex, 2> indices = {1, 2};
104+
globalIndex = axes.ComputeGlobalIndex(indices);
100105
EXPECT_EQ(globalIndex.fIndex, 1 * (BinsY + 2) + 2);
101106
EXPECT_TRUE(globalIndex.fValid);
102107
}
103108

104109
{
105110
// Underflow bin of the first axis.
106-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(-1, 2.5));
111+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(-1, 2.5));
112+
EXPECT_EQ(globalIndex.fIndex, BinsX * (BinsY + 2) + 2);
113+
EXPECT_TRUE(globalIndex.fValid);
114+
const std::array<RBinIndex, 2> indices = {RBinIndex::Underflow(), 2};
115+
globalIndex = axes.ComputeGlobalIndex(indices);
107116
EXPECT_EQ(globalIndex.fIndex, BinsX * (BinsY + 2) + 2);
108117
EXPECT_TRUE(globalIndex.fValid);
109118
}
110119

111120
{
112121
// Overflow bin of the second axis.
113-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 42));
122+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 42));
123+
EXPECT_EQ(globalIndex.fIndex, 1 * (BinsY + 2) + BinsY + 1);
124+
EXPECT_TRUE(globalIndex.fValid);
125+
const std::array<RBinIndex, 2> indices = {1, RBinIndex::Overflow()};
126+
globalIndex = axes.ComputeGlobalIndex(indices);
114127
EXPECT_EQ(globalIndex.fIndex, 1 * (BinsY + 2) + BinsY + 1);
115128
EXPECT_TRUE(globalIndex.fValid);
116129
}
@@ -131,21 +144,33 @@ TEST(RAxes, ComputeGlobalIndexNoFlowBins)
131144
ASSERT_EQ(axes.ComputeTotalNBins(), BinsX * BinsY);
132145

133146
{
134-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 2.5));
147+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 2.5));
148+
EXPECT_EQ(globalIndex.fIndex, 1 * BinsY + 2);
149+
EXPECT_TRUE(globalIndex.fValid);
150+
const std::array<RBinIndex, 2> indices = {1, 2};
151+
globalIndex = axes.ComputeGlobalIndex(indices);
135152
EXPECT_EQ(globalIndex.fIndex, 1 * BinsY + 2);
136153
EXPECT_TRUE(globalIndex.fValid);
137154
}
138155

139156
{
140157
// Underflow bin of the first axis.
141-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(-1, 2.5));
158+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(-1, 2.5));
159+
EXPECT_EQ(globalIndex.fIndex, 0);
160+
EXPECT_FALSE(globalIndex.fValid);
161+
const std::array<RBinIndex, 2> indices = {RBinIndex::Underflow(), 2};
162+
globalIndex = axes.ComputeGlobalIndex(indices);
142163
EXPECT_EQ(globalIndex.fIndex, 0);
143164
EXPECT_FALSE(globalIndex.fValid);
144165
}
145166

146167
{
147168
// Overflow bin of the second axis.
148-
const auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 42));
169+
auto globalIndex = axes.ComputeGlobalIndex(std::make_tuple(1.5, 42));
170+
EXPECT_EQ(globalIndex.fIndex, 0);
171+
EXPECT_FALSE(globalIndex.fValid);
172+
const std::array<RBinIndex, 2> indices = {1, RBinIndex::Overflow()};
173+
globalIndex = axes.ComputeGlobalIndex(indices);
149174
EXPECT_EQ(globalIndex.fIndex, 0);
150175
EXPECT_FALSE(globalIndex.fValid);
151176
}
@@ -166,4 +191,15 @@ TEST(RAxes, ComputeGlobalIndexInvalidNumberOfArguments)
166191
EXPECT_THROW(axes2.ComputeGlobalIndex(std::make_tuple(1)), std::invalid_argument);
167192
EXPECT_NO_THROW(axes2.ComputeGlobalIndex(std::make_tuple(1, 2)));
168193
EXPECT_THROW(axes2.ComputeGlobalIndex(std::make_tuple(1, 2, 3)), std::invalid_argument);
194+
195+
const std::array<RBinIndex, 1> indices1 = {1};
196+
const std::array<RBinIndex, 2> indices2 = {1, 2};
197+
const std::array<RBinIndex, 3> indices3 = {1, 2, 3};
198+
199+
EXPECT_NO_THROW(axes1.ComputeGlobalIndex(indices1));
200+
EXPECT_THROW(axes1.ComputeGlobalIndex(indices2), std::invalid_argument);
201+
202+
EXPECT_THROW(axes2.ComputeGlobalIndex(indices1), std::invalid_argument);
203+
EXPECT_NO_THROW(axes2.ComputeGlobalIndex(indices2));
204+
EXPECT_THROW(axes2.ComputeGlobalIndex(indices3), std::invalid_argument);
169205
}

hist/histv7/test/hist_regular.cxx

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,56 @@ TEST(RRegularAxis, ComputeLinearizedIndex)
110110
EXPECT_FALSE(linIndex.fValid);
111111
}
112112
}
113+
114+
TEST(RRegularAxis, GetLinearizedIndex)
115+
{
116+
static constexpr std::size_t Bins = 20;
117+
const RRegularAxis axis(Bins, 0, Bins);
118+
const RRegularAxis axisNoFlowBins(Bins, 0, Bins, /*enableFlowBins=*/false);
119+
120+
{
121+
const auto underflow = RBinIndex::Underflow();
122+
auto linIndex = axis.GetLinearizedIndex(underflow);
123+
EXPECT_EQ(linIndex.fIndex, Bins);
124+
EXPECT_TRUE(linIndex.fValid);
125+
linIndex = axisNoFlowBins.GetLinearizedIndex(underflow);
126+
EXPECT_EQ(linIndex.fIndex, Bins);
127+
EXPECT_FALSE(linIndex.fValid);
128+
}
129+
130+
for (std::size_t i = 0; i < Bins; i++) {
131+
auto linIndex = axis.GetLinearizedIndex(i);
132+
EXPECT_EQ(linIndex.fIndex, i);
133+
EXPECT_TRUE(linIndex.fValid);
134+
linIndex = axisNoFlowBins.GetLinearizedIndex(i);
135+
EXPECT_EQ(linIndex.fIndex, i);
136+
EXPECT_TRUE(linIndex.fValid);
137+
}
138+
139+
// Out of bounds
140+
{
141+
auto linIndex = axis.GetLinearizedIndex(Bins);
142+
EXPECT_EQ(linIndex.fIndex, Bins);
143+
EXPECT_FALSE(linIndex.fValid);
144+
linIndex = axisNoFlowBins.GetLinearizedIndex(Bins);
145+
EXPECT_EQ(linIndex.fIndex, Bins);
146+
EXPECT_FALSE(linIndex.fValid);
147+
}
148+
149+
{
150+
const auto overflow = RBinIndex::Overflow();
151+
auto linIndex = axis.GetLinearizedIndex(overflow);
152+
EXPECT_TRUE(linIndex.fValid);
153+
EXPECT_EQ(linIndex.fIndex, Bins + 1);
154+
linIndex = axisNoFlowBins.GetLinearizedIndex(overflow);
155+
EXPECT_FALSE(linIndex.fValid);
156+
}
157+
158+
{
159+
const RBinIndex invalid;
160+
auto linIndex = axis.GetLinearizedIndex(invalid);
161+
EXPECT_FALSE(linIndex.fValid);
162+
linIndex = axisNoFlowBins.GetLinearizedIndex(invalid);
163+
EXPECT_FALSE(linIndex.fValid);
164+
}
165+
}

hist/histv7/test/hist_variable.cxx

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,62 @@ TEST(RVariableBinAxis, ComputeLinearizedIndex)
135135
EXPECT_FALSE(linIndex.fValid);
136136
}
137137
}
138+
139+
TEST(RVariableBinAxis, GetLinearizedIndex)
140+
{
141+
static constexpr std::size_t Bins = 20;
142+
std::vector<double> bins;
143+
for (std::size_t i = 0; i < Bins; i++) {
144+
bins.push_back(i);
145+
}
146+
bins.push_back(Bins);
147+
148+
const RVariableBinAxis axis(bins);
149+
const RVariableBinAxis axisNoFlowBins(bins, /*enableFlowBins=*/false);
150+
151+
{
152+
const auto underflow = RBinIndex::Underflow();
153+
auto linIndex = axis.GetLinearizedIndex(underflow);
154+
EXPECT_EQ(linIndex.fIndex, Bins);
155+
EXPECT_TRUE(linIndex.fValid);
156+
linIndex = axisNoFlowBins.GetLinearizedIndex(underflow);
157+
EXPECT_EQ(linIndex.fIndex, Bins);
158+
EXPECT_FALSE(linIndex.fValid);
159+
}
160+
161+
for (std::size_t i = 0; i < Bins; i++) {
162+
auto linIndex = axis.GetLinearizedIndex(i);
163+
EXPECT_EQ(linIndex.fIndex, i);
164+
EXPECT_TRUE(linIndex.fValid);
165+
linIndex = axisNoFlowBins.GetLinearizedIndex(i);
166+
EXPECT_EQ(linIndex.fIndex, i);
167+
EXPECT_TRUE(linIndex.fValid);
168+
}
169+
170+
// Out of bounds
171+
{
172+
auto linIndex = axis.GetLinearizedIndex(Bins);
173+
EXPECT_EQ(linIndex.fIndex, Bins);
174+
EXPECT_FALSE(linIndex.fValid);
175+
linIndex = axisNoFlowBins.GetLinearizedIndex(Bins);
176+
EXPECT_EQ(linIndex.fIndex, Bins);
177+
EXPECT_FALSE(linIndex.fValid);
178+
}
179+
180+
{
181+
const auto overflow = RBinIndex::Overflow();
182+
auto linIndex = axis.GetLinearizedIndex(overflow);
183+
EXPECT_TRUE(linIndex.fValid);
184+
EXPECT_EQ(linIndex.fIndex, Bins + 1);
185+
linIndex = axisNoFlowBins.GetLinearizedIndex(overflow);
186+
EXPECT_FALSE(linIndex.fValid);
187+
}
188+
189+
{
190+
const RBinIndex invalid;
191+
auto linIndex = axis.GetLinearizedIndex(invalid);
192+
EXPECT_FALSE(linIndex.fValid);
193+
linIndex = axisNoFlowBins.GetLinearizedIndex(invalid);
194+
EXPECT_FALSE(linIndex.fValid);
195+
}
196+
}

0 commit comments

Comments
 (0)