|
15 | 15 | #include <RooAbsCachedPdf.h> |
16 | 16 | #include <RooAddPdf.h> |
17 | 17 | #include <RooAddModel.h> |
| 18 | +#include <RooBinning.h> |
18 | 19 | #include <RooBinSamplingPdf.h> |
19 | 20 | #include <RooBinWidthFunction.h> |
20 | 21 | #include <RooCategory.h> |
@@ -538,16 +539,55 @@ class ParamHistFuncFactory : public RooFit::JSONIO::Importer { |
538 | 539 | bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override |
539 | 540 | { |
540 | 541 | std::string name(RooJSONFactoryWSTool::name(p)); |
541 | | - RooArgList vars = tool->requestArgList<RooRealVar>(p, "variables"); |
542 | | - std::vector<int> nbins; |
543 | | - nbins << p["nbins"]; |
544 | | - for (size_t i = 0; i < vars.size(); ++i) { |
545 | | - auto *v = dynamic_cast<RooRealVar*>(vars.at(i)); |
546 | | - v->setBins(nbins[i]); |
547 | | - } |
548 | | - tool->wsEmplace<ParamHistFunc>(name, vars, tool->requestArgList<RooAbsReal>(p, "parameters")); |
| 542 | + RooArgList varList = tool->requestArgList<RooRealVar>(p, "variables"); |
| 543 | + tool->wsEmplace<ParamHistFunc>(name, readBinning(p, varList), tool->requestArgList<RooAbsReal>(p, "parameters")); |
549 | 544 | return true; |
550 | 545 | } |
| 546 | + |
| 547 | +private: |
| 548 | + RooArgList readBinning(const JSONNode &topNode, const RooArgList &varList) const |
| 549 | + { |
| 550 | + // Temporary map from variable name → RooRealVar |
| 551 | + std::map<std::string, std::unique_ptr<RooRealVar>> varMap; |
| 552 | + |
| 553 | + // Build variables from JSON |
| 554 | + for (const JSONNode &node : topNode["axes"].children()) { |
| 555 | + const std::string name = node["name"].val(); |
| 556 | + std::unique_ptr<RooRealVar> obs; |
| 557 | + |
| 558 | + if (node.has_child("edges")) { |
| 559 | + std::vector<double> edges; |
| 560 | + for (const auto &bound : node["edges"].children()) { |
| 561 | + edges.push_back(bound.val_double()); |
| 562 | + } |
| 563 | + obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), edges.front(), edges.back()); |
| 564 | + RooBinning bins(obs->getMin(), obs->getMax()); |
| 565 | + for (auto b : edges) |
| 566 | + bins.addBoundary(b); |
| 567 | + obs->setBinning(bins); |
| 568 | + } else { |
| 569 | + obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), node["min"].val_double(), |
| 570 | + node["max"].val_double()); |
| 571 | + obs->setBins(node["nbins"].val_int()); |
| 572 | + } |
| 573 | + |
| 574 | + varMap[name] = std::move(obs); |
| 575 | + } |
| 576 | + |
| 577 | + // Now build the final list following the order in varList |
| 578 | + RooArgList vars; |
| 579 | + for (int i = 0; i < varList.getSize(); ++i) { |
| 580 | + const auto *refVar = dynamic_cast<RooRealVar *>(varList.at(i)); |
| 581 | + if (!refVar) |
| 582 | + continue; |
| 583 | + |
| 584 | + auto it = varMap.find(refVar->GetName()); |
| 585 | + if (it != varMap.end()) { |
| 586 | + vars.addOwned(std::move(it->second)); // preserve ownership |
| 587 | + } |
| 588 | + } |
| 589 | + return vars; |
| 590 | + } |
551 | 591 | }; |
552 | 592 |
|
553 | 593 | /////////////////////////////////////////////////////////////////////////////////////////////////////// |
@@ -980,15 +1020,36 @@ class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter { |
980 | 1020 | elem["type"] << key(); |
981 | 1021 | RooJSONFactoryWSTool::fillSeq(elem["variables"], pdf->dataVars()); |
982 | 1022 | RooJSONFactoryWSTool::fillSeq(elem["parameters"], pdf->paramList()); |
983 | | - std::vector<int> nbins; |
984 | | - for (auto *arg : pdf->dataVars()) { |
985 | | - auto *var = dynamic_cast<RooRealVar*>(arg); |
986 | | - nbins.push_back(var->numBins()); |
987 | | - } |
988 | | - elem["nbins"] << nbins; |
989 | | - |
| 1023 | + writeBinningInfo(pdf, elem); |
990 | 1024 | return true; |
991 | 1025 | } |
| 1026 | + |
| 1027 | +private: |
| 1028 | + void writeBinningInfo(const ParamHistFunc *pdf, JSONNode &elem) const |
| 1029 | + { |
| 1030 | + auto &observablesNode = elem["axes"].set_seq(); |
| 1031 | + // axes have to be ordered to get consistent bin indices |
| 1032 | + for (auto *var : static_range_cast<RooRealVar *>(pdf->dataVars())) { |
| 1033 | + std::string name = var->GetName(); |
| 1034 | + RooJSONFactoryWSTool::testValidName(name, false); |
| 1035 | + JSONNode &obsNode = observablesNode.append_child().set_map(); |
| 1036 | + obsNode["name"] << name; |
| 1037 | + if (var->getBinning().isUniform()) { |
| 1038 | + obsNode["min"] << var->getMin(); |
| 1039 | + obsNode["max"] << var->getMax(); |
| 1040 | + obsNode["nbins"] << var->getBins(); |
| 1041 | + } else { |
| 1042 | + auto &edges = obsNode["edges"]; |
| 1043 | + edges.set_seq(); |
| 1044 | + double val = var->getBinning().binLow(0); |
| 1045 | + edges.append_child() << val; |
| 1046 | + for (int i = 0; i < var->getBinning().numBins(); ++i) { |
| 1047 | + val = var->getBinning().binHigh(i); |
| 1048 | + edges.append_child() << val; |
| 1049 | + } |
| 1050 | + } |
| 1051 | + } |
| 1052 | + } |
992 | 1053 | }; |
993 | 1054 |
|
994 | 1055 | #define DEFINE_EXPORTER_KEY(class_name, name) \ |
@@ -1028,7 +1089,7 @@ DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral"); |
1028 | 1089 | DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative"); |
1029 | 1090 | DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf"); |
1030 | 1091 | DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf"); |
1031 | | -DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "param_hist_func"); |
| 1092 | +DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step"); |
1032 | 1093 |
|
1033 | 1094 | /////////////////////////////////////////////////////////////////////////////////////////////////////// |
1034 | 1095 | // instantiate all importers and exporters |
@@ -1061,7 +1122,7 @@ STATIC_EXECUTE([]() { |
1061 | 1122 | registerImporter<RooDerivativeFactory>("derivative", false); |
1062 | 1123 | registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false); |
1063 | 1124 | registerImporter<RooExtendPdfFactory>("extend_pdf", false); |
1064 | | - registerImporter<ParamHistFuncFactory>("param_hist_func", false); |
| 1125 | + registerImporter<ParamHistFuncFactory>("step", false); |
1065 | 1126 |
|
1066 | 1127 | registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false); |
1067 | 1128 | registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false); |
|
0 commit comments