Skip to content

Commit 0af87b9

Browse files
committed
[RF] Support for categorical parameters in codegen
1 parent 5b9d51d commit 0af87b9

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
154154
/// represents the data entry.
155155
class RooFuncWrapper {
156156
public:
157-
RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const& paramSet);
157+
RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const &paramSet);
158158

159159
bool hasGradient() const { return _hasGradient; }
160160
void gradient(double *out) const
@@ -221,7 +221,8 @@ void replaceAll(std::string &str, const std::string &from, const std::string &to
221221

222222
} // namespace
223223

224-
RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const&paramSet)
224+
RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
225+
RooArgSet const &paramSet)
225226
{
226227
// Load the parameters and observables.
227228
auto spans = loadParamsAndData(paramSet, data, simPdf);
@@ -290,15 +291,6 @@ RooFuncWrapper::loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *d
290291
}
291292

292293
for (auto *param : paramSet) {
293-
if (!dynamic_cast<RooAbsReal *>(param)) {
294-
if (param->isConstant()) {
295-
continue;
296-
}
297-
std::stringstream errorMsg;
298-
errorMsg << "In creation of function wrapper: input param expected to be of type RooAbsReal.";
299-
oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
300-
throw std::runtime_error(errorMsg.str().c_str());
301-
}
302294
if (spans.find(param) == spans.end()) {
303295
_params.add(*param);
304296
}
@@ -358,8 +350,10 @@ void RooFuncWrapper::createGradient()
358350

359351
void RooFuncWrapper::updateGradientVarBuffer() const
360352
{
361-
std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(),
362-
[](RooAbsArg *obj) { return static_cast<RooAbsReal *>(obj)->getVal(); });
353+
std::transform(_params.begin(), _params.end(), _gradientVarBuffer.begin(), [](RooAbsArg *obj) {
354+
return obj->isCategory() ? static_cast<RooAbsCategory *>(obj)->getCurrentIndex()
355+
: static_cast<RooAbsReal *>(obj)->getVal();
356+
});
363357
}
364358

365359
/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
@@ -487,9 +481,10 @@ void RooEvaluatorWrapper::createFuncWrapper()
487481
{
488482
// Get the parameters.
489483
RooArgSet paramSet;
490-
this->getParameters(_data ? _data->get() : nullptr, paramSet);
484+
this->getParameters(_data ? _data->get() : nullptr, paramSet, /*sripDisconnectedParams=*/false);
491485

492-
_funcWrapper = std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf), paramSet);
486+
_funcWrapper =
487+
std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast<RooSimultaneous const *>(_pdf), paramSet);
493488
}
494489

495490
void RooEvaluatorWrapper::generateGradient()

0 commit comments

Comments
 (0)