Skip to content

Commit e7d56d5

Browse files
authored
Add support for CUDAExecutionProvider to Onnx::Session (#137)
1 parent f819b2f commit e7d56d5

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/Onnx/Session.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
#include <chrono>
77

8+
#ifdef MODULE_CUDA
9+
#include <cuda_runtime.h>
10+
#endif
11+
812
#include "Util.hh"
913

1014
namespace Onnx {
@@ -21,6 +25,14 @@ const Core::ParameterInt Session::paramInterOpNumThreads("inter-op-num-threads",
2125
"number of threads to use between ops",
2226
1);
2327

28+
const Core::Choice Session::executionProviderChoice(
29+
"cpu", ExecutionProviderType::cpu,
30+
"cuda", ExecutionProviderType::cuda,
31+
Core::Choice::endMark());
32+
33+
const Core::ParameterChoice Session::paramExecutionProviderType(
34+
"execution-provider-type", &Session::executionProviderChoice, "type of execution provider", ExecutionProviderType::cpu);
35+
2436
Session::Session(Core::Configuration const& config)
2537
: Precursor(config),
2638
file_(paramFile(config)),
@@ -34,6 +46,37 @@ Session::Session(Core::Configuration const& config)
3446
Ort::SessionOptions session_opts;
3547
session_opts.SetIntraOpNumThreads(intraOpNumThreads_);
3648
session_opts.SetInterOpNumThreads(interOpNumThreads_);
49+
50+
auto providers = Ort::GetAvailableProviders();
51+
switch (paramExecutionProviderType(config)) {
52+
case ExecutionProviderType::cpu: {
53+
if (std::find(providers.begin(), providers.end(), "CPUExecutionProvider") == providers.end()) {
54+
error() << "Requested CPU execution provider for ONNX session but it is not available.";
55+
}
56+
break;
57+
}
58+
case ExecutionProviderType::cuda: {
59+
if (std::find(providers.begin(), providers.end(), "CUDAExecutionProvider") == providers.end()) {
60+
error() << "Requested CUDA execution provider for ONNX session but it is not available.";
61+
}
62+
#ifdef MODULE_CUDA
63+
int deviceCount = 0;
64+
if (cudaGetDeviceCount(&deviceCount) != cudaSuccess or deviceCount == 0) {
65+
error() << "Requested CUDA execution provider but no CUDA device was found.";
66+
}
67+
OrtCUDAProviderOptionsV2* cuda_opts = nullptr;
68+
Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&cuda_opts));
69+
session_opts.AppendExecutionProvider_CUDA_V2(*cuda_opts);
70+
Ort::GetApi().ReleaseCUDAProviderOptions(cuda_opts);
71+
break;
72+
#else
73+
error() << "Requested CUDA execution provider but RASR was not compiled with MODULE_CUDA which is required for it.";
74+
#endif
75+
}
76+
default:
77+
error() << "Execution provider for ONNX session not known.";
78+
}
79+
3780
session_ = Ort::Session(env_, file_.c_str(), session_opts);
3881

3982
size_t num_inputs = session_.GetInputCount();

src/Onnx/Session.hh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
namespace Onnx {
1616

17+
enum ExecutionProviderType {
18+
cpu,
19+
cuda
20+
};
21+
1722
class Session : public Core::Component {
1823
public:
1924
using Precursor = Core::Component;
@@ -22,6 +27,9 @@ public:
2227
static const Core::ParameterInt paramIntraOpNumThreads;
2328
static const Core::ParameterInt paramInterOpNumThreads;
2429

30+
static const Core::Choice executionProviderChoice;
31+
static const Core::ParameterChoice paramExecutionProviderType;
32+
2533
Session(Core::Configuration const& config);
2634
virtual ~Session() = default;
2735

0 commit comments

Comments
 (0)