Skip to content

Commit 1cc876a

Browse files
committed
Refactor RooParametricHist detail into CombineMathFuncs
1 parent 41415b8 commit 1cc876a

File tree

3 files changed

+276
-103
lines changed

3 files changed

+276
-103
lines changed

interface/CombineMathFuncs.h

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,167 @@ inline Double_t verticalInterpPdfIntegral(double const* coefList, std::size_t nC
306306
return result > 0. ? result : integralFloorVal;
307307
}
308308

309+
inline int parametricHistFindBin(const int N_bins, const double* bins, const double x) {
310+
if (x < bins[0] || x >= bins[N_bins])
311+
return -1;
312+
313+
// Search for the bin
314+
for (int i = 0; i < N_bins; ++i) {
315+
if (x >= bins[i] && x < bins[i + 1])
316+
return i;
317+
}
318+
return -1;
319+
}
320+
321+
inline int parametricHistFindBin(const int N_bins, std::vector<double> const& bins, const double x) {
322+
return parametricHistFindBin(N_bins, bins.data(), x);
323+
}
324+
325+
inline Double_t parametricHistMorphScale(const double parVal,
326+
const int nMorphs,
327+
const double* morphCoeffs,
328+
const double* morphDiffs,
329+
const double* morphSums,
330+
double smoothRegion) {
331+
double morphScale = 1.0;
332+
if (!morphDiffs || !morphSums)
333+
return morphScale;
334+
for (int i = 0; i < nMorphs; ++i) {
335+
double coeff = morphCoeffs[i];
336+
double a = 0.5 * coeff;
337+
double b = smoothStepFunc(coeff, smoothRegion);
338+
morphScale *= 1 + (1.0 / parVal) * a * (morphDiffs[i] + b * morphSums[i]);
339+
}
340+
return morphScale;
341+
}
342+
343+
inline Double_t parametricHistEvaluate(const int bin_i,
344+
const double* parVals,
345+
const double* bins,
346+
const int N_bins,
347+
const double* morphCoeffs,
348+
const int nMorphs,
349+
const double* morphDiffs,
350+
const double* morphSums,
351+
const double* widths,
352+
const double smoothRegion) {
353+
if (bin_i < 0)
354+
return 0.0;
355+
// Morphing case
356+
if (morphCoeffs != nullptr && nMorphs > 0) {
357+
// morphDiffs and morphSums are flattened arrays of size N_bins * nMorphs
358+
const double* binMorphDiffs = nullptr;
359+
const double* binMorphSums = nullptr;
360+
if (morphDiffs) {
361+
binMorphDiffs = morphDiffs + bin_i * nMorphs;
362+
}
363+
if (morphSums) {
364+
binMorphSums = morphSums + bin_i * nMorphs;
365+
}
366+
double parVal = parVals[bin_i];
367+
double scale = parametricHistMorphScale(parVal, nMorphs, morphCoeffs, binMorphDiffs, binMorphSums, smoothRegion);
368+
return (parVal * scale) / widths[bin_i];
369+
}
370+
// No morphing case
371+
return parVals[bin_i] / widths[bin_i];
372+
}
373+
374+
inline Double_t parametricMorphFunction(const int j,
375+
const double parVal,
376+
const bool hasMorphs,
377+
const int nMorphs,
378+
const double* morphCoeffs,
379+
const double* morphDiffs,
380+
const double* morphSums,
381+
double smoothRegion) {
382+
double morphScale = 1.0;
383+
if (!hasMorphs)
384+
return morphScale;
385+
386+
int ndim = nMorphs;
387+
// apply all morphs one by one to the bin
388+
// almost certaintly a faster way to do this in a vectorized way ....
389+
for (int i = 0; i < ndim; ++i) {
390+
double x = morphCoeffs[i];
391+
double a = 0.5 * x, b = smoothStepFunc(x, smoothRegion);
392+
morphScale *= 1 + (1. / parVal) * a * (morphDiffs[j * nMorphs + i] + b * morphSums[j * nMorphs + i]);
393+
}
394+
return morphScale;
395+
}
396+
397+
inline Double_t parametricHistFullSum(const double* parVals,
398+
const int nBins,
399+
const bool hasMorphs,
400+
const int nMorphs,
401+
const double* morphCoeffs,
402+
const double* morphDiffs,
403+
const double* morphSums,
404+
double smoothRegion) {
405+
double sum = 0;
406+
for (int i = 0; i < nBins; ++i) {
407+
double thisVal = parVals[i];
408+
if (hasMorphs) {
409+
// Apply morphing to this bin, just like in RooParametricHist::evaluate
410+
thisVal *=
411+
parametricMorphFunction(i, thisVal, hasMorphs, nMorphs, morphCoeffs, morphDiffs, morphSums, smoothRegion);
412+
}
413+
sum += thisVal;
414+
}
415+
return sum;
416+
}
417+
418+
inline Double_t parametricHistIntegral(const double* parVals,
419+
const double* bins,
420+
const int N_bins,
421+
const double* morphCoeffs,
422+
const int nMorphs,
423+
const double* morphDiffs,
424+
const double* morphSums,
425+
const double* widths,
426+
const double smoothRegion,
427+
const char* rangeName,
428+
const double xmin,
429+
const double xmax) {
430+
// No ranges
431+
if (!rangeName) {
432+
return parametricHistFullSum(
433+
parVals, N_bins, morphCoeffs != nullptr, nMorphs, morphCoeffs, morphDiffs, morphSums, smoothRegion);
434+
}
435+
436+
// Case with ranges, calculate integral explicitly
437+
double sum = 0;
438+
int i;
439+
for (i = 1; i <= N_bins; i++) {
440+
// Get maybe-morphed bin value
441+
double binVal = parVals[i - 1] / widths[i - 1];
442+
if (morphCoeffs != nullptr) {
443+
binVal *= parametricMorphFunction(
444+
i - 1, parVals[i - 1], true, nMorphs, morphCoeffs, morphDiffs, morphSums, smoothRegion);
445+
}
446+
447+
if (bins[i - 1] >= xmin && bins[i] <= xmax) {
448+
// Bin fully in integration domain
449+
sum += (bins[i] - bins[i - 1]) * binVal;
450+
} else if (bins[i - 1] < xmin && bins[i] > xmax) {
451+
// Domain is fully contained in this bin
452+
sum += (xmax - xmin) * binVal;
453+
// Exit here, this is the last bin to be processed by construction
454+
double fullSum = parametricHistFullSum(
455+
parVals, N_bins, morphCoeffs != nullptr, nMorphs, morphCoeffs, morphDiffs, morphSums, smoothRegion);
456+
return sum / fullSum;
457+
} else if (bins[i - 1] < xmin && bins[i] <= xmax && bins[i] > xmin) {
458+
// Lower domain boundary is in bin
459+
sum += (bins[i] - xmin) * binVal;
460+
} else if (bins[i - 1] >= xmin && bins[i] > xmax && bins[i - 1] < xmax) {
461+
// Upper domain boundary is in bin
462+
// Exit here, this is the last bin to be processed by construction
463+
sum += (xmax - bins[i - 1]) * binVal;
464+
return sum;
465+
}
466+
}
467+
return sum;
468+
}
469+
309470
} // namespace MathFuncs
310471
} // namespace Detail
311472
} // namespace RooFit

