Skip to content

Commit 7f76766

Browse files
GalinBistrev2guitargeek
authored andcommitted
[RF] Added test and correction factor in RooMultiPdf
1 parent acdc6b6 commit 7f76766

File tree

15 files changed

+140
-26
lines changed

15 files changed

+140
-26
lines changed

roofit/roofit/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ endif()
1616

1717
ROOT_STANDARD_LIBRARY_PACKAGE(RooFit
1818
HEADERS
19-
RooMultiPdf.h
2019
Roo2DKeysPdf.h
2120
RooArgusBG.h
2221
RooBCPEffDecay.h
@@ -82,7 +81,6 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit
8281
RooVoigtian.h
8382
RooJohnson.h
8483
SOURCES
85-
src/RooMultiPdf.cxx
8684
src/Roo2DKeysPdf.cxx
8785
src/RooArgusBG.cxx
8886
src/RooBCPEffDecay.cxx

roofit/roofit/inc/LinkDef1.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#pragma link off all globals;
33
#pragma link off all classes;
44
#pragma link off all functions;
5-
#pragma link C++ class Roo2DKeysPdf+ ;
6-
#pragma link C++ class RooMultiPdf + ;
5+
#pragma link C++ class Roo2DKeysPdf + ;
76
#pragma link C++ class RooArgusBG+ ;
87
#pragma link C++ class RooBCPEffDecay+ ;
98
#pragma link C++ class RooBCPGenDecay+ ;
@@ -59,6 +58,7 @@
5958
#pragma link C++ class RooSpline+ ;
6059
#pragma link C++ class RooStepFunction+ ;
6160
#pragma link C++ class RooMultiBinomial+ ;
61+
6262
/* #pragma link C++ class std::vector< TVector2 >; */
6363
/* #pragma link C++ class std::vector< TVector2 >::iterator ; */
6464
/* #pragma link C++ class RooPolyMorph2D+ ; */

roofit/roofit/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ ROOT_ADD_GTEST(testRooLandau testRooLandau.cxx LIBRARIES RooFit)
2020
ROOT_ADD_GTEST(testRooParamHistFunc testRooParamHistFunc.cxx LIBRARIES Gpad RooFit)
2121
ROOT_ADD_GTEST(testRooPoisson testRooPoisson.cxx LIBRARIES RooFit)
2222
ROOT_ADD_GTEST(testRooStepFunc testRooStepFunc.cxx LIBRARIES RooFit)
23-
ROOT_ADD_GTEST(testRooMultiPdf testRooMultiPdf.cxx LIBRARIES RooFit)
23+
2424
if(mathmore)
2525
ROOT_EXECUTABLE(testRooFit testRooFit.cxx LIBRARIES RooFit MathMore)
2626
ROOT_ADD_TEST(test-fit-testRooFit COMMAND testRooFit)

roofit/roofitcore/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
171171
RooMsgService.h
172172
RooMultiCategory.h
173173
RooMultiVarGaussian.h
174+
RooMultiPdf.h
174175
RooNameReg.h
175176
RooNormSetCache.h
176177
RooNumCdf.h
@@ -366,6 +367,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
366367
src/RooMsgService.cxx
367368
src/RooMultiCategory.cxx
368369
src/RooMultiVarGaussian.cxx
370+
src/RooMultiPdf.cxx
369371
src/RooNLLVarNew.cxx
370372
src/RooNameReg.cxx
371373
src/RooNormSetCache.cxx

roofit/roofitcore/inc/LinkDef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
#pragma read sourceClass="RooMappedCategory::Entry" targetClass="RooMappedCategory::Entry" version="[1]" include="RooFitLegacy/RooCatTypeLegacy.h" \
134134
source="RooCatType _cat" target="_catIdx" code="{ _catIdx = onfile._cat.getVal(); }"
135135
#pragma link C++ class RooMultiCategory+ ;
136+
#pragma link C++ class RooMultiPdf + ;
136137
#pragma link off class RooNameReg+ ;
137138
#pragma link C++ class RooNumConvolution+ ;
138139
#pragma link C++ class RooNumConvPdf+ ;

roofit/roofitcore/inc/RooAbsPdf.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ class RooAbsPdf : public RooAbsReal {
177177
// Project p.d.f into lower dimensional p.d.f
178178
virtual RooAbsPdf* createProjection(const RooArgSet& iset) ;
179179

180+
virtual double getCorrection() const;
181+
180182
// Create cumulative density function from p.d.f
181183
RooFit::OwningPtr<RooAbsReal> createCdf(const RooArgSet& iset, const RooArgSet& nset=RooArgSet()) ;
182184
RooFit::OwningPtr<RooAbsReal> createCdf(const RooArgSet& iset, const RooCmdArg& arg1, const RooCmdArg& arg2={},

roofit/roofitcore/inc/RooFit/Detail/RooNormalizedPdf.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class RooNormalizedPdf : public RooAbsPdf {
4747

4848
bool selfNormalized() const override { return true; }
4949

50+
inline double getCorrection() const override { return _pdf->getCorrection(); }
51+
5052
bool forceAnalyticalInt(const RooAbsArg & /*dep*/) const override { return true; }
5153
/// Forward determination of analytical integration capabilities to input p.d.f
5254
Int_t getAnalyticalIntegralWN(RooArgSet &allVars, RooArgSet &analVars, const RooArgSet * /*normSet*/,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RooMultiPdf : public RooAbsPdf {
1818
TObject *clone(const char *newname) const override { return new RooMultiPdf(*this, newname); }
1919

2020
inline bool checkIndexDirty() const { return _oldIndex != x; }
21-
inline double getCorrection() const { return cFactor * static_cast<RooAbsReal *>(corr.at(x))->getVal(); }
21+
inline double getCorrection() const override { return cFactor * static_cast<RooAbsReal *>(corr.at(x))->getVal(); }
2222
inline RooAbsPdf *getCurrentPdf() const { return getPdf(getCurrentIndex()); }
2323
int getNumPdfs() const { return c.size(); }
2424

roofit/roofitcore/src/FitHelpers.cxx

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <RooFuncWrapper.h>
3434
#include <RooLinkedList.h>
3535
#include <RooMinimizer.h>
36+
#include <RooConstVar.h>
3637
#include <RooRealVar.h>
3738
#include <RooSimultaneous.h>
3839
#include <RooFormulaVar.h>
@@ -810,8 +811,26 @@ std::unique_ptr<RooAbsReal> createNLL(RooAbsPdf &pdf, RooAbsData &data, const Ro
810811
takeGlobalObservablesFromData);
811812
}
812813

814+
const double correction = pdfClone->getCorrection();
813815
nllWrapper->addOwnedComponents(std::move(nll));
814816
nllWrapper->addOwnedComponents(std::move(pdfClone));
817+
818+
if (correction > 0) {
819+
oocoutI(&pdf, Fitting) << "[FitHelpers] Detected correction term from RooAbsPdf::getCorrection(). "
820+
<< "Adding penalty to NLL." << std::endl;
821+
822+
// Convert the multiplicative correction to an additive term in -log L
823+
auto penaltyTerm = std::make_unique<RooConstVar>((baseName + "_Penalty").c_str(),
824+
"Penalty term from getCorrection()", correction);
825+
826+
auto correctedNLL = std::make_unique<RooAddition>(
827+
// add penalty and NLL
828+
(baseName + "_corrected").c_str(), "NLL + penalty", RooArgSet(*nllWrapper, *penaltyTerm));
829+
830+
// transfer ownership of terms
831+
correctedNLL->addOwnedComponents(std::move(nllWrapper), std::move(penaltyTerm));
832+
nllWrapper = std::move(correctedNLL);
833+
}
815834
return nllWrapper;
816835
}
817836

@@ -888,6 +907,23 @@ std::unique_ptr<RooAbsReal> createNLL(RooAbsPdf &pdf, RooAbsData &data, const Ro
888907
throw std::runtime_error("RooFit was not built with the legacy evaluation backend");
889908
#endif
890909

910+
if (const double correction = pdf.getCorrection(); correction > 0) {
911+
oocoutI(&pdf, Fitting) << "[FitHelpers] Detected correction term from RooAbsPdf::getCorrection(). "
912+
<< "Adding penalty to NLL." << std::endl;
913+
914+
// Convert the multiplicative correction to an additive term in -log L
915+
auto penaltyTerm = std::make_unique<RooConstVar>((baseName + "_Penalty").c_str(),
916+
"Penalty term from getCorrection()", correction);
917+
918+
auto correctedNLL = std::make_unique<RooAddition>(
919+
// add penalty and NLL
920+
(baseName + "_corrected").c_str(), "NLL + penalty", RooArgSet(*nll, *penaltyTerm));
921+
922+
// transfer ownership of terms
923+
correctedNLL->addOwnedComponents(std::move(nll), std::move(penaltyTerm));
924+
nll = std::move(correctedNLL);
925+
}
926+
891927
return nll;
892928
}
893929

roofit/roofitcore/src/RooAbsPdf.cxx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,14 @@ double RooAbsPdf::getLogVal(const RooArgSet* nset) const
619619
{
620620
return getLog(getVal(nset), this);
621621
}
622-
622+
////////////////////////////////////////////////////////////////////////////////
623+
/// This function returns the penalty term.
624+
/// Penalty terms modify the likelihood,during model parameter estimation.This penalty term is usually
625+
// a function of the model parameters
626+
double RooAbsPdf::getCorrection() const
627+
{
628+
return 0;
629+
}
623630

624631
////////////////////////////////////////////////////////////////////////////////
625632
/// Check for infinity or NaN.

0 commit comments

Comments
 (0)