@@ -177,6 +177,8 @@ Evaluator::Evaluator(const RooAbsReal &absReal, bool useGPU)
177177
178178 _nodes.emplace_back ();
179179 auto &nodeInfo = _nodes.back ();
180+ _nodesMap[arg->namePtr ()] = &nodeInfo;
181+
180182 nodeInfo.absArg = arg;
181183 nodeInfo.originalOperMode = arg->operMode ();
182184 nodeInfo.iNode = iNode;
@@ -244,49 +246,51 @@ void Evaluator::setInput(std::string const &name, std::span<const double> inputA
244246 throw std::runtime_error (" Evaluator can only take device array as input in CUDA mode!" );
245247 }
246248
247- auto namePtr = RooNameReg::ptr (name.c_str ());
249+ // Check if "name" is used in the computation graph. If yes, add the span to
250+ // the data map and set the node info accordingly.
248251
249- // Iterate over the given data spans and add them to the data map. Check if
250- // they are used in the computation graph. If yes, add the span to the data
251- // map and set the node info accordingly.
252- std::size_t iNode = 0 ;
253- for (auto &info : _nodes) {
254- const bool fromArrayInput = info.absArg ->namePtr () == namePtr;
255- if (fromArrayInput) {
256- info.fromArrayInput = true ;
257- info.absArg ->setDataToken (iNode);
258- info.outputSize = inputArray.size ();
259- if (_useGPU && info.outputSize <= 1 ) {
260- // Empty or scalar observables from the data don't need to be
261- // copied to the GPU.
262- _evalContextCPU.set (info.absArg , inputArray);
263- _evalContextCUDA.set (info.absArg , inputArray);
264- } else if (_useGPU && info.outputSize > 1 ) {
265- // For simplicity, we put the data on both host and device for
266- // now. This could be optimized by inspecting the clients of the
267- // variable.
268- if (isOnDevice) {
269- _evalContextCUDA.set (info.absArg , inputArray);
270- auto gpuSpan = _evalContextCUDA.at (info.absArg );
271- info.buffer = _bufferManager->makeCpuBuffer (gpuSpan.size ());
272- info.buffer ->assignFromDevice (gpuSpan);
273- _evalContextCPU.set (info.absArg , {info.buffer ->hostReadPtr (), gpuSpan.size ()});
274- } else {
275- _evalContextCPU.set (info.absArg , inputArray);
276- auto cpuSpan = _evalContextCPU.at (info.absArg );
277- info.buffer = _bufferManager->makeGpuBuffer (cpuSpan.size ());
278- info.buffer ->assignFromHost (cpuSpan);
279- _evalContextCUDA.set (info.absArg , {info.buffer ->deviceReadPtr (), cpuSpan.size ()});
280- }
281- } else {
282- _evalContextCPU.set (info.absArg , inputArray);
283- }
284- }
285- info.isDirty = !info.fromArrayInput ;
286- ++iNode;
287- }
252+ auto found = _nodesMap.find (RooNameReg::ptr (name.c_str ()));
253+
254+ if (found == _nodesMap.end ())
255+ return ;
288256
289257 _needToUpdateOutputSizes = true ;
258+
259+ NodeInfo &info = *found->second ;
260+
261+ info.fromArrayInput = true ;
262+ info.absArg ->setDataToken (info.iNode );
263+ info.outputSize = inputArray.size ();
264+
265+ if (!_useGPU) {
266+ _evalContextCPU.set (info.absArg , inputArray);
267+ return ;
268+ }
269+
270+ if (info.outputSize <= 1 ) {
271+ // Empty or scalar observables from the data don't need to be
272+ // copied to the GPU.
273+ _evalContextCPU.set (info.absArg , inputArray);
274+ _evalContextCUDA.set (info.absArg , inputArray);
275+ return ;
276+ }
277+
278+ // For simplicity, we put the data on both host and device for
279+ // now. This could be optimized by inspecting the clients of the
280+ // variable.
281+ if (isOnDevice) {
282+ _evalContextCUDA.set (info.absArg , inputArray);
283+ auto gpuSpan = _evalContextCUDA.at (info.absArg );
284+ info.buffer = _bufferManager->makeCpuBuffer (gpuSpan.size ());
285+ info.buffer ->assignFromDevice (gpuSpan);
286+ _evalContextCPU.set (info.absArg , {info.buffer ->hostReadPtr (), gpuSpan.size ()});
287+ } else {
288+ _evalContextCPU.set (info.absArg , inputArray);
289+ auto cpuSpan = _evalContextCPU.at (info.absArg );
290+ info.buffer = _bufferManager->makeGpuBuffer (cpuSpan.size ());
291+ info.buffer ->assignFromHost (cpuSpan);
292+ _evalContextCUDA.set (info.absArg , {info.buffer ->deviceReadPtr (), cpuSpan.size ()});
293+ }
290294}
291295
292296void Evaluator::updateOutputSizes ()
@@ -309,6 +313,7 @@ void Evaluator::updateOutputSizes()
309313
310314 for (auto &info : _nodes) {
311315 info.outputSize = outputSizeMap.at (info.absArg );
316+ info.isDirty = true ;
312317
313318 // In principle we don't need dirty flag propagation because the driver
314319 // takes care of deciding which node needs to be re-evaluated. However,
0 commit comments