@@ -154,7 +154,7 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
154154// / represents the data entry.
155155class RooFuncWrapper {
156156public:
157- RooFuncWrapper (RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const & paramSet);
157+ RooFuncWrapper (RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const & paramSet);
158158
159159 bool hasGradient () const { return _hasGradient; }
160160 void gradient (double *out) const
@@ -221,7 +221,8 @@ void replaceAll(std::string &str, const std::string &from, const std::string &to
221221
222222} // namespace
223223
224- RooFuncWrapper::RooFuncWrapper (RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf, RooArgSet const ¶mSet)
224+ RooFuncWrapper::RooFuncWrapper (RooAbsReal &obj, const RooAbsData *data, RooSimultaneous const *simPdf,
225+ RooArgSet const ¶mSet)
225226{
226227 // Load the parameters and observables.
227228 auto spans = loadParamsAndData (paramSet, data, simPdf);
@@ -290,15 +291,6 @@ RooFuncWrapper::loadParamsAndData(RooArgSet const ¶mSet, const RooAbsData *d
290291 }
291292
292293 for (auto *param : paramSet) {
293- if (!dynamic_cast <RooAbsReal *>(param)) {
294- if (param->isConstant ()) {
295- continue ;
296- }
297- std::stringstream errorMsg;
298- errorMsg << " In creation of function wrapper: input param expected to be of type RooAbsReal." ;
299- oocoutE (nullptr , InputArguments) << errorMsg.str () << std::endl;
300- throw std::runtime_error (errorMsg.str ().c_str ());
301- }
302294 if (spans.find (param) == spans.end ()) {
303295 _params.add (*param);
304296 }
@@ -358,8 +350,10 @@ void RooFuncWrapper::createGradient()
358350
359351void RooFuncWrapper::updateGradientVarBuffer () const
360352{
361- std::transform (_params.begin (), _params.end (), _gradientVarBuffer.begin (),
362- [](RooAbsArg *obj) { return static_cast <RooAbsReal *>(obj)->getVal (); });
353+ std::transform (_params.begin (), _params.end (), _gradientVarBuffer.begin (), [](RooAbsArg *obj) {
354+ return obj->isCategory () ? static_cast <RooAbsCategory *>(obj)->getCurrentIndex ()
355+ : static_cast <RooAbsReal *>(obj)->getVal ();
356+ });
363357}
364358
365359// / @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
@@ -487,9 +481,10 @@ void RooEvaluatorWrapper::createFuncWrapper()
487481{
488482 // Get the parameters.
489483 RooArgSet paramSet;
490- this ->getParameters (_data ? _data->get () : nullptr , paramSet);
484+ this ->getParameters (_data ? _data->get () : nullptr , paramSet, /* sripDisconnectedParams= */ false );
491485
492- _funcWrapper = std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast <RooSimultaneous const *>(_pdf), paramSet);
486+ _funcWrapper =
487+ std::make_unique<RooFuncWrapper>(*_topNode, _data, dynamic_cast <RooSimultaneous const *>(_pdf), paramSet);
493488}
494489
495490void RooEvaluatorWrapper::generateGradient ()
0 commit comments