2828
2929#include  < cstdint> 
3030#include  < exception> 
31+ #include  < mutex> 
3132
3233#include  " libtorch_utils.h" 
3334#include  " triton/backend/backend_common.h" 
6667//  PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6768// 
6869
70+ namespace  {
71+ std::once_flag pytorch_interop_threads_flag;
72+ std::once_flag pytorch_intraop_threads_flag;
73+ }  //  namespace
74+ 
6975namespace  triton  { namespace  backend  { namespace  pytorch  {
7076
7177// 
@@ -509,11 +515,15 @@ ModelState::ParseParameters()
509515      }
510516    } else  {
511517      if  (intra_op_thread_count > 0 ) {
512-         at::set_num_threads (intra_op_thread_count);
518+         //  at::set_num_threads() does not throw if called more than once, but
519+         //  issues warnings. std::call_once() is useful to limit these.
520+         std::call_once (pytorch_intraop_threads_flag, [intra_op_thread_count]() {
521+           at::set_num_threads (intra_op_thread_count);
522+         });
513523        LOG_MESSAGE (
514524            TRITONSERVER_LOG_INFO,
515525            (std::string (" Intra op thread count is set to "  ) +
516-              std::to_string (intra_op_thread_count ) + "  for model instance '"   +
526+              std::to_string (at::get_num_threads () ) + "  for model instance '"   +
517527             Name () + " '"  )
518528                .c_str ());
519529      }
@@ -533,12 +543,22 @@ ModelState::ParseParameters()
533543      }
534544    } else  {
535545      if  (inter_op_thread_count > 0 ) {
536-         at::set_num_interop_threads (inter_op_thread_count);
546+         //  at::set_num_interop_threads() throws if called more than once.
547+         //  std::call_once() should prevent this, but try/catch is additionally
548+         //  used for safety.
549+         std::call_once (pytorch_interop_threads_flag, [inter_op_thread_count]() {
550+           try  {
551+             at::set_num_interop_threads (inter_op_thread_count);
552+           }
553+           catch  (const  c10::Error& e) {
554+             //  do nothing
555+           }
556+         });
537557        LOG_MESSAGE (
538558            TRITONSERVER_LOG_INFO,
539559            (std::string (" Inter op thread count is set to "  ) +
540-              std::to_string (inter_op_thread_count) +  "  for model instance ' "   +
541-              Name () + " '"  )
560+              std::to_string (at::get_num_interop_threads ())  +
561+              "  for model instance ' "  +  Name () + " '"  )
542562                .c_str ());
543563      }
544564    }
0 commit comments