33#include " TROOT.h"
44#include " TSystem.h"
55#include " ROOT/RDataFrame.hxx"
6+ #include " TMath.h"
67
78#include < onnxruntime_cxx_api.h>
89
@@ -19,16 +20,13 @@ struct ONNXFunctor {
1920 // std::vector<float> input;
2021 // std::vector<std::shared_ptr<Func>> sessions;
2122
22- std::map<std::string, double > inputs;
23- std::vector<std::string> names;
24-
2523 std::shared_ptr<Ort::Session> session;
2624
2725 // td::vector<Ort::Value> input_tensors;
2826
2927 // Ort::Value * ort_input = nullptr;
3028
31- // float *input_arr = nullptr;
29+ // float *inputArray = nullptr;
3230
3331 std::vector<const char *> input_node_names;
3432 std::vector<const char *> output_node_names;
@@ -38,6 +36,10 @@ struct ONNXFunctor {
3836 std::vector<int64_t > input_node_dims;
3937 std::vector<int64_t > output_node_dims;
4038
39+ Ort::Value inputTensor{nullptr };
40+
41+ float *inputArray = nullptr ;
42+
4143 ONNXFunctor (unsigned nslots)
4244 {
4345
@@ -65,9 +67,6 @@ struct ONNXFunctor {
6567
6668 // Calculating the dimension of the input tensor
6769
68- // int bsize = input_node_dims[0];
69- // std::cout << "Using bsize = " << bsize << std::endl;
70- // int nbatches = nevt / bsize;
7170
7271 size_t input_tensor_size = std::accumulate (input_node_dims.begin (), input_node_dims.end (), 1 , std::multiplies<int >());
7372 // std::vector<float> input_tensor_values(input_tensor_size );
@@ -76,57 +75,50 @@ struct ONNXFunctor {
7675
7776 auto memory_info = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault);
7877
79- // input_tensors.push_back(Ort::Value::CreateTensor<float>(
80- // memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size()) );
81-
82-
83- // Ort::Value
84- // Ort::Value *ort_input = new Ort::Value(nullptr);
85- // // input_tensor =
86- // *ort_input = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(),
87- // input_node_dims.data(), input_node_dims.size());
88-
89- // input_arr = input_tensor.GetTensorMutableData<float>();
90-
91- // Running the model
92- // input_arr = input_tensors[0].GetTensorMutableData<float>();
93-
94- // /////
78+ inputTensor =
79+ Ort::Value::CreateTensor<float >(memory_info, input_tensor_values.data (), input_tensor_values.size (),
80+ input_node_dims.data (), input_node_dims.size ());
9581
96-
97- // // Load inputs from argv
98- // std::cout << "input size is " << config.inputs.size() << std::endl;
99- // for (size_t n = 0; n < config.inputs.size(); n++) {
100- // inputs[config.inputs.at(n).name] = 0.0;
101- // names.push_back(config.inputs.at(n).name);
102- // }
82+ inputArray = inputTensor.GetTensorMutableData <float >();
10383 }
10484
10585 double operator ()(unsigned nslots, float x0, float x1, float x2, float x3, float x4, float x5, float x6)
10686 {
10787
108- // not sure how to cache input ort tensor
109- auto memory_info = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault);
110- Ort::Value
111- input_tensor = Ort::Value::CreateTensor<float >(
112- memory_info, input_tensor_values.data (), input_tensor_values.size (), input_node_dims.data (), input_node_dims.size ());
113- float * input_arr = input_tensor.GetTensorMutableData <float >();
11488
11589 int off = 0 ;
116- input_arr [off] = x0;
117- input_arr [off + 1 ] = x1;
118- input_arr [off + 2 ] = x2;
119- input_arr [off + 3 ] = x3;
120- input_arr [off + 4 ] = x4;
121- input_arr [off + 5 ] = x5;
122- input_arr [off + 6 ] = x6;
90+ inputArray [off] = x0;
91+ inputArray [off + 1 ] = x1;
92+ inputArray [off + 2 ] = x2;
93+ inputArray [off + 3 ] = x3;
94+ inputArray [off + 4 ] = x4;
95+ inputArray [off + 5 ] = x5;
96+ inputArray [off + 6 ] = x6;
12397
12498
12599
126- auto output_tensors = session->Run (Ort::RunOptions{nullptr }, input_node_names.data (), &input_tensor , 1 , output_node_names.data (), 1 );
100+ auto output_tensors = session->Run (Ort::RunOptions{nullptr }, input_node_names.data (), &inputTensor , 1 , output_node_names.data (), 1 );
127101 float * floatarr = output_tensors.front ().GetTensorMutableData <float >();
128102 return floatarr[0 ];
129103 }
104+
105+ // need copy ctor for ONNXruntime
106+ // because I cannot copy Ort::Value
107+ ONNXFunctor (const ONNXFunctor & rhs) {
108+ session = rhs.session ;
109+ input_node_names = rhs.input_node_names ;
110+ output_node_names = rhs.output_node_names ;
111+
112+ input_tensor_values = rhs.input_tensor_values ;
113+
114+ input_node_dims = rhs.input_node_dims ;
115+ output_node_dims = rhs.output_node_dims ;
116+
117+ auto memory_info = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault);
118+ inputTensor = Ort::Value::CreateTensor<float >(memory_info, input_tensor_values.data (), input_tensor_values.size (),
119+ input_node_dims.data (), input_node_dims.size ());
120+ inputArray = inputTensor.GetTensorMutableData <float >();
121+ }
130122};
131123
132124void BM_RDF_ONNX_Inference (benchmark::State &state)
@@ -149,6 +141,9 @@ void BM_RDF_ONNX_Inference(benchmark::State &state)
149141
150142 ONNXFunctor functor (nslot);
151143
144+ std::vector<double > durations;
145+ double ntot = 0 ;
146+
152147 for (auto _ : state) {
153148
154149 auto h1 = df.DefineSlot (" DNN_Value" , functor, {" m_jj" , " m_jjj" , " m_lv" , " m_jlv" , " m_bb" , " m_wbb" , " m_wwbb" })
@@ -160,14 +155,17 @@ void BM_RDF_ONNX_Inference(benchmark::State &state)
160155 auto t2 = std::chrono::high_resolution_clock::now ();
161156 auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count ();
162157
163- std::cout << " Processed " << n << " entries "
164- << " time = " << duration / 1 . E6 << " (sec) time/event = " << duration / double (n) << " musec "
165- << std::endl;
166-
167- // h1->DrawClone() ;
158+ durations. push_back (duration/ 1 . E6 );
159+ ntot += n;
160+ // std::cout << " Processed " << n << " entries "
161+ // << " time = " << duration / 1.E6 << " (sec) time/event = " << duration / double(n) << " musec"
162+ // << std::endl ;
168163 }
164+ double avgDuration = TMath::Mean (durations.begin (), durations.end ());
165+ state.counters [" avg-time(s)" ] = avgDuration;
166+ state.counters [" time/evt(s)" ] = avgDuration * double (durations.size ()) / ntot;
169167}
170168
171-
172169BENCHMARK (BM_RDF_ONNX_Inference)->Unit(benchmark::kMillisecond );
170+
173171BENCHMARK_MAIN ();
0 commit comments