interface/RooParametricHist.h

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,33 @@ class RooParametricHist : public RooAbsPdf {
3131
RooArgList & getAllBinVars() const ;
3232

3333
RooRealVar & getObs() const { return (RooRealVar&)x; };
34+
RooAbsReal& observable() const { return const_cast<RooAbsReal&>(static_cast<const RooAbsReal&>(x.arg())); }
3435
const std::vector<double> getBins() const { return bins; };
36+
const int getNBins() const { return N_bins; };
3537
const std::vector<double> getWidths() const { return widths; };
3638

37-
const double quickSum() const {return getFullSum() ;}
39+
const RooArgList& getPars() const { return pars; };
40+
const RooArgList& getCoeffList() const { return _coeffList; };
41+
42+
const double quickSum() const;
3843
//RooAddition & getYieldVar(){return sum;};
3944

4045
// how can we pass this version? is there a Collection object for RooDataHists?
4146
//void addMorphs(RooArgList &_morphPdfsUp, RooArgList &_morphPdfsDown, RooArgList &_coeffs, double smoothRegion);
4247
void addMorphs(RooDataHist&, RooDataHist&, RooRealVar&, double );
48+
Double_t evaluate() const override;
49+
50+
// Accessors for evaluation data
51+
double getX() const { return x; }
52+
double getSmoothRegion() const { return _smoothRegion; }
53+
bool hasMorphs() const { return _hasMorphs; }
54+
55+
double getParVal(int bin_i) const;
56+
57+
// Utility functions for data extraction
58+
const std::vector<double>& getParVals() const;
59+
const std::vector<double>& getCoeffs() const;
60+
void getFlattenedMorphs(std::vector<double>& diffs_flat, std::vector<double>& sums_flat) const;
4361

4462
protected:
4563

@@ -56,19 +74,16 @@ class RooParametricHist : public RooAbsPdf {
5674
mutable double _smoothRegion;
5775
mutable bool _hasMorphs;
5876
mutable std::vector<std::vector <double> > _diffs;
59-
mutable std::vector<std::vector <double> > _sums;
60-
double evaluateMorphFunction(int) const;
77+
mutable std::vector<std::vector<double> > _sums;
6178

79+
mutable std::vector<double> pars_vals_; //! Don't serialize me
80+
mutable std::vector<double> coeffs_; //! Don't serialize me
81+
mutable std::vector<double> diffs_flat_; //! Don't serialize me
82+
mutable std::vector<double> sums_flat_; //! Don't serialize me
6283

6384
void initializeBins(const TH1&) const;
6485
//void initializeNorm();
6586

66-
67-
double evaluatePartial() const ;
68-
double evaluateFull() const ;
69-
Double_t evaluate() const override ;
70-
double getFullSum() const ;
71-
7287
mutable double _cval;
7388
void update_cval(double r){_cval=r;};
7489

0 commit comments

Comments
 (0)