Skip to content

Commit f5c36ef

Browse files
authored
Change MicroInterpreter::SetCompressionMemory API (#3267)
* Support for DECODE operator @tensorflow/micro Add support for alternate decompression memory to DECODE operator. Additional unit tests. Update generic benchmark application and Makefile. bug=fixes #3212 * fixes. * Changes as per review. * Change SetDecompressionMemory API
1 parent 9cbf6e0 commit f5c36ef

File tree

10 files changed

+53
-36
lines changed

10 files changed

+53
-36
lines changed

tensorflow/lite/micro/kernels/decode_test_helpers.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ TfLiteStatus ExecuteDecodeTest(
102102
inputs_array, outputs_array, nullptr);
103103

104104
if (amr != nullptr) {
105-
runner.GetFakeMicroContext()->SetDecompressionMemory(*amr);
105+
runner.GetFakeMicroContext()->SetDecompressionMemory(amr->begin(),
106+
amr->size());
106107
}
107108

108109
if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {

tensorflow/lite/micro/micro_context.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ void* MicroContext::DecompressTensorToBuffer(
130130
#endif // USE_TFLM_COMPRESSION
131131

132132
TfLiteStatus MicroContext::SetDecompressionMemory(
133-
const std::initializer_list<AlternateMemoryRegion>& regions) {
133+
const AlternateMemoryRegion* regions, size_t count) {
134134
if (decompress_regions_ != nullptr) {
135135
return kTfLiteError;
136136
}
137137

138-
decompress_regions_ = &regions;
138+
decompress_regions_ = regions;
139+
decompress_regions_size_ = count;
139140
decompress_regions_allocations_ = static_cast<size_t*>(
140-
AllocatePersistentBuffer(sizeof(size_t) * regions.size()));
141+
AllocatePersistentBuffer(sizeof(size_t) * decompress_regions_size_));
141142
if (decompress_regions_allocations_ == nullptr) {
142143
return kTfLiteError;
143144
}
@@ -149,8 +150,8 @@ TfLiteStatus MicroContext::SetDecompressionMemory(
149150
void* MicroContext::AllocateDecompressionMemory(size_t bytes,
150151
size_t alignment) {
151152
if (decompress_regions_ != nullptr) {
152-
for (size_t i = 0; i < decompress_regions_->size(); i++) {
153-
const AlternateMemoryRegion* region = &decompress_regions_->begin()[i];
153+
for (size_t i = 0; i < decompress_regions_size_; i++) {
154+
const AlternateMemoryRegion* region = &decompress_regions_[i];
154155
uint8_t* start = static_cast<uint8_t*>(region->address) +
155156
decompress_regions_allocations_[i];
156157
uint8_t* aligned_start = AlignPointerUp(start, alignment);
@@ -170,7 +171,7 @@ void MicroContext::ResetDecompressionMemoryAllocations() {
170171
return;
171172
}
172173
TFLITE_DCHECK(decompress_regions_allocations_ != nullptr);
173-
std::fill_n(decompress_regions_allocations_, decompress_regions_->size(), 0);
174+
std::fill_n(decompress_regions_allocations_, decompress_regions_size_, 0);
174175
}
175176

176177
} // namespace tflite

tensorflow/lite/micro/micro_context.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class MicroContext {
138138
// Set the alternate decompression memory regions.
139139
// Can only be called during the MicroInterpreter kInit state.
140140
virtual TfLiteStatus SetDecompressionMemory(
141-
const std::initializer_list<AlternateMemoryRegion>& regions);
141+
const AlternateMemoryRegion* regions, size_t count);
142142

143143
// Return a pointer to memory that can be used for decompression.
144144
// The pointer will be aligned to the <alignment> value.
@@ -170,9 +170,9 @@ class MicroContext {
170170
}
171171

172172
private:
173-
const std::initializer_list<AlternateMemoryRegion>* decompress_regions_ =
174-
nullptr;
175-
// array of size_t elements with length equal to decompress_regions_.size()
173+
const AlternateMemoryRegion* decompress_regions_ = nullptr;
174+
size_t decompress_regions_size_ = 0;
175+
// array of size_t elements with length equal to decompress_regions_size_
176176
size_t* decompress_regions_allocations_ = nullptr;
177177

178178
TF_LITE_REMOVE_VIRTUAL_DELETE

tensorflow/lite/micro/micro_interpreter.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,8 @@ TfLiteStatus MicroInterpreter::SetAlternateProfiler(
340340
}
341341

342342
TfLiteStatus MicroInterpreter::SetDecompressionMemory(
343-
const std::initializer_list<MicroInterpreterContext::AlternateMemoryRegion>&
344-
regions) {
345-
return micro_context_.SetDecompressionMemory(regions);
343+
const MicroContext::AlternateMemoryRegion* regions, size_t count) {
344+
return micro_context_.SetDecompressionMemory(regions, count);
346345
}
347346

348347
} // namespace tflite

tensorflow/lite/micro/micro_interpreter.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,13 @@ class MicroInterpreter {
163163
// Set the alternate decompression memory regions.
164164
// Can only be called during the MicroInterpreter kInit state (i.e. must
165165
// be called before MicroInterpreter::AllocateTensors).
166+
// The regions pointer argument is the start of a
167+
// MicroContext::AlternateMemoryRegion array where the length of the array is
168+
// given by the count argument.
169+
// The lifetime of the MicroContext::AlternateMemoryRegion array must be at
170+
// least that of the MicroInterpreter.
166171
TfLiteStatus SetDecompressionMemory(
167-
const std::initializer_list<MicroContext::AlternateMemoryRegion>&
168-
regions);
172+
const MicroContext::AlternateMemoryRegion* regions, size_t count);
169173

170174
protected:
171175
const MicroAllocator& allocator() const { return allocator_; }

tensorflow/lite/micro/micro_interpreter_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,12 @@ void* MicroInterpreterContext::DecompressTensorToBuffer(
217217
#endif // USE_TFLM_COMPRESSION
218218

219219
TfLiteStatus MicroInterpreterContext::SetDecompressionMemory(
220-
const std::initializer_list<MicroContext::AlternateMemoryRegion>& regions) {
220+
const AlternateMemoryRegion* regions, size_t count) {
221221
if (state_ != InterpreterState::kInit) {
222222
return kTfLiteError;
223223
}
224224

225-
return MicroContext::SetDecompressionMemory(regions);
225+
return MicroContext::SetDecompressionMemory(regions, count);
226226
}
227227

228228
void* MicroInterpreterContext::AllocateDecompressionMemory(size_t bytes,

tensorflow/lite/micro/micro_interpreter_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class MicroInterpreterContext : public MicroContext {
135135

136136
// Set the alternate decompression memory regions.
137137
// Can only be called during the MicroInterpreter kInit state.
138-
TfLiteStatus SetDecompressionMemory(
139-
const std::initializer_list<AlternateMemoryRegion>& regions) override;
138+
TfLiteStatus SetDecompressionMemory(const AlternateMemoryRegion* regions,
139+
size_t count) override;
140140

141141
// Return a pointer to memory that can be used for decompression.
142142
// The pointer will be aligned to the <alignment> value.

tensorflow/lite/micro/micro_interpreter_context_test.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,25 +219,29 @@ TF_LITE_MICRO_TEST(TestSetDecompressionMemory) {
219219
// fail during Prepare state
220220
micro_context.SetInterpreterState(
221221
tflite::MicroInterpreterContext::InterpreterState::kPrepare);
222-
status = micro_context.SetDecompressionMemory(alt_memory_region);
222+
status = micro_context.SetDecompressionMemory(alt_memory_region.begin(),
223+
alt_memory_region.size());
223224
TF_LITE_MICRO_EXPECT(status == kTfLiteError);
224225

225226
// fail during Invoke state
226227
micro_context.SetInterpreterState(
227228
tflite::MicroInterpreterContext::InterpreterState::kInvoke);
228-
status = micro_context.SetDecompressionMemory(alt_memory_region);
229+
status = micro_context.SetDecompressionMemory(alt_memory_region.begin(),
230+
alt_memory_region.size());
229231
TF_LITE_MICRO_EXPECT(status == kTfLiteError);
230232

231233
// succeed during Init state
232234
micro_context.SetInterpreterState(
233235
tflite::MicroInterpreterContext::InterpreterState::kInit);
234-
status = micro_context.SetDecompressionMemory(alt_memory_region);
236+
status = micro_context.SetDecompressionMemory(alt_memory_region.begin(),
237+
alt_memory_region.size());
235238
TF_LITE_MICRO_EXPECT(status == kTfLiteOk);
236239

237240
// fail on second Init state attempt
238241
micro_context.SetInterpreterState(
239242
tflite::MicroInterpreterContext::InterpreterState::kInit);
240-
status = micro_context.SetDecompressionMemory(alt_memory_region);
243+
status = micro_context.SetDecompressionMemory(alt_memory_region.begin(),
244+
alt_memory_region.size());
241245
TF_LITE_MICRO_EXPECT(status == kTfLiteError);
242246
}
243247

@@ -253,7 +257,8 @@ TF_LITE_MICRO_TEST(TestAllocateDecompressionMemory) {
253257

254258
micro_context.SetInterpreterState(
255259
tflite::MicroInterpreterContext::InterpreterState::kInit);
256-
TfLiteStatus status = micro_context.SetDecompressionMemory(alt_memory_region);
260+
TfLiteStatus status = micro_context.SetDecompressionMemory(
261+
alt_memory_region.begin(), alt_memory_region.size());
257262
TF_LITE_MICRO_EXPECT(status == kTfLiteOk);
258263

259264
micro_context.SetInterpreterState(
@@ -287,7 +292,8 @@ TF_LITE_MICRO_TEST(TestResetDecompressionMemory) {
287292

288293
micro_context.SetInterpreterState(
289294
tflite::MicroInterpreterContext::InterpreterState::kInit);
290-
TfLiteStatus status = micro_context.SetDecompressionMemory(alt_memory_region);
295+
TfLiteStatus status = micro_context.SetDecompressionMemory(
296+
alt_memory_region.begin(), alt_memory_region.size());
291297
TF_LITE_MICRO_EXPECT(status == kTfLiteOk);
292298

293299
micro_context.SetInterpreterState(

tensorflow/lite/micro/micro_interpreter_test.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemoryAfterInit) {
182182
tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
183183
kAllocatorBufferSize);
184184
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
185-
TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem),
186-
kTfLiteError);
185+
TF_LITE_MICRO_EXPECT_EQ(
186+
interpreter.SetDecompressionMemory(alt_mem.begin(), alt_mem.size()),
187+
kTfLiteError);
187188
}
188189
}
189190

@@ -208,8 +209,9 @@ TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemoryTooSmall) {
208209
{
209210
tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
210211
kAllocatorBufferSize);
211-
TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem),
212-
kTfLiteOk);
212+
TF_LITE_MICRO_EXPECT_EQ(
213+
interpreter.SetDecompressionMemory(alt_mem.begin(), alt_mem.size()),
214+
kTfLiteOk);
213215
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
214216
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
215217
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.outputs_size());
@@ -269,8 +271,9 @@ TF_LITE_MICRO_TEST(TestInterpreterCompressionAltMemory) {
269271
{
270272
tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
271273
kAllocatorBufferSize);
272-
TF_LITE_MICRO_EXPECT_EQ(interpreter.SetDecompressionMemory(alt_mem),
273-
kTfLiteOk);
274+
TF_LITE_MICRO_EXPECT_EQ(
275+
interpreter.SetDecompressionMemory(alt_mem.begin(), alt_mem.size()),
276+
kTfLiteOk);
274277
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
275278
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
276279
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.outputs_size());

tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#if !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL)
1617
#include <stdio.h>
1718
#include <sys/stat.h>
1819
#include <sys/types.h>
20+
#endif // !defined(GENERIC_BENCHMARK_USING_BUILTIN_MODEL)
1921

22+
#include <array>
2023
#include <cstring>
21-
#include <initializer_list>
2224
#include <memory>
2325
#include <random>
2426
#include <type_traits>
@@ -246,9 +248,10 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) {
246248
#ifdef USE_ALT_DECOMPRESSION_MEM
247249
event_handle =
248250
profiler.BeginEvent("tflite::MicroInterpreter::SetDecompressionMemory");
249-
std::initializer_list<tflite::MicroContext::AlternateMemoryRegion>
250-
alt_memory_region = {{g_alt_memory, kAltMemorySize}};
251-
status = interpreter.SetDecompressionMemory(alt_memory_region);
251+
tflite::MicroContext::AlternateMemoryRegion alt_memory_region[] = {
252+
{g_alt_memory, kAltMemorySize}};
253+
status = interpreter.SetDecompressionMemory(alt_memory_region,
254+
std::size(alt_memory_region));
252255
if (status != kTfLiteOk) {
253256
MicroPrintf("tflite::MicroInterpreter::SetDecompressionMemory failed");
254257
return -1;

0 commit comments

Comments
 (0)