Skip to content

Commit a7d6dfe

Browse files
committed
Merge branch 'dev'
2 parents 69a296d + 46fb322 commit a7d6dfe

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ optimization { execution_accelerators {
120120
}}
121121
```
122122

123+
### Default Optimization Options
124+
Optimization parameters for the default tflite interpreter can be passed using the `parameters` section of the model configuration.
125+
126+
By default the tflite interpreter will use the maximum number of threads available to the system.
127+
To set the number to threads available to the tflite interpreter you can add the following section to your model configuration:
128+
```
129+
parameters: {
130+
key: "tflite_num_threads"
131+
value: {
132+
string_value:"<num_threads>"
133+
}
134+
}
135+
```
136+
123137
### ArmNN Delegate Optimization Options
124138
Users also have the ability to specify ArmNN specific optimizations.
125139
The following options are available for CPU:

src/tflite.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class ModelState : public BackendModel {
6060
// Validate that model configuration is supported by this backend.
6161
// TRITONSERVER_Error* ValidateModelConfig();
6262

63+
// Default TFLite runtime options
64+
int32_t tflite_num_threads_ = int32_t(std::thread::hardware_concurrency());
65+
6366
#ifdef ARMNN_DELEGATE_ENABLE
6467
// ArmNN Delegate options
6568
bool use_armnn_delegate_cpu_ = false;
@@ -136,6 +139,37 @@ ModelState::LoadModel(
136139
("failed to load model " + Name()).c_str());
137140
}
138141

142+
// Handle tflite default interpeter options set in parameters
143+
{
144+
triton::common::TritonJson::Value params;
145+
if (ModelConfig().Find("parameters", &params)) {
146+
// Handle tflite_num_threads parameter
147+
std::string value_str;
148+
auto err = GetParameterValue(params, "tflite_num_threads", &value_str);
149+
150+
// tflite_num_threads is not required so clear error if not found
151+
if (err != nullptr) {
152+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
153+
return err;
154+
} else {
155+
TRITONSERVER_ErrorDelete(err);
156+
}
157+
} else {
158+
RETURN_IF_ERROR(ParseIntValue(value_str, &tflite_num_threads_));
159+
160+
if (tflite_num_threads_ < 0) {
161+
return TRITONSERVER_ErrorNew(
162+
TRITONSERVER_ERROR_INVALID_ARG,
163+
(std::string(
164+
"parameter 'tflite_num_threads' must be non-negative "
165+
"number for tflite model '") +
166+
Name() + "'")
167+
.c_str());
168+
}
169+
}
170+
}
171+
}
172+
139173
// Handle tflite optimizations from model config
140174
{
141175
triton::common::TritonJson::Value optimization;
@@ -536,7 +570,7 @@ ModelInstanceState::BuildInterpreter()
536570
}
537571

538572
// Tell interpreter to use max threads available to system
539-
if (interpreter_->SetNumThreads(std::thread::hardware_concurrency()) !=
573+
if (interpreter_->SetNumThreads(model_state_->tflite_num_threads_) !=
540574
kTfLiteOk) {
541575
return TRITONSERVER_ErrorNew(
542576
TRITONSERVER_ERROR_INTERNAL,

0 commit comments

Comments
 (0)