Skip to content

Commit c6c6f5a

Browse files
committed
update for new ONNXruntime version (13)
1 parent 3ec1884 commit c6c6f5a

File tree

3 files changed

+38
-14
lines changed

3 files changed

+38
-14
lines changed

root/tmva/sofie/ONNXRuntimeInference_Template.cxx.in

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,26 @@ static void @FUNC_NAME@(benchmark::State& state, string model_path)
3333

3434
vector<const char*> input_node_names(nin);
3535
vector<const char*> output_node_names(nout);
36+
vector<std::string> inputStrings(nin);
37+
vector<std::string> outputStrings(nout);
3638

3739
Ort::AllocatorWithDefaultOptions allocator;
38-
for (int i = 0; i < nin; i++)
39-
input_node_names[i] = session.GetInputName(i, allocator);
40-
for (int i = 0; i < nout; i++)
41-
output_node_names[i] = session.GetOutputName(i, allocator);
42-
40+
for (int i = 0; i < nin; i++) {
41+
#if ORT_API_VERSION > 12
42+
inputStrings[i] = session.GetInputNameAllocated(i, allocator).get();
43+
#else
44+
inputStrings[i] = session.GetInputName(i, allocator);
45+
#endif
46+
input_node_names[i] = inputStrings[i].c_str();
47+
}
48+
for (int i = 0; i < nout; i++) {
49+
#if ORT_API_VERSION > 12
50+
outputStrings[i] = session.GetOutputNameAllocated(i, allocator).get();
51+
#else
52+
outputStrings[i] = session.GetOutputName(i, allocator);
53+
#endif
54+
output_node_names[i] = outputStrings[i].c_str();
55+
}
4356
// Getting the shapes
4457
vector<vector<int64_t>> input_node_dims(nin);
4558
vector<vector<int64_t>> output_node_dims(nout);

root/tmva/sofie/RDF_ONNXRuntime_Inference.cxx

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct ONNXFunctor {
2929

3030
std::vector<const char *> input_node_names;
3131
std::vector<const char *> output_node_names;
32+
std::vector<std::string> input_node_str;
33+
std::vector<std::string> output_node_str;
3234

3335
std::vector<float> input_tensor_values;
3436

@@ -53,19 +55,25 @@ struct ONNXFunctor {
5355
// std::cout << "benchmarking model " << model_path << std::endl;
5456
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
5557

56-
5758

58-
Ort::AllocatorWithDefaultOptions allocator;
59-
input_node_names.push_back(session->GetInputName(0, allocator));
60-
output_node_names.push_back( session->GetOutputName(0, allocator));
6159

60+
Ort::AllocatorWithDefaultOptions allocator;
61+
#if ORT_API_VERSION > 12
62+
input_node_str.push_back(session->GetInputNameAllocated(0, allocator).get());
63+
output_node_str.push_back(session->GetOutputNameAllocated(0, allocator).get());
64+
#else
65+
input_node_str.push_back(session->GetInputName(0, allocator));
66+
output_node_str.push_back( session->GetOutputName(0, allocator));
67+
#endif
68+
input_node_names.push_back(input_node_str.back().c_str());
69+
output_node_names.push_back(output_node_str.back().c_str());
6270
// Getting the shapes
6371

6472
input_node_dims = session->GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
6573
output_node_dims = session->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
6674

6775
// Calculating the dimension of the input tensor
68-
76+
6977

7078
size_t input_tensor_size = std::accumulate(input_node_dims.begin(), input_node_dims.end(), 1, std::multiplies<int>());
7179
//std::vector<float> input_tensor_values(input_tensor_size );
@@ -94,7 +102,7 @@ struct ONNXFunctor {
94102
inputArray[off + 5] = x5;
95103
inputArray[off + 6] = x6;
96104

97-
105+
98106

99107
auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), &inputTensor, 1, output_node_names.data(), 1);
100108
float * floatarr = output_tensors.front().GetTensorMutableData<float>();
@@ -130,8 +138,10 @@ void BM_RDF_ONNX_Inference(benchmark::State &state)
130138
auto fileName = "Higgs_data_full.root";
131139
// file is available at "https://cernbox.cern.ch/index.php/s/YuSHwTXBa0UBEhD/download";
132140
// do curl https://cernbox.cern.ch/index.php/s/XaPBtaGrnN38wU0 -o Higgs_data_full.root
141+
// https://cernbox.cern.ch/s/vLOqclhWirZEWpj
142+
std::string directLink = "https://cernbox.cern.ch/remote.php/dav/public-files/vLOqclhWirZEWpj/Higgs_data_full.root";
133143
if (gSystem->AccessPathName(fileName)) {
134-
std::string cmd = "curl https://cernbox.cern.ch/index.php/s/YuSHwTXBa0UBEhD/download -o ";
144+
std::string cmd = "curl " + directLink + " -o ";
135145
cmd += fileName;
136146
gSystem->Exec(cmd.c_str());
137147
}

root/tmva/sofie/RDF_SOFIE_Inference.cxx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ void BM_RDF_SOFIE_Inference(benchmark::State &state)
5757
auto fileName = "Higgs_data_full.root";
5858
//file is available at "https://cernbox.cern.ch/index.php/s/YuSHwTXBa0UBEhD/download";
5959
// do curl https://cernbox.cern.ch/index.php/s/XaPBtaGrnN38wU0 -o Higgs_data_full.root
60-
if (gSystem->AccessPathName(fileName) ) {
61-
std::string cmd = "curl https://cernbox.cern.ch/index.php/s/YuSHwTXBa0UBEhD/download -o ";
60+
std::string directLink = "https://cernbox.cern.ch/remote.php/dav/public-files/vLOqclhWirZEWpj/Higgs_data_full.root";
61+
if (gSystem->AccessPathName(fileName)) {
62+
std::string cmd = "curl " + directLink + " -o ";
6263
cmd += fileName;
6364
gSystem->Exec(cmd.c_str());
6465
}

0 commit comments

Comments
 (0)