Skip to content

Commit 1681382

Browse files
committed
[RF][HS3] Clean roundtripping of RooExponential and RooLogNormal
Make sure that when transforming workspaces with these classes, roundtripping from RooFit -> JSON -> RooFit -> JSON leaves both the RooFit workspace and the JSON unchanged. Closes #15756.
1 parent cab0c8c commit 1681382

File tree

5 files changed

+88
-65
lines changed

5 files changed

+88
-65
lines changed

roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,7 @@ class RooJSONFactoryWSTool {
181181

182182
void queueExport(RooAbsArg const &arg) { _serversToExport.push_back(&arg); }
183183

184-
RooFit::Detail::JSONNode &createAdHoc(const std::string &toplevel, const std::string &name);
185-
RooAbsReal *importTransformed(const std::string &name, const std::string &tag, const std::string &operation_name,
186-
const std::string &formula);
187-
std::string exportTransformed(const RooAbsReal *original, const std::string &tag, const std::string &operation_name,
188-
const std::string &formula);
184+
std::string exportTransformed(const RooAbsReal *original, const std::string &suffix, const std::string &formula);
189185

190186
void setAttribute(const std::string &obj, const std::string &attrib);
191187
bool hasAttribute(const std::string &obj, const std::string &attrib);

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include <RooTFnBinding.h>
3636
#include <RooWorkspace.h>
3737

38+
#include "JSONIOUtils.h"
39+
3840
#include <TF1.h>
3941
#include <TH1.h>
4042

@@ -237,12 +239,18 @@ class RooLogNormalFactory : public RooFit::JSONIO::Importer {
237239
{
238240
std::string name(RooJSONFactoryWSTool::name(p));
239241
RooAbsReal *x = tool->requestArg<RooAbsReal>(p, "x");
240-
RooAbsReal *mu = tool->requestArg<RooAbsReal>(p, "mu");
241-
RooAbsReal *sigma = tool->requestArg<RooAbsReal>(p, "sigma");
242242

243-
// TODO: check if the pdf was originally exported by ROOT, in which case
244-
// it can be imported back without using the standard parametrization.
245-
tool->wsEmplace<RooLognormal>(name, *x, *mu, *sigma, true);
243+
// Same mechanism to undo the parameter transformation as in the
244+
// RooExponentialFactory (see comments in that class for more info).
245+
const std::string muName = p["mu"].val();
246+
const std::string sigmaName = p["sigma"].val();
247+
const bool isTransformed = endsWith(muName, "_lognormal_log");
248+
const std::string suffixToRemove = isTransformed ? "_lognormal_log" : "";
249+
RooAbsReal *mu = tool->request<RooAbsReal>(removeSuffix(muName, suffixToRemove), name);
250+
RooAbsReal *sigma = tool->request<RooAbsReal>(removeSuffix(sigmaName, suffixToRemove), name);
251+
252+
tool->wsEmplace<RooLognormal>(name, *x, *mu, *sigma, !isTransformed);
253+
246254
return true;
247255
}
248256
};
@@ -253,11 +261,45 @@ class RooExponentialFactory : public RooFit::JSONIO::Importer {
253261
{
254262
std::string name(RooJSONFactoryWSTool::name(p));
255263
RooAbsReal *x = tool->requestArg<RooAbsReal>(p, "x");
256-
RooAbsReal *c = tool->requestArg<RooAbsReal>(p, "c");
257264

258-
// TODO: check if the pdf was originally exported by ROOT, in which case
259-
// it can be imported back without using the standard parametrization.
260-
tool->wsEmplace<RooExponential>(name, *x, *c, true);
265+
// If the parameter name ends with the "_exponential_inverted" suffix,
266+
// this means that it was exported from a RooFit object where the
267+
// parameter first needed to be transformed on export to match the HS3
268+
// specification. But when re-importing such a parameter, we can simply
269+
// skip the transformation and use the original RooFit parameter without
270+
// the suffix.
271+
//
272+
// A concrete example: take the following RooFit pdf in the factory language:
273+
//
274+
// "Exponential::exponential_1(x[0, 10], c[-0.1])"
275+
//
276+
// It defines en exponential exp(c * x). However, in HS3 the exponential
277+
// is defined as exp(-c * x), to RooFit would export these dictionaries
278+
// to the JSON:
279+
//
280+
// {
281+
// "name": "exponential_1", // HS3 exponential_dist with transformed parameter
282+
// "type": "exponential_dist",
283+
// "x": "x",
284+
// "c": "c_exponential_inverted"
285+
// },
286+
// {
287+
// "name": "c_exponential_inverted", // transformation function created on-the-fly on export
288+
// "type": "generic_function",
289+
// "expression": "-c"
290+
// }
291+
//
292+
// On import, we can directly take the non-transformed parameter, which is
293+
// we check for the suffix and optionally remove it from the requested
294+
// name next:
295+
296+
const std::string constParamName = p["c"].val();
297+
const bool isInverted = endsWith(constParamName, "_exponential_inverted");
298+
const std::string suffixToRemove = isInverted ? "_exponential_inverted" : "";
299+
RooAbsReal *c = tool->request<RooAbsReal>(removeSuffix(constParamName, suffixToRemove), name);
300+
301+
tool->wsEmplace<RooExponential>(name, *x, *c, !isInverted);
302+
261303
return true;
262304
}
263305
};
@@ -562,12 +604,12 @@ class RooLogNormalStreamer : public RooFit::JSONIO::Exporter {
562604
auto &m0 = pdf->getMedian();
563605
auto &k = pdf->getShapeK();
564606

565-
if(pdf->useStandardParametrization()) {
607+
if (pdf->useStandardParametrization()) {
566608
elem["mu"] << m0.GetName();
567609
elem["sigma"] << k.GetName();
568610
} else {
569-
elem["mu"] << tool->exportTransformed(&m0, "lognormal", "log", "log(%s)");
570-
elem["sigma"] << tool->exportTransformed(&k, "lognormal", "log", "log(%s)");
611+
elem["mu"] << tool->exportTransformed(&m0, "_lognormal_log", "log(%s)");
612+
elem["sigma"] << tool->exportTransformed(&k, "_lognormal_log", "log(%s)");
571613
}
572614

573615
return true;
@@ -586,7 +628,7 @@ class RooExponentialStreamer : public RooFit::JSONIO::Exporter {
586628
if (pdf->negateCoefficient()) {
587629
elem["c"] << c.GetName();
588630
} else {
589-
elem["c"] << tool->exportTransformed(&c, "exponential", "inverted", "-%s");
631+
elem["c"] << tool->exportTransformed(&c, "_exponential_inverted", "-%s");
590632
}
591633

592634
return true;

roofit/hs3/src/JSONIOUtils.cxx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "JSONIOUtils.h"
22

3+
#include <string>
4+
35
using RooFit::Detail::JSONNode;
46
using RooFit::Detail::JSONTree;
57

@@ -13,6 +15,21 @@ bool endsWith(std::string_view str, std::string_view suffix)
1315
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
1416
}
1517

18+
std::string removePrefix(std::string_view str, std::string_view prefix)
19+
{
20+
std::string out;
21+
out += str;
22+
out = out.substr(prefix.length());
23+
return out;
24+
}
25+
std::string removeSuffix(std::string_view str, std::string_view suffix)
26+
{
27+
std::string out;
28+
out += str;
29+
out = out.substr(0, out.length() - suffix.length());
30+
return out;
31+
}
32+
1633
std::unique_ptr<RooFit::Detail::JSONTree> varJSONString(const JSONNode &treeRoot)
1734
{
1835
std::string varName = treeRoot.find("name")->val();
@@ -71,4 +88,4 @@ std::unique_ptr<RooFit::Detail::JSONTree> varJSONString(const JSONNode &treeRoot
7188
}
7289

7390
return jsonDict;
74-
}
91+
}

roofit/hs3/src/JSONIOUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
bool startsWith(std::string_view str, std::string_view prefix);
88
bool endsWith(std::string_view str, std::string_view suffix);
9+
std::string removePrefix(std::string_view str, std::string_view prefix);
10+
std::string removeSuffix(std::string_view str, std::string_view suffix);
911
std::unique_ptr<RooFit::Detail::JSONTree> varJSONString(const RooFit::Detail::JSONNode &treeRoot);
1012

1113
#endif

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -956,57 +956,17 @@ void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &
956956
}
957957
}
958958

