Skip to content

Commit 6af113b

Browse files
committed
Store binning Information with ParamHistFunc
1 parent 5454e38 commit 6af113b

File tree

3 files changed

+95
-27
lines changed

3 files changed

+95
-27
lines changed

roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ class RooJSONFactoryWSTool {
229229
void importVariable(const RooFit::Detail::JSONNode &p);
230230
void importDependants(const RooFit::Detail::JSONNode &n);
231231

232-
void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &n, const bool storeConstant=true);
233-
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n, const bool storeConstant=true);
232+
void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);
233+
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);
234234

235235
void exportAllObjects(RooFit::Detail::JSONNode &n);
236236

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <RooAbsCachedPdf.h>
1616
#include <RooAddPdf.h>
1717
#include <RooAddModel.h>
18+
#include <RooBinning.h>
1819
#include <RooBinSamplingPdf.h>
1920
#include <RooBinWidthFunction.h>
2021
#include <RooCategory.h>
@@ -538,16 +539,55 @@ class ParamHistFuncFactory : public RooFit::JSONIO::Importer {
538539
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
539540
{
540541
std::string name(RooJSONFactoryWSTool::name(p));
541-
RooArgList vars = tool->requestArgList<RooRealVar>(p, "variables");
542-
std::vector<int> nbins;
543-
nbins << p["nbins"];
544-
for (size_t i = 0; i < vars.size(); ++i) {
545-
auto *v = dynamic_cast<RooRealVar*>(vars.at(i));
546-
v->setBins(nbins[i]);
547-
}
548-
tool->wsEmplace<ParamHistFunc>(name, vars, tool->requestArgList<RooAbsReal>(p, "parameters"));
542+
RooArgList varList = tool->requestArgList<RooRealVar>(p, "variables");
543+
tool->wsEmplace<ParamHistFunc>(name, readBinning(p, varList), tool->requestArgList<RooAbsReal>(p, "parameters"));
549544
return true;
550545
}
546+
547+
private:
548+
RooArgList readBinning(const JSONNode &topNode, const RooArgList &varList) const
549+
{
550+
// Temporary map from variable name → RooRealVar
551+
std::map<std::string, std::unique_ptr<RooRealVar>> varMap;
552+
553+
// Build variables from JSON
554+
for (const JSONNode &node : topNode["axes"].children()) {
555+
const std::string name = node["name"].val();
556+
std::unique_ptr<RooRealVar> obs;
557+
558+
if (node.has_child("edges")) {
559+
std::vector<double> edges;
560+
for (const auto &bound : node["edges"].children()) {
561+
edges.push_back(bound.val_double());
562+
}
563+
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), edges.front(), edges.back());
564+
RooBinning bins(obs->getMin(), obs->getMax());
565+
for (auto b : edges)
566+
bins.addBoundary(b);
567+
obs->setBinning(bins);
568+
} else {
569+
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), node["min"].val_double(),
570+
node["max"].val_double());
571+
obs->setBins(node["nbins"].val_int());
572+
}
573+
574+
varMap[name] = std::move(obs);
575+
}
576+
577+
// Now build the final list following the order in varList
578+
RooArgList vars;
579+
for (int i = 0; i < varList.getSize(); ++i) {
580+
const auto *refVar = dynamic_cast<RooRealVar *>(varList.at(i));
581+
if (!refVar)
582+
continue;
583+
584+
auto it = varMap.find(refVar->GetName());
585+
if (it != varMap.end()) {
586+
vars.addOwned(std::move(it->second)); // preserve ownership
587+
}
588+
}
589+
return vars;
590+
}
551591
};
552592

553593
///////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -980,15 +1020,36 @@ class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter {
9801020
elem["type"] << key();
9811021
RooJSONFactoryWSTool::fillSeq(elem["variables"], pdf->dataVars());
9821022
RooJSONFactoryWSTool::fillSeq(elem["parameters"], pdf->paramList());
983-
std::vector<int> nbins;
984-
for (auto *arg : pdf->dataVars()) {
985-
auto *var = dynamic_cast<RooRealVar*>(arg);
986-
nbins.push_back(var->numBins());
987-
}
988-
elem["nbins"] << nbins;
989-
1023+
writeBinningInfo(pdf, elem);
9901024
return true;
9911025
}
1026+
1027+
private:
1028+
void writeBinningInfo(const ParamHistFunc *pdf, JSONNode &elem) const
1029+
{
1030+
auto &observablesNode = elem["axes"].set_seq();
1031+
// axes have to be ordered to get consistent bin indices
1032+
for (auto *var : static_range_cast<RooRealVar *>(pdf->dataVars())) {
1033+
std::string name = var->GetName();
1034+
RooJSONFactoryWSTool::testValidName(name, false);
1035+
JSONNode &obsNode = observablesNode.append_child().set_map();
1036+
obsNode["name"] << name;
1037+
if (var->getBinning().isUniform()) {
1038+
obsNode["min"] << var->getMin();
1039+
obsNode["max"] << var->getMax();
1040+
obsNode["nbins"] << var->getBins();
1041+
} else {
1042+
auto &edges = obsNode["edges"];
1043+
edges.set_seq();
1044+
double val = var->getBinning().binLow(0);
1045+
edges.append_child() << val;
1046+
for (int i = 0; i < var->getBinning().numBins(); ++i) {
1047+
val = var->getBinning().binHigh(i);
1048+
edges.append_child() << val;
1049+
}
1050+
}
1051+
}
1052+
}
9921053
};
9931054

9941055
#define DEFINE_EXPORTER_KEY(class_name, name) \
@@ -1028,7 +1089,7 @@ DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral");
10281089
DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative");
10291090
DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf");
10301091
DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf");
1031-
DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "param_hist_func");
1092+
DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step");
10321093

10331094
///////////////////////////////////////////////////////////////////////////////////////////////////////
10341095
// instantiate all importers and exporters
@@ -1061,7 +1122,7 @@ STATIC_EXECUTE([]() {
10611122
registerImporter<RooDerivativeFactory>("derivative", false);
10621123
registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false);
10631124
registerImporter<RooExtendPdfFactory>("extend_pdf", false);
1064-
registerImporter<ParamHistFuncFactory>("param_hist_func", false);
1125+
registerImporter<ParamHistFuncFactory>("step", false);
10651126

10661127
registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false);
10671128
registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false);

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl<RooAbsReal>(const std::string &obj
965965
* @param node The JSONNode to which the variable will be exported.
966966
* @return void
967967
*/
968-
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, const bool storeConstant)
968+
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, bool storeConstant, bool storeBins)
969969
{
970970
auto *cv = dynamic_cast<const RooConstVar *>(v);
971971
auto *rrv = dynamic_cast<const RooRealVar *>(v);
@@ -987,7 +987,7 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, co
987987
if (rrv->isConstant() && storeConstant) {
988988
var["const"] << rrv->isConstant();
989989
}
990-
if (rrv->getBins() != 100) {
990+
if (rrv->getBins() != 100 && storeBins) {
991991
var["nbins"] << rrv->getBins();
992992
}
993993
_domains->readVariable(*rrv);
@@ -1004,12 +1004,12 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, co
10041004
* @param n The JSONNode to which the variables will be exported.
10051005
* @return void
10061006
*/
1007-
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n, const bool storeConstant)
1007+
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n, bool storeConstant, bool storeBins)
10081008
{
10091009
// export a list of RooRealVar objects
10101010
n.set_seq();
10111011
for (RooAbsArg *arg : allElems) {
1012-
exportVariable(arg, n, storeConstant);
1012+
exportVariable(arg, n, storeConstant, storeBins);
10131013
}
10141014
}
10151015

@@ -1070,7 +1070,7 @@ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set<std::str
10701070
// categories are created by the respective RooSimultaneous, so we're skipping the export here
10711071
return;
10721072
} else if (dynamic_cast<RooRealVar const *>(&func) || dynamic_cast<RooConstVar const *>(&func)) {
1073-
exportVariable(&func, *_varsNode);
1073+
exportVariable(&func, *_varsNode, true, false);
10741074
return;
10751075
}
10761076

@@ -1554,7 +1554,7 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15541554

15551555
// this really is an unbinned dataset
15561556
output["type"] << "unbinned";
1557-
exportVariables(variables, output["axes"], false);
1557+
exportVariables(variables, output["axes"], false, true);
15581558
auto &coords = output["entries"].set_seq();
15591559
std::vector<double> weightVals;
15601560
bool hasNonUnityWeights = false;
@@ -1955,7 +1955,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)
19551955
snapshotSorted.sort();
19561956
std::string name(snsh->GetName());
19571957
if (name != "default_values") {
1958-
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"]);
1958+
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"], true,
1959+
false);
19591960
}
19601961
}
19611962
_varsNode = nullptr;
@@ -2235,8 +2236,14 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
22352236
combineDatasets(*_rootnodeInput, datasets);
22362237

22372238
for (auto const &d : datasets) {
2238-
if (d)
2239+
if (d) {
22392240
_workspace.import(*d);
2241+
for (auto const &obs : *d->get()) {
2242+
if (auto *rrv = dynamic_cast<RooRealVar *>(obs)) {
2243+
_workspace.var(rrv->GetName())->setBinning(rrv->getBinning());
2244+
}
2245+
}
2246+
}
22402247
}
22412248

22422249
_rootnodeInput = nullptr;

0 commit comments

Comments
 (0)