@@ -60,6 +60,9 @@ class ModelState : public BackendModel {
6060 // Validate that model configuration is supported by this backend.
6161 // TRITONSERVER_Error* ValidateModelConfig();
6262
63+ // Default TFLite runtime options
64+ int32_t tflite_num_threads_ = int32_t (std::thread::hardware_concurrency());
65+
6366#ifdef ARMNN_DELEGATE_ENABLE
6467 // ArmNN Delegate options
6568 bool use_armnn_delegate_cpu_ = false ;
@@ -136,6 +139,37 @@ ModelState::LoadModel(
136139 (" failed to load model " + Name ()).c_str ());
137140 }
138141
142+ // Handle tflite default interpeter options set in parameters
143+ {
144+ triton::common::TritonJson::Value params;
145+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
146+ // Handle tflite_num_threads parameter
147+ std::string value_str;
148+ auto err = GetParameterValue (params, " tflite_num_threads" , &value_str);
149+
150+ // tflite_num_threads is not required so clear error if not found
151+ if (err != nullptr ) {
152+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
153+ return err;
154+ } else {
155+ TRITONSERVER_ErrorDelete (err);
156+ }
157+ } else {
158+ RETURN_IF_ERROR (ParseIntValue (value_str, &tflite_num_threads_));
159+
160+ if (tflite_num_threads_ < 0 ) {
161+ return TRITONSERVER_ErrorNew (
162+ TRITONSERVER_ERROR_INVALID_ARG,
163+ (std::string (
164+ " parameter 'tflite_num_threads' must be non-negative "
165+ " number for tflite model '" ) +
166+ Name () + " '" )
167+ .c_str ());
168+ }
169+ }
170+ }
171+ }
172+
139173 // Handle tflite optimizations from model config
140174 {
141175 triton::common::TritonJson::Value optimization;
@@ -536,7 +570,7 @@ ModelInstanceState::BuildInterpreter()
536570 }
537571
538572 // Tell interpreter to use max threads available to system
539- if (interpreter_->SetNumThreads (std::thread::hardware_concurrency () ) !=
573+ if (interpreter_->SetNumThreads (model_state_-> tflite_num_threads_ ) !=
540574 kTfLiteOk ) {
541575 return TRITONSERVER_ErrorNew (
542576 TRITONSERVER_ERROR_INTERNAL,
0 commit comments