959-
RooAbsReal *RooJSONFactoryWSTool::importTransformed(const std::string &name, const std::string &tag,
960-
const std::string &operation_name, const std::string &formula)
959+
std::string RooJSONFactoryWSTool::exportTransformed(const RooAbsReal *original, const std::string &suffix,
960+
const std::string &formula)
961961
{
962-
RooAbsReal *transformed = nullptr;
963-
const std::string tagname = "autogen_transform_" + tag;
964-
if (this->hasAttribute(name, tagname)) {
965-
const std::string &original = this->getStringAttribute(name, tagname + "_original");
966-
transformed = this->workspace()->function(original);
967-
if (transformed)
968-
return transformed;
969-
}
970-
const std::string newname = name + "_" + tag + "_" + operation_name;
971-
transformed = this->workspace()->function(newname);
972-
if (!transformed) {
973-
auto *original = this->workspace()->arg(name);
974-
if (!original) {
975-
error("unable to import transformed of '" + name + "', original not present.");
976-
}
977-
RooArgSet components{*original};
978-
const std::string &expression = TString::Format(formula.c_str(), name.c_str()).Data();
979-
transformed = &wsEmplace<RooFormulaVar>(newname, expression.c_str(), components);
980-
transformed->setAttribute(tagname.c_str());
981-
}
982-
return transformed;
983-
}
984-
985-
std::string RooJSONFactoryWSTool::exportTransformed(const RooAbsReal *original, const std::string &tag,
986-
const std::string &operation_name, const std::string &formula)
987-
{
988-
const std::string tagname = "autogen_transform_" + tag;
989-
if (original->getAttribute(tagname.c_str())) {
990-
if (const RooFormulaVar *trafo = dynamic_cast<const RooFormulaVar *>(original)) {
991-
return trafo->dependents().first()->GetName();
992-
}
993-
}
994-
995-
std::string newname = std::string(original->GetName()) + "_" + tag + "_" + operation_name;
996-
auto &trafo_node = this->createAdHoc("functions", newname);
962+
std::string newname = std::string(original->GetName()) + suffix;
963+
RooFit::Detail::JSONNode &trafo_node = appendNamedChild((*_rootnodeOutput)["functions"], newname);
997964
trafo_node["type"] << "generic_function";
998965
trafo_node["expression"] << TString::Format(formula.c_str(), original->GetName()).Data();
999-
this->setAttribute(newname, tagname);
1000-
this->setStringAttribute(newname, tagname + "_original", original->GetName());
966+
this->setAttribute(newname, "roofit_skip"); // this function should not be imported back in
1001967
return newname;
1002968
}
1003969

