Skip to content

Commit 4bbcda9

Browse files
committed
[RF] Make the RooFit math functions have templated array element types
The array element types for the RooFit math functions should be templated, so that we can try out other element types for intermediate arrays, like `std::reference_wrapper<double>`. After this commit, changing the element type for all intermediate arrays in the generated code would be a one-line change. It was validated with the ATLAS Higgs combination likelihood that using templated functions has no effect on gradient generation time and runtime.
1 parent 5baeea4 commit 4bbcda9

File tree

4 files changed

+55
-45
lines changed

4 files changed

+55
-45
lines changed

roofit/codegen/src/CodegenImpl.cxx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ void codegenImpl(RooFormulaVar &arg, CodegenContext &ctx)
381381
arg.getVal(); // to trigger the creation of the TFormula
382382
std::string funcName = arg.getUniqueFuncName();
383383
ctx.collectFunction(funcName);
384-
ctx.addResult(&arg, ctx.buildCall(funcName, arg.dependents()));
384+
// We have to force the array type to be "double" because that's what the
385+
// declared function wrapped by the TFormula expects.
386+
auto inputVar = ctx.buildArg(arg.dependents(), /*arrayType=*/"double");
387+
ctx.addResult(&arg, funcName + "(" + inputVar + ")");
385388
}
386389

387390
void codegenImpl(RooEffProd &arg, CodegenContext &ctx)
@@ -424,7 +427,10 @@ void codegenImpl(RooGenericPdf &arg, CodegenContext &ctx)
424427
arg.getVal(); // to trigger the creation of the TFormula
425428
std::string funcName = arg.getUniqueFuncName();
426429
ctx.collectFunction(funcName);
427-
ctx.addResult(&arg, ctx.buildCall(funcName, arg.dependents()));
430+
// We have to force the array type to be "double" because that's what the
431+
// declared function wrapped by the TFormula expects.
432+
auto inputVar = ctx.buildArg(arg.dependents(), /*arrayType=*/"double");
433+
ctx.addResult(&arg, funcName + "(" + inputVar + ")");
428434
}
429435

430436
void codegenImpl(RooHistFunc &arg, CodegenContext &ctx)

roofit/roofitcore/inc/RooFit/CodegenContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class CodegenContext {
109109

110110
std::string getTmpVarName() const;
111111

112-
std::string buildArg(RooAbsCollection const &x);
112+
std::string buildArg(RooAbsCollection const &x, std::string const &arrayType = "double");
113113

114114
std::string buildArg(std::span<const double> arr);
115115
std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }

roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
#include <algorithm>
2222
#include <cmath>
2323
#include <stdexcept>
24-
namespace RooFit {
25-
namespace Detail {
26-
namespace MathFuncs {
24+
25+
namespace RooFit::Detail::MathFuncs {
2726

2827
/// Calculates the binomial coefficient n over k.
2928
/// Equivalent to TMath::Binomial, but inlined.
@@ -44,7 +43,8 @@ inline double binomial(int n, int k)
4443
}
4544

4645
/// The caller needs to make sure that there is at least one coefficient.
47-
inline double bernstein(double x, double xmin, double xmax, double *coefs, int nCoefs)
46+
template <typename DoubleArray>
47+
double bernstein(double x, double xmin, double xmax, DoubleArray coefs, int nCoefs)
4848
{
4949
double xScaled = (x - xmin) / (xmax - xmin); // rescale to [0,1]
5050
int degree = nCoefs - 1; // n+1 polys of degree n
@@ -89,7 +89,8 @@ inline double gaussian(double x, double mean, double sigma)
8989
return std::exp(-0.5 * arg * arg / (sig * sig));
9090
}
9191

92-
inline double product(double const *factors, std::size_t nFactors)
92+
template <typename DoubleArray>
93+
double product(DoubleArray factors, std::size_t nFactors)
9394
{
9495
double out = 1.0;
9596
for (std::size_t i = 0; i < nFactors; ++i) {
@@ -125,8 +126,8 @@ inline double efficiency(double effFuncVal, int catIndex, int sigCatIndex)
125126
}
126127

127128
/// In pdfMode, a coefficient for the constant term of 1.0 is implied if lowestOrder > 0.
128-
template <bool pdfMode = false>
129-
inline double polynomial(double const *coeffs, int nCoeffs, int lowestOrder, double x)
129+
template <bool pdfMode = false, typename DoubleArray>
130+
double polynomial(DoubleArray coeffs, int nCoeffs, int lowestOrder, double x)
130131
{
131132
double retVal = coeffs[nCoeffs - 1];
132133
for (int i = nCoeffs - 2; i >= 0; i--) {
@@ -136,7 +137,8 @@ inline double polynomial(double const *coeffs, int nCoeffs, int lowestOrder, dou
136137
return retVal + (pdfMode && lowestOrder > 0 ? 1.0 : 0.0);
137138
}
138139

139-
inline double chebychev(double *coeffs, unsigned int nCoeffs, double x_in, double xMin, double xMax)
140+
template <typename DoubleArray>
141+
double chebychev(DoubleArray coeffs, unsigned int nCoeffs, double x_in, double xMin, double xMax)
140142
{
141143
// transform to range [-1, +1]
142144
const double xPrime = (x_in - 0.5 * (xMax + xMin)) / (0.5 * (xMax - xMin));
@@ -160,7 +162,8 @@ inline double chebychev(double *coeffs, unsigned int nCoeffs, double x_in, doubl
160162
return sum;
161163
}
162164

163-
inline double multipdf(int idx, double const *pdfs)
165+
template <typename DoubleArray>
166+
double multipdf(int idx, DoubleArray pdfs)
164167
{
165168
/* if (idx < 0 || idx >= static_cast<int>(pdfs.size())){
166169
throw std::out_of_range("Invalid PDF index");
@@ -169,7 +172,9 @@ inline double multipdf(int idx, double const *pdfs)
169172
*/
170173
return pdfs[idx];
171174
}
172-
inline double constraintSum(double const *comp, unsigned int compSize)
175+
176+
template <typename DoubleArray>
177+
double constraintSum(DoubleArray comp, unsigned int compSize)
173178
{
174179
double sum = 0;
175180
for (unsigned int i = 0; i < compSize; i++) {
@@ -184,26 +189,28 @@ inline unsigned int uniformBinNumber(double low, double high, double val, unsign
184189
return coef * (val >= high ? numBins - 1 : std::abs((val - low) / binWidth));
185190
}
186191

187-
inline unsigned int rawBinNumber(double x, double const *boundaries, std::size_t nBoundaries)
192+
template <typename DoubleArray>
193+
unsigned int rawBinNumber(double x, DoubleArray boundaries, std::size_t nBoundaries)
188194
{
189-
double const *end = boundaries + nBoundaries;
190-
double const *it = std::lower_bound(boundaries, end, x);
195+
DoubleArray end = boundaries + nBoundaries;
196+
DoubleArray it = std::lower_bound(boundaries, end, x);
191197
// always return valid bin number
192198
while (boundaries != it && (end == it || end == it + 1 || x < *it)) {
193199
--it;
194200
}
195201
return it - boundaries;
196202
}
197203

198-
inline unsigned int
199-
binNumber(double x, double coef, double const *boundaries, unsigned int nBoundaries, int nbins, int blo)
204+
template <typename DoubleArray>
205+
unsigned int binNumber(double x, double coef, DoubleArray boundaries, unsigned int nBoundaries, int nbins, int blo)
200206
{
201207
const int rawBin = rawBinNumber(x, boundaries, nBoundaries);
202208
int tmp = std::min(nbins, rawBin - blo);
203209
return coef * std::max(0, tmp);
204210
}
205211

206-
inline double interpolate1d(double low, double high, double val, unsigned int numBins, double const *vals)
212+
template <typename DoubleArray>
213+
double interpolate1d(double low, double high, double val, unsigned int numBins, DoubleArray vals)
207214
{
208215
double binWidth = (high - low) / numBins;
209216
int idx = val >= high ? numBins - 1 : std::abs((val - low) / binWidth);
@@ -352,8 +359,9 @@ inline double flexibleInterpSingle(unsigned int code, double low, double high, d
352359
return 0.0;
353360
}
354361

355-
inline double flexibleInterp(unsigned int code, double const *params, unsigned int n, double const *low,
356-
double const *high, double boundary, double nominal, int doCutoff)
362+
template <typename ParamsArray, typename DoubleArray>
363+
double flexibleInterp(unsigned int code, ParamsArray params, unsigned int n, DoubleArray low, DoubleArray high,
364+
double boundary, double nominal, int doCutoff)
357365
{
358366
double total = nominal;
359367
for (std::size_t i = 0; i < n; ++i) {
@@ -403,7 +411,8 @@ inline double nll(double pdf, double weight, int binnedL, int doBinOffset)
403411
}
404412
}
405413

406-
inline double recursiveFraction(double *a, unsigned int n)
414+
template <typename DoubleArray>
415+
double recursiveFraction(DoubleArray a, unsigned int n)
407416
{
408417
double prod = a[0];
409418

@@ -512,8 +521,8 @@ inline double exponentialIntegral(double xMin, double xMax, double constant)
512521
}
513522

514523
/// In pdfMode, a coefficient for the constant term of 1.0 is implied if lowestOrder > 0.
515-
template <bool pdfMode = false>
516-
inline double polynomialIntegral(double const *coeffs, int nCoeffs, int lowestOrder, double xMin, double xMax)
524+
template <bool pdfMode = false, typename DoubleArray>
525+
double polynomialIntegral(DoubleArray coeffs, int nCoeffs, int lowestOrder, double xMin, double xMax)
517526
{
518527
int denom = lowestOrder + nCoeffs;
519528
double min = coeffs[nCoeffs - 1] / double(denom);
@@ -545,8 +554,9 @@ inline double fast_fma(double x, double y, double z) noexcept
545554
#endif // defined(FP_FAST_FMA)
546555
}
547556

548-
inline double chebychevIntegral(double const *coeffs, unsigned int nCoeffs, double xMin, double xMax, double xMinFull,
549-
double xMaxFull)
557+
template <typename DoubleArray>
558+
double
559+
chebychevIntegral(DoubleArray coeffs, unsigned int nCoeffs, double xMin, double xMax, double xMinFull, double xMaxFull)
550560
{
551561
const double halfrange = .5 * (xMax - xMin);
552562
const double mid = .5 * (xMax + xMin);
@@ -744,7 +754,8 @@ inline double cbShapeIntegral(double mMin, double mMax, double m0, double sigma,
744754
return result;
745755
}
746756

747-
inline double bernsteinIntegral(double xlo, double xhi, double xmin, double xmax, double *coefs, int nCoefs)
757+
template <typename DoubleArray>
758+
double bernsteinIntegral(double xlo, double xhi, double xmin, double xmax, DoubleArray coefs, int nCoefs)
748759
{
749760
double xloScaled = (xlo - xmin) / (xmax - xmin);
750761
double xhiScaled = (xhi - xmin) / (xmax - xmin);
@@ -770,7 +781,8 @@ inline double bernsteinIntegral(double xlo, double xhi, double xmin, double xmax
770781
return norm * (xmax - xmin);
771782
}
772783

773-
inline double multiVarGaussian(int n, const double *x, const double *mu, const double *covI)
784+
template <typename DoubleArray>
785+
double multiVarGaussian(int n, DoubleArray x, DoubleArray mu, DoubleArray covI)
774786
{
775787
double result = 0.0;
776788

@@ -786,8 +798,8 @@ inline double multiVarGaussian(int n, const double *x, const double *mu, const d
786798
// Integral of a step function defined by `nBins` intervals, where the
787799
// intervals have values `coefs` and the boundary on the interval `iBin` is
788800
// given by `[boundaries[i], boundaries[i+1])`.
789-
inline double
790-
stepFunctionIntegral(double xmin, double xmax, std::size_t nBins, double const *boundaries, double const *coefs)
801+
template <typename DoubleArray>
802+
double stepFunctionIntegral(double xmin, double xmax, std::size_t nBins, DoubleArray boundaries, DoubleArray coefs)
791803
{
792804
double out = 0.0;
793805
for (std::size_t i = 0; i < nBins; ++i) {
@@ -798,15 +810,10 @@ stepFunctionIntegral(double xmin, double xmax, std::size_t nBins, double const *
798810
return out;
799811
}
800812

801-
} // namespace MathFuncs
802-
} // namespace Detail
803-
} // namespace RooFit
813+
} // namespace RooFit::Detail::MathFuncs
804814

805-
namespace clad {
806-
namespace custom_derivatives {
807-
namespace RooFit {
808-
namespace Detail {
809-
namespace MathFuncs {
815+
namespace clad::custom_derivatives {
816+
namespace RooFit::Detail::MathFuncs {
810817

811818
// Clad can't generate the pullback for binNumber because of the
812819
// std::lower_bound usage. But since binNumber returns an integer, and such
@@ -818,10 +825,7 @@ void binNumber_pullback(Types...)
818825
{
819826
}
820827

821-
} // namespace MathFuncs
822-
} // namespace Detail
823-
} // namespace RooFit
824-
} // namespace custom_derivatives
825-
} // namespace clad
828+
} // namespace RooFit::Detail::MathFuncs
829+
} // namespace clad::custom_derivatives
826830

827831
#endif

roofit/roofitcore/src/RooFit/CodegenContext.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ void CodegenContext::addResult(RooAbsArg const *in, std::string const &valueToSa
236236
/// @brief Function to save a RooListProxy as an array in the squashed code.
237237
/// @param in The list to convert to array.
238238
/// @return Name of the array that stores the input list in the squashed code.
239-
std::string CodegenContext::buildArg(RooAbsCollection const &in)
239+
std::string CodegenContext::buildArg(RooAbsCollection const &in, std::string const &arrayType)
240240
{
241241
if (in.empty()) {
242242
return "nullptr";
@@ -250,7 +250,7 @@ std::string CodegenContext::buildArg(RooAbsCollection const &in)
250250
bool canSaveOutside = true;
251251

252252
std::stringstream declStrm;
253-
declStrm << "double " << savedName << "[] = {";
253+
declStrm << arrayType << " " << savedName << "[]{";
254254
for (const auto arg : in) {
255255
declStrm << getResult(*arg) << ",";
256256
canSaveOutside = canSaveOutside && isScopeIndependent(arg);

0 commit comments

Comments
 (0)