diff --git a/tensorflow/lite/micro/kernels/reduce.h b/tensorflow/lite/micro/kernels/reduce.h index a9d007c3df4..881c693742d 100644 --- a/tensorflow/lite/micro/kernels/reduce.h +++ b/tensorflow/lite/micro/kernels/reduce.h @@ -24,14 +24,12 @@ limitations under the License. namespace tflite { -extern const int kMaxNumberOfAxis; -extern const int kMaxNumberOfReducedAxis; - struct OpDataReduce { int32_t multiplier; int shift; - int temp_buffer_idx; - int resolved_axis_idx; + int scratch_accumulator_idx; + int scratch_resolved_axis_idx; + int scratch_input_iter_idx; int input_zp; float input_scale; int output_zp; diff --git a/tensorflow/lite/micro/kernels/reduce_common.cc b/tensorflow/lite/micro/kernels/reduce_common.cc index 25303060dac..2de973779ae 100644 --- a/tensorflow/lite/micro/kernels/reduce_common.cc +++ b/tensorflow/lite/micro/kernels/reduce_common.cc @@ -28,9 +28,6 @@ limitations under the License. namespace tflite { -const int kMaxNumberOfAxis = 5; -const int kMaxNumberOfReducedAxis = 2; - namespace { TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node, @@ -80,7 +77,7 @@ void ResolveAxis(const int* axis_data, int axis_count, template TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node, - int* temp_index, int* resolved_axis, + int* input_iter, int* resolved_axis, int32_t* temp_sum, OpDataReduce* op_data, bool compute_sum) { const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); @@ -96,7 +93,7 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node, op_data->multiplier, op_data->shift, op_data->output_zp, &output->dims->data[0], output->dims->size, tflite::micro::GetTensorData(axis), op_data->num_axis, - params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum); + params->keep_dims, input_iter, resolved_axis, temp_sum, compute_sum); TF_LITE_ENSURE(context, result); return kTfLiteOk; @@ -105,11 +102,11 @@ TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node, int num_axis, OpDataReduce* op_data, - int* temp_index, int* resolved_axis) { + int* input_iter, int* resolved_axis) { int32_t* temp_sum = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + context->GetScratchBuffer(context, op_data->scratch_accumulator_idx)); - QuantizedMeanOrSum(context, node, temp_index, resolved_axis, + QuantizedMeanOrSum(context, node, input_iter, resolved_axis, temp_sum, op_data, /*compute_sum=*/false); return kTfLiteOk; @@ -155,10 +152,10 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node, // Interpret an axis tensor with null dimensions as a scalar int num_axis = static_cast(ElementCount(*axis->dims)); - int* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + int* input_iter = static_cast( + context->GetScratchBuffer(context, op_data->scratch_input_iter_idx)); int* resolved_axis = static_cast( - context->GetScratchBuffer(context, op_data->resolved_axis_idx)); + context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx)); switch (input->type) { case kTfLiteFloat32: { MinMaxReducerCompare reducer(evalType); @@ -169,7 +166,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node, input->dims->size, tflite::micro::GetTensorData(output), output->dims->data, output->dims->size, tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_buffer, resolved_axis, + params->keep_dims, input_iter, resolved_axis, reducer.initialValue(), reducer.compare())); } break; case kTfLiteInt8: { @@ -184,7 +181,7 @@ TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node, input->dims->size, tflite::micro::GetTensorData(output), output->dims->data, output->dims->size, tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_buffer, resolved_axis, + params->keep_dims, input_iter, resolved_axis, reducer.initialValue(), reducer.compare())); } break; default: @@ -211,12 +208,11 @@ TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node, op_data->output_zp = output->params.zero_point; op_data->output_scale = output->params.scale; op_data->num_output_elements = NumElements(output); - context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size, - &op_data->temp_buffer_idx); + &op_data->scratch_input_iter_idx); context->RequestScratchBufferInArena( context, sizeof(int) * static_cast(ElementCount(*axis->dims)), - &op_data->resolved_axis_idx); + &op_data->scratch_resolved_axis_idx); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); @@ -236,17 +232,22 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node, QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift); } - int output_size = NumElements(output); op_data->num_axis = NumElements(axis); + op_data->num_output_elements = NumElements(output); if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { - context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t), - &op_data->temp_buffer_idx); + context->RequestScratchBufferInArena( + context, sizeof(int32_t) * op_data->num_output_elements, + &op_data->scratch_accumulator_idx); op_data->input_zp = input->params.zero_point; op_data->input_scale = input->params.scale; op_data->output_zp = output->params.zero_point; op_data->output_scale = output->params.scale; } + context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size, + &op_data->scratch_input_iter_idx); + context->RequestScratchBufferInArena(context, sizeof(int) * op_data->num_axis, + &op_data->scratch_resolved_axis_idx); TF_LITE_ENSURE_OK( context, @@ -274,12 +275,11 @@ TfLiteStatus PrepareAllHelper(TfLiteContext* context, TfLiteNode* node, op_data->output_zp = output->params.zero_point; op_data->output_scale = output->params.scale; op_data->num_output_elements = NumElements(output); - context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size, - &op_data->temp_buffer_idx); + &op_data->scratch_input_iter_idx); context->RequestScratchBufferInArena( context, sizeof(int) * static_cast(ElementCount(*axis->dims)), - &op_data->resolved_axis_idx); + &op_data->scratch_resolved_axis_idx); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); @@ -296,8 +296,10 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node, reinterpret_cast(node->builtin_data); int num_axis = static_cast(ElementCount(*axis->dims)); - int temp_index[kMaxNumberOfAxis]; - int resolved_axis[kMaxNumberOfReducedAxis]; + int* input_iter = static_cast( + context->GetScratchBuffer(context, op_data->scratch_input_iter_idx)); + int* resolved_axis = static_cast( + context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx)); switch (input->type) { case kTfLiteFloat32: { @@ -326,19 +328,19 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node, input->dims->size, tflite::micro::GetTensorData(output), output->dims->data, output->dims->size, tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_index, resolved_axis, + params->keep_dims, input_iter, resolved_axis, tflite::micro::GetTensorData(output))); } } break; case kTfLiteInt8: { TF_LITE_ENSURE_OK( context, EvalIntegerMean(context, node, num_axis, op_data, - temp_index, resolved_axis)); + input_iter, resolved_axis)); } break; case kTfLiteInt16: { TF_LITE_ENSURE_OK( context, EvalIntegerMean(context, node, num_axis, op_data, - temp_index, resolved_axis)); + input_iter, resolved_axis)); } break; default: TF_LITE_ENSURE_MSG(context, false, @@ -369,8 +371,10 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node, // Interpret an axis tensor with null dimensions as a scalar. int num_axis = static_cast(ElementCount(*axis->dims)); - int temp_index[kMaxNumberOfAxis]; - int resolved_axis[kMaxNumberOfReducedAxis]; + int* input_iter = static_cast( + context->GetScratchBuffer(context, op_data->scratch_input_iter_idx)); + int* resolved_axis = static_cast( + context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx)); switch (input->type) { case kTfLiteFloat32: { @@ -381,21 +385,21 @@ TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node, input->dims->size, tflite::micro::GetTensorData(output), output->dims->data, output->dims->size, tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_index, resolved_axis, /*init_value=*/0.f, + params->keep_dims, input_iter, resolved_axis, /*init_value=*/0.f, [](const float current, const float in) -> float { return in + current; })); } break; case kTfLiteInt8: { int32_t* temp_sum = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - QuantizedMeanOrSum(context, node, temp_index, resolved_axis, + context->GetScratchBuffer(context, op_data->scratch_accumulator_idx)); + QuantizedMeanOrSum(context, node, input_iter, resolved_axis, temp_sum, op_data, /*compute_sum=*/true); } break; case kTfLiteInt16: { int32_t* temp_sum = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); - QuantizedMeanOrSum(context, node, temp_index, resolved_axis, + context->GetScratchBuffer(context, op_data->scratch_accumulator_idx)); + QuantizedMeanOrSum(context, node, input_iter, resolved_axis, temp_sum, op_data, /*compute_sum=*/true); } break; default: @@ -416,10 +420,10 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node, // Interpret an axis tensor with null dimensions as a scalar int num_axis = static_cast(ElementCount(*axis->dims)); - int* temp_buffer = static_cast( - context->GetScratchBuffer(context, op_data->temp_buffer_idx)); + int* input_iter = static_cast( + context->GetScratchBuffer(context, op_data->scratch_input_iter_idx)); int* resolved_axis = static_cast( - context->GetScratchBuffer(context, op_data->resolved_axis_idx)); + context->GetScratchBuffer(context, op_data->scratch_resolved_axis_idx)); switch (input->type) { case kTfLiteBool: TF_LITE_ENSURE( @@ -429,7 +433,7 @@ TfLiteStatus EvalAllHelper(TfLiteContext* context, TfLiteNode* node, input->dims->size, tflite::micro::GetTensorData(output), output->dims->data, output->dims->size, tflite::micro::GetTensorData(axis), num_axis, - params->keep_dims, temp_buffer, resolved_axis, true, + params->keep_dims, input_iter, resolved_axis, true, [](const bool current, const bool in) -> bool { return in && current; }));