1004-
RooFit::Detail::JSONNode &RooJSONFactoryWSTool::createAdHoc(const std::string &toplevel, const std::string &name)
1005-
{
1006-
auto &collectionNode = (*_rootnodeOutput)[toplevel];
1007-
return appendNamedChild(collectionNode, name);
1008-
}
1009-
1010970
/**
1011971
* @brief Export an object from the workspace to a JSONNode.
1012972
*
@@ -1167,11 +1127,17 @@ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set<std::str
11671127
*/
11681128
void RooJSONFactoryWSTool::importFunction(const JSONNode &p, bool importAllDependants)
11691129
{
1130+
std::string name(RooJSONFactoryWSTool::name(p));
1131+
1132+
// If this node if marked to be skipped by RooFit, exit
1133+
if (hasAttribute(name, "roofit_skip")) {
1134+
return;
1135+
}
1136+
11701137
auto const &importers = RooFit::JSONIO::importers();
11711138
auto const &factoryExpressions = RooFit::JSONIO::importExpressions();
11721139

11731140
// some preparations: what type of function are we dealing with here?
1174-
std::string name(RooJSONFactoryWSTool::name(p));
11751141
if (!::isValidName(name)) {
11761142
std::stringstream ss;
11771143
ss << "RooJSONFactoryWSTool() function name '" << name << "' is not valid!" << std::endl;

0 commit comments

Comments
 (0)