2828
2929#include < cstdint>
3030#include < exception>
31+ #include < mutex>
3132
3233#include " libtorch_utils.h"
3334#include " triton/backend/backend_common.h"
6667// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6768//
6869
70+ namespace {
71+ std::once_flag pytorch_interop_threads_flag;
72+ std::once_flag pytorch_intraop_threads_flag;
73+ }
74+
6975namespace triton { namespace backend { namespace pytorch {
7076
7177//
@@ -509,13 +515,17 @@ ModelState::ParseParameters()
509515 }
510516 } else {
511517 if (intra_op_thread_count > 0 ) {
512- at::set_num_threads (intra_op_thread_count);
513- LOG_MESSAGE (
514- TRITONSERVER_LOG_INFO,
515- (std::string (" Intra op thread count is set to " ) +
516- std::to_string (intra_op_thread_count) + " for model instance '" +
517- Name () + " '" )
518- .c_str ());
518+ // at::set_num_threads() does not throw if called more than once, but issues warnings.
519+ // std::call_once() is useful to limit these.
520+ std::call_once (pytorch_intraop_threads_flag, [this , intra_op_thread_count](){
521+ at::set_num_threads (intra_op_thread_count);
522+ LOG_MESSAGE (
523+ TRITONSERVER_LOG_INFO,
524+ (std::string (" Intra op thread count is set to " ) +
525+ std::to_string (intra_op_thread_count) + " for model instance '" +
526+ this ->Name () + " '" )
527+ .c_str ());
528+ });
519529 }
520530 }
521531
@@ -533,13 +543,28 @@ ModelState::ParseParameters()
533543 }
534544 } else {
535545 if (inter_op_thread_count > 0 ) {
536- at::set_num_interop_threads (inter_op_thread_count);
537- LOG_MESSAGE (
538- TRITONSERVER_LOG_INFO,
539- (std::string (" Inter op thread count is set to " ) +
540- std::to_string (inter_op_thread_count) + " for model instance '" +
541- Name () + " '" )
542- .c_str ());
546+ // at::set_num_interop_threads() throws if called more than once.
547+ // std::call_once() should prevent this, but try/catch is additionally used for safety.
548+ std::call_once (pytorch_interop_threads_flag, [this , inter_op_thread_count](){
549+ try {
550+ at::set_num_interop_threads (inter_op_thread_count);
551+ LOG_MESSAGE (
552+ TRITONSERVER_LOG_INFO,
553+ (std::string (" Inter op thread count is set to " ) +
554+ std::to_string (inter_op_thread_count) + " for model instance '" +
555+ Name () + " '" )
556+ .c_str ());
557+ } catch (const c10::Error& e) {
558+ int current_inter_op_thread_count = at::get_num_interop_threads ();
559+ bool current_is_requested = inter_op_thread_count == current_inter_op_thread_count;
560+ LOG_MESSAGE (
561+ TRITONSERVER_LOG_INFO,
562+ (std::string (" Inter op thread count is already set to " ) +
563+ std::to_string (current_inter_op_thread_count) +
564+ (current_is_requested ? " " : " and cannot be changed. Setting ignored" ) +
565+ " for model instance '" + this ->Name () + " '" ).c_str ());
566+ }
567+ });
543568 }
544569 }
545570 }
0 commit comments