Skip to content

Commit badc9d2

Browse files
committed
use call_once to prevent repeated thread count setting
1 parent db70751 commit badc9d2

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

src/libtorch.cc

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <cstdint>
3030
#include <exception>
31+
#include <mutex>
3132

3233
#include "libtorch_utils.h"
3334
#include "triton/backend/backend_common.h"
@@ -66,6 +67,11 @@
6667
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6768
//
6869

70+
namespace {
71+
std::once_flag pytorch_interop_threads_flag;
72+
std::once_flag pytorch_intraop_threads_flag;
73+
}
74+
6975
namespace triton { namespace backend { namespace pytorch {
7076

7177
//
@@ -509,13 +515,17 @@ ModelState::ParseParameters()
509515
}
510516
} else {
511517
if (intra_op_thread_count > 0) {
512-
at::set_num_threads(intra_op_thread_count);
513-
LOG_MESSAGE(
514-
TRITONSERVER_LOG_INFO,
515-
(std::string("Intra op thread count is set to ") +
516-
std::to_string(intra_op_thread_count) + " for model instance '" +
517-
Name() + "'")
518-
.c_str());
518+
// at::set_num_threads() does not throw if called more than once, but issues warnings.
519+
// std::call_once() is useful to limit these.
520+
std::call_once(pytorch_intraop_threads_flag, [this, intra_op_thread_count](){
521+
at::set_num_threads(intra_op_thread_count);
522+
LOG_MESSAGE(
523+
TRITONSERVER_LOG_INFO,
524+
(std::string("Intra op thread count is set to ") +
525+
std::to_string(intra_op_thread_count) + " for model instance '" +
526+
this->Name() + "'")
527+
.c_str());
528+
});
519529
}
520530
}
521531

@@ -533,13 +543,28 @@ ModelState::ParseParameters()
533543
}
534544
} else {
535545
if (inter_op_thread_count > 0) {
536-
at::set_num_interop_threads(inter_op_thread_count);
537-
LOG_MESSAGE(
538-
TRITONSERVER_LOG_INFO,
539-
(std::string("Inter op thread count is set to ") +
540-
std::to_string(inter_op_thread_count) + " for model instance '" +
541-
Name() + "'")
542-
.c_str());
546+
// at::set_num_interop_threads() throws if called more than once.
547+
// std::call_once() should prevent this, but try/catch is additionally used for safety.
548+
std::call_once(pytorch_interop_threads_flag, [this, inter_op_thread_count](){
549+
try {
550+
at::set_num_interop_threads(inter_op_thread_count);
551+
LOG_MESSAGE(
552+
TRITONSERVER_LOG_INFO,
553+
(std::string("Inter op thread count is set to ") +
554+
std::to_string(inter_op_thread_count) + " for model instance '" +
555+
Name() + "'")
556+
.c_str());
557+
} catch (const c10::Error& e) {
558+
int current_inter_op_thread_count = at::get_num_interop_threads();
559+
bool current_is_requested = inter_op_thread_count == current_inter_op_thread_count;
560+
LOG_MESSAGE(
561+
TRITONSERVER_LOG_INFO,
562+
(std::string("Inter op thread count is already set to ") +
563+
std::to_string(current_inter_op_thread_count) +
564+
(current_is_requested ? "" : " and cannot be changed. Setting ignored") +
565+
" for model instance '" + this->Name() + "'").c_str());
566+
}
567+
});
543568
}
544569
}
545570
}

0 commit comments

Comments
 (0)