Skip to content

Commit a2c6bc3

Browse files
committed
[RF] Implement codegen support for RooFunctorBinding and friends
This enables easily plugging in user-defined functions including their gradients into RooFit computation graphs, which also works from Python thanks to the `std::function` pythonization.
1 parent 5a58529 commit a2c6bc3

File tree

5 files changed

+219
-30
lines changed

5 files changed

+219
-30
lines changed

roofit/codegen/inc/RooFit/CodegenImpl.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class RooAddPdf;
2525
class RooAddition;
2626
class RooBernstein;
2727
class RooBifurGauss;
28-
class RooCategory;
2928
class RooCBShape;
29+
class RooCategory;
3030
class RooChebychev;
3131
class RooConstVar;
3232
class RooConstraintSum;
@@ -35,13 +35,18 @@ class RooEfficiency;
3535
class RooExponential;
3636
class RooExtendPdf;
3737
class RooFormulaVar;
38+
class RooFunctor1DBinding;
39+
class RooFunctor1DPdfBinding;
40+
class RooFunctorBinding;
41+
class RooFunctorPdfBinding;
3842
class RooGamma;
3943
class RooGaussian;
4044
class RooGenericPdf;
4145
class RooHistFunc;
4246
class RooHistPdf;
4347
class RooLandau;
4448
class RooLognormal;
49+
class RooMultiPdf;
4550
class RooMultiVarGaussian;
4651
class RooParamHistFunc;
4752
class RooPoisson;
@@ -55,7 +60,6 @@ class RooRealSumPdf;
5560
class RooRealVar;
5661
class RooRecursiveFraction;
5762
class RooUniform;
58-
class RooMultiPdf;
5963
class RooWrapperPdf;
6064

6165
namespace RooStats {
@@ -96,6 +100,10 @@ void codegenImpl(RooEfficiency &arg, CodegenContext &ctx);
96100
void codegenImpl(RooExponential &arg, CodegenContext &ctx);
97101
void codegenImpl(RooExtendPdf &arg, CodegenContext &ctx);
98102
void codegenImpl(RooFormulaVar &arg, CodegenContext &ctx);
103+
void codegenImpl(RooFunctor1DBinding &arg, CodegenContext &ctx);
104+
void codegenImpl(RooFunctor1DPdfBinding &arg, CodegenContext &ctx);
105+
void codegenImpl(RooFunctorBinding &arg, CodegenContext &ctx);
106+
void codegenImpl(RooFunctorPdfBinding &arg, CodegenContext &ctx);
99107
void codegenImpl(RooGamma &arg, CodegenContext &ctx);
100108
void codegenImpl(RooGaussian &arg, CodegenContext &ctx);
101109
void codegenImpl(RooGenericPdf &arg, CodegenContext &ctx);

roofit/codegen/src/CodegenImpl.cxx

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include <RooFit/Detail/RooNLLVarNew.h>
3232
#include <RooFit/Detail/RooNormalizedPdf.h>
3333
#include <RooFormulaVar.h>
34+
#include <RooFunctor1DBinding.h>
35+
#include <RooFunctorBinding.h>
3436
#include <RooGamma.h>
3537
#include <RooGaussian.h>
3638
#include <RooGenericPdf.h>
@@ -62,8 +64,9 @@
6264

6365
#include <TInterpreter.h>
6466

65-
namespace RooFit {
66-
namespace Experimental {
67+
#include <unordered_set>
68+
69+
namespace RooFit::Experimental {
6770

6871
namespace {
6972

@@ -371,6 +374,92 @@ void codegenImpl(RooConstraintSum &arg, CodegenContext &ctx)
371374
ctx.addResult(&arg, ctx.buildCall(mathFunc("constraintSum"), arg.list(), arg.list().size()));
372375
}
373376

377+
// Generate RooFit codegen wrappers for RooFunctorBinding and similar objects,
378+
// emitting both the primal function call and its gradient pullback for
379+
// Clad-based AD.
380+
template <class RooArg_t>
381+
void functorCodegenImpl(RooArg_t &arg, RooArgList const &variables, CodegenContext &ctx)
382+
{
383+
if (!arg.function()->HasGradient()) {
384+
std::stringstream errorMsg;
385+
errorMsg << "Functor wrapped by \"" << arg.GetName() << "\" doesn't provide a gradient function."
386+
<< " RooFit codegen is therefore not supported.";
387+
oocoutE(&arg, InputArguments) << errorMsg.str() << std::endl;
388+
throw std::runtime_error(errorMsg.str());
389+
}
390+
391+
std::string funcAddrStr = TString::Format("0x%zx", reinterpret_cast<std::size_t>(arg.function())).Data();
392+
std::string wrapperName = "roo_functor_" + funcAddrStr;
393+
394+
static std::unordered_set<std::string> wrapperNames;
395+
396+
if (wrapperNames.find(wrapperName) == wrapperNames.end()) {
397+
398+
wrapperNames.insert(wrapperName);
399+
400+
std::string pullbackName = wrapperName + "_pullback";
401+
std::string nStr = std::to_string(std::size(variables));
402+
403+
std::string type;
404+
if constexpr (std::is_same_v<RooArg_t, RooFunctor1DBinding> || std::is_same_v<RooArg_t, RooFunctor1DPdfBinding>)
405+
type = "::ROOT::Math::IGradientFunctionOneDim";
406+
else
407+
type = "::ROOT::Math::IGradientFunctionMultiDim";
408+
409+
std::string funcAddrCasted = "reinterpret_cast<" + type + " const *>(" + funcAddrStr + ")";
410+
411+
std::string code;
412+
413+
code += "double " + wrapperName +
414+
"(double const *x) {\n"
415+
" return " +
416+
funcAddrCasted +
417+
"->operator()(x);\n"
418+
"}\n\n"
419+
"namespace clad::custom_derivatives {\n\n"
420+
"void " +
421+
pullbackName +
422+
"(double const* x, double d_y, double *d_x) {\n"
423+
" double output[" +
424+
nStr +
425+
"]{};\n"
426+
" " +
427+
funcAddrCasted +
428+
"->Gradient(x, output);\n"
429+
" for (int i = 0; i < " +
430+
nStr +
431+
"; ++i) {\n"
432+
" d_x[i] += output[i] * d_y;\n"
433+
" }\n"
434+
"}\n"
435+
"} // namespace clad::custom_derivatives\n";
436+
437+
gInterpreter->Declare(code.c_str());
438+
}
439+
440+
ctx.addResult(&arg, ctx.buildCall(wrapperName, variables));
441+
}
442+
443+
void codegenImpl(RooFunctor1DBinding &arg, CodegenContext &ctx)
444+
{
445+
functorCodegenImpl(arg, arg.variable(), ctx);
446+
}
447+
448+
void codegenImpl(RooFunctor1DPdfBinding &arg, CodegenContext &ctx)
449+
{
450+
functorCodegenImpl(arg, arg.variable(), ctx);
451+
}
452+
453+
void codegenImpl(RooFunctorBinding &arg, CodegenContext &ctx)
454+
{
455+
functorCodegenImpl(arg, arg.variables(), ctx);
456+
}
457+
458+
void codegenImpl(RooFunctorPdfBinding &arg, CodegenContext &ctx)
459+
{
460+
functorCodegenImpl(arg, arg.variables(), ctx);
461+
}
462+
374463
void codegenImpl(RooGamma &arg, CodegenContext &ctx)
375464
{
376465
ctx.addResult(&arg, ctx.buildCall("TMath::GammaDist", arg.getX(), arg.getGamma(), arg.getMu(), arg.getBeta()));
@@ -909,5 +998,4 @@ std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName
909998
return doubleToString(arg.analyticalIntegral(code, rangeName));
910999
}
9111000

912-
} // namespace Experimental
913-
} // namespace RooFit
1001+
} // namespace RooFit::Experimental

roofit/roofit/inc/RooFunctor1DBinding.h

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "RooListProxy.h"
1919
#include "RooAbsPdf.h"
2020
#include "RooRealProxy.h"
21-
#include "RooMsgService.h"
2221
#include "Math/IFunction.h"
2322

2423

@@ -32,49 +31,45 @@ RooAbsPdf* bindPdf(const char* name, const ROOT::Math::IBaseFunctionOneDim& fto
3231

3332
class RooFunctor1DBinding : public RooAbsReal {
3433
public:
35-
RooFunctor1DBinding() : func(nullptr) {
36-
// Default constructor
37-
} ;
34+
RooFunctor1DBinding() = default;
3835
RooFunctor1DBinding(const char *name, const char *title, const ROOT::Math::IBaseFunctionOneDim& ftor, RooAbsReal& var);
3936
RooFunctor1DBinding(const RooFunctor1DBinding& other, const char* name=nullptr) ;
4037
TObject* clone(const char* newname=nullptr) const override { return new RooFunctor1DBinding(*this,newname); }
4138
void printArgs(std::ostream& os) const override ;
4239

40+
ROOT::Math::IBaseFunctionOneDim const *function() const { return func; }
41+
RooAbsReal const &variable() const { return *var; }
42+
4343
protected:
4444

4545
double evaluate() const override ;
4646

47-
const ROOT::Math::IBaseFunctionOneDim* func ; // Functor
47+
const ROOT::Math::IBaseFunctionOneDim *func = nullptr; // Functor
4848
RooRealProxy var ; // Argument reference
4949

50-
51-
private:
52-
5350
ClassDefOverride(RooFunctor1DBinding,1) // RooAbsReal binding to a ROOT::Math::IBaseFunctionOneDim
5451
};
5552

5653

5754

5855
class RooFunctor1DPdfBinding : public RooAbsPdf {
5956
public:
60-
RooFunctor1DPdfBinding() : func(nullptr) {
61-
// Default constructor
62-
} ;
57+
RooFunctor1DPdfBinding() = default;
6358
RooFunctor1DPdfBinding(const char *name, const char *title, const ROOT::Math::IBaseFunctionOneDim& ftor, RooAbsReal& vars);
6459
RooFunctor1DPdfBinding(const RooFunctor1DPdfBinding& other, const char* name=nullptr) ;
6560
TObject* clone(const char* newname=nullptr) const override { return new RooFunctor1DPdfBinding(*this,newname); }
6661
void printArgs(std::ostream& os) const override ;
6762

63+
ROOT::Math::IBaseFunctionOneDim const *function() const { return func; }
64+
RooAbsReal const &variable() const { return *var; }
65+
6866
protected:
6967

7068
double evaluate() const override ;
7169

72-
const ROOT::Math::IBaseFunctionOneDim* func ; // Functor
70+
ROOT::Math::IBaseFunctionOneDim const *func = nullptr; // Functor
7371
RooRealProxy var ; // Argument reference
7472

75-
76-
private:
77-
7873
ClassDefOverride(RooFunctor1DPdfBinding,1) // RooAbsPdf binding to a ROOT::Math::IBaseFunctionOneDim
7974
};
8075

roofit/roofit/inc/RooFunctorBinding.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "RooListProxy.h"
1919
#include "RooAbsPdf.h"
2020
#include "RooRealProxy.h"
21-
#include "RooMsgService.h"
2221
#include "Math/IFunction.h"
2322

2423
namespace RooFit {
@@ -37,6 +36,9 @@ class RooFunctorBinding : public RooAbsReal {
3736
inline ~RooFunctorBinding() override { delete[] x ; }
3837
void printArgs(std::ostream& os) const override ;
3938

39+
ROOT::Math::IBaseFunctionMultiDim const *function() const { return func; }
40+
RooArgList const &variables() const { return vars; }
41+
4042
protected:
4143

4244
double evaluate() const override ;
@@ -45,9 +47,6 @@ class RooFunctorBinding : public RooAbsReal {
4547
RooListProxy vars; // Argument reference
4648
double *x = nullptr; // Argument value array
4749

48-
49-
private:
50-
5150
ClassDefOverride(RooFunctorBinding,1) // RooAbsReal binding to a ROOT::Math::IBaseFunctionMultiDim
5251
};
5352

@@ -62,6 +61,9 @@ class RooFunctorPdfBinding : public RooAbsPdf {
6261
inline ~RooFunctorPdfBinding() override { delete[] x ; }
6362
void printArgs(std::ostream& os) const override ;
6463

64+
ROOT::Math::IBaseFunctionMultiDim const *function() const { return func; }
65+
RooArgList const &variables() const { return vars; }
66+
6567
protected:
6668

6769
double evaluate() const override ;
@@ -70,9 +72,6 @@ class RooFunctorPdfBinding : public RooAbsPdf {
7072
RooListProxy vars ; // Argument reference
7173
double *x = nullptr; // Argument value array
7274

73-
74-
private:
75-
7675
ClassDefOverride(RooFunctorPdfBinding,1) // RooAbsPdf binding to a ROOT::Math::IBaseFunctionMultiDim
7776
};
7877

0 commit comments

Comments
 (0)