55
66#include < chrono>
77
8+ #ifdef MODULE_CUDA
9+ #include < cuda_runtime.h>
10+ #endif
11+
812#include " Util.hh"
913
1014namespace Onnx {
@@ -21,6 +25,14 @@ const Core::ParameterInt Session::paramInterOpNumThreads("inter-op-num-threads",
2125 " number of threads to use between ops" ,
2226 1 );
2327
28+ const Core::Choice Session::executionProviderChoice (
29+ " cpu" , ExecutionProviderType::cpu,
30+ " cuda" , ExecutionProviderType::cuda,
31+ Core::Choice::endMark ());
32+
33+ const Core::ParameterChoice Session::paramExecutionProviderType (
34+ " execution-provider-type" , &Session::executionProviderChoice, " type of execution provider" , ExecutionProviderType::cpu);
35+
2436Session::Session (Core::Configuration const & config)
2537 : Precursor(config),
2638 file_(paramFile(config)),
@@ -34,6 +46,37 @@ Session::Session(Core::Configuration const& config)
3446 Ort::SessionOptions session_opts;
3547 session_opts.SetIntraOpNumThreads (intraOpNumThreads_);
3648 session_opts.SetInterOpNumThreads (interOpNumThreads_);
49+
50+ auto providers = Ort::GetAvailableProviders ();
51+ switch (paramExecutionProviderType (config)) {
52+ case ExecutionProviderType::cpu: {
53+ if (std::find (providers.begin (), providers.end (), " CPUExecutionProvider" ) == providers.end ()) {
54+ error () << " Requested CPU execution provider for ONNX session but it is not available." ;
55+ }
56+ break ;
57+ }
58+ case ExecutionProviderType::cuda: {
59+ if (std::find (providers.begin (), providers.end (), " CUDAExecutionProvider" ) == providers.end ()) {
60+ error () << " Requested CUDA execution provider for ONNX session but it is not available." ;
61+ }
62+ #ifdef MODULE_CUDA
63+ int deviceCount = 0 ;
64+ if (cudaGetDeviceCount (&deviceCount) != cudaSuccess or deviceCount == 0 ) {
65+ error () << " Requested CUDA execution provider but no CUDA device was found." ;
66+ }
67+ OrtCUDAProviderOptionsV2* cuda_opts = nullptr ;
68+ Ort::ThrowOnError (Ort::GetApi ().CreateCUDAProviderOptions (&cuda_opts));
69+ session_opts.AppendExecutionProvider_CUDA_V2 (*cuda_opts);
70+ Ort::GetApi ().ReleaseCUDAProviderOptions (cuda_opts);
71+ break ;
72+ #else
73+ error () << " Requested CUDA execution provider but RASR was not compiled with MODULE_CUDA which is required for it." ;
74+ #endif
75+ }
76+ default :
77+ error () << " Execution provider for ONNX session not known." ;
78+ }
79+
3780 session_ = Ort::Session (env_, file_.c_str (), session_opts);
3881
3982 size_t num_inputs = session_.GetInputCount ();
0 commit comments