Skip to content

Commit 7843552

Browse files
committed
Support TfLite schema buffer and custom options offsets
@tensorflow/micro Allow for models >2Gb (and less than 4Gb) in size, as generated by the TfLite converter. Parse TfLite schema Buffer tables where the offset and size fields are active. Parse TfLite schema Operator tables where the large_custom_options_offset and large_custom_options_size fields are active. Correctly process the Offline Memory Planner metadata buffer. Correctly process the compression metadata buffer. Add unit tests for all of the above. bug=fixes #3196
1 parent ba753ce commit 7843552

10 files changed

+739
-131
lines changed

tensorflow/lite/micro/micro_allocation_info.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <algorithm>
1919

20+
#include "flatbuffers/flatbuffers.h"
2021
#include "tensorflow/lite/c/c_api_types.h"
2122
#include "tensorflow/lite/kernels/internal/compatibility.h"
2223
#include "tensorflow/lite/kernels/kernel_util.h"
@@ -352,8 +353,23 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets(
352353
model_->buffers();
353354
auto* buffer = (*buffers)[metadata->buffer()];
354355
auto* array = buffer->data();
355-
const uint32_t* metadata_buffer =
356-
reinterpret_cast<const uint32_t*>(array->data());
356+
const uint32_t* metadata_buffer = nullptr;
357+
const size_t min_length = sizeof(uint32_t) * 3;
358+
if (array != nullptr && array->size() >= min_length) {
359+
metadata_buffer = reinterpret_cast<const uint32_t*>(array->data());
360+
} else if (buffer->offset() > 1 && buffer->size() >= min_length) {
361+
const uint8_t* flatbuffer_start =
362+
flatbuffers::GetBufferStartFromRootPointer(model_);
363+
if (flatbuffer_start != nullptr) {
364+
metadata_buffer = reinterpret_cast<const uint32_t*>(
365+
flatbuffer_start + buffer->offset());
366+
}
367+
}
368+
if (metadata_buffer == nullptr) {
369+
MicroPrintf("Unable to locate offline buffer offsets");
370+
return kTfLiteError;
371+
}
372+
357373
const size_t nbr_tensors = static_cast<size_t>(metadata_buffer[2]);
358374
*offline_planner_offsets =
359375
reinterpret_cast<const int32_t*>(&metadata_buffer[3]);

tensorflow/lite/micro/micro_allocator.cc

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ namespace internal {
194194
// return nullptr if no buffer is found.
195195
void* GetFlatbufferTensorBuffer(
196196
const tflite::Tensor& flatbuffer_tensor,
197-
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers) {
197+
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
198+
const uint8_t* flatbuffer_start) {
198199
// We need to figure out where the actual contents of this tensor are stored
199200
// in memory. We'll check to see if there's a serialized buffer (pretty much
200201
// the same as a constant op in TensorFlow) associated with this tensor first,
@@ -212,6 +213,9 @@ void* GetFlatbufferTensorBuffer(
212213
// data structure to point to it.
213214
out_buffer = const_cast<void*>(static_cast<const void*>(array->data()));
214215
}
216+
} else if (buffer->offset() > 1 && buffer->size() > 0) {
217+
out_buffer = const_cast<void*>(
218+
static_cast<const void*>(flatbuffer_start + buffer->offset()));
215219
}
216220
// TODO(petewarden): It's not clear in what circumstances we could have a
217221
// buffer in the serialized tensor, but it doesn't have any data in it. Is
@@ -227,7 +231,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
227231
INonPersistentBufferAllocator* non_persistent_buffer_allocator,
228232
bool allocate_temp, const tflite::Tensor& flatbuffer_tensor,
229233
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
230-
TfLiteTensor* result) {
234+
TfLiteTensor* result, const uint8_t* flatbuffer_start) {
231235
TFLITE_DCHECK(result != nullptr);
232236

233237
*result = {};
@@ -238,7 +242,8 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
238242
// Make sure we remember if the serialized tensor is designated as a variable.
239243
result->is_variable = flatbuffer_tensor.is_variable();
240244

241-
result->data.data = GetFlatbufferTensorBuffer(flatbuffer_tensor, buffers);
245+
result->data.data =
246+
GetFlatbufferTensorBuffer(flatbuffer_tensor, buffers, flatbuffer_start);
242247

243248
// TODO(petewarden): Some of these paths aren't getting enough testing
244249
// coverage, so we should figure out some tests that exercise them.
@@ -345,14 +350,15 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
345350
TfLiteStatus InitializeTfLiteEvalTensorFromFlatbuffer(
346351
const tflite::Tensor& flatbuffer_tensor,
347352
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
348-
TfLiteEvalTensor* result) {
353+
TfLiteEvalTensor* result, const uint8_t* flatbuffer_start) {
349354
*result = {};
350355
// Make sure the serialized type is one we know how to deal with, and convert
351356
// it from a flatbuffer enum into a constant used by the kernel C API.
352357
TF_LITE_ENSURE_STATUS(
353358
tflite::ConvertTensorType(flatbuffer_tensor.type(), &result->type));
354359

355-
result->data.data = GetFlatbufferTensorBuffer(flatbuffer_tensor, buffers);
360+
result->data.data =
361+
GetFlatbufferTensorBuffer(flatbuffer_tensor, buffers, flatbuffer_start);
356362

357363
if (flatbuffer_tensor.shape() == nullptr) {
358364
// flatbuffer_tensor.shape() can return a nullptr in the case of a scalar
@@ -376,6 +382,15 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata(
376382
if (buffers == nullptr) {
377383
return nullptr;
378384
}
385+
386+
// needed for compression metadata buffer when model is larger than 2Gb
387+
const uint8_t* model_flatbuffer_start =
388+
flatbuffers::GetBufferStartFromRootPointer(&model);
389+
if (model_flatbuffer_start == nullptr) {
390+
MicroPrintf("%s: Unable to locate flatbuffer start", __func__);
391+
return nullptr;
392+
}
393+
379394
const size_t metadata_string_length = std::strlen(kCompressionMetadataString);
380395
for (size_t metadata_index = 0; metadata_index < metadata_vector->size();
381396
metadata_index++) {
@@ -392,18 +407,33 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata(
392407
MicroPrintf("Compression: Invalid buffer index %u", buffer_index);
393408
continue;
394409
}
395-
auto vp = buffers->Get(buffer_index)->data();
396-
if (vp == nullptr || vp->data() == nullptr) {
410+
411+
auto buffer = buffers->Get(buffer_index);
412+
const uint8_t* metadata_start = nullptr;
413+
size_t metadata_size = 0;
414+
if (buffer != nullptr) {
415+
auto vp = buffer->data();
416+
if (vp != nullptr) {
417+
metadata_start = vp->data();
418+
metadata_size = vp->size();
419+
} else if (buffer->offset() > 1) {
420+
metadata_start = model_flatbuffer_start + buffer->offset();
421+
metadata_size = buffer->size();
422+
}
423+
}
424+
425+
if (metadata_start == nullptr) {
397426
MicroPrintf("Compression: Invalid data for buffer index %u",
398427
buffer_index);
399428
continue;
400429
}
430+
401431
// TODO(ddavis-2015): support multiple compression methods, possibly
402432
// through multiple verification checks.
403433
// Then return a pair<void*, compression_scheme>.
404434
auto compression_metadata =
405-
tflite::micro::compression::GetSizePrefixedMetadata(vp);
406-
flatbuffers::Verifier verifier(vp->data(), vp->size(),
435+
tflite::micro::compression::GetMetadata(metadata_start);
436+
flatbuffers::Verifier verifier(metadata_start, metadata_size,
407437
flatbuffers::Verifier::Options());
408438
if (!tflite::micro::compression::VerifyMetadataBuffer(verifier)) {
409439
MicroPrintf("Compression: verification failure");
@@ -429,6 +459,14 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
429459
const Model& model, const size_t subgraph_index,
430460
const tflite::micro::compression::LutTensor& lut_tensor,
431461
CompressionTensorData* ctd) {
462+
// needed for LUT value buffer when model is larger than 2Gb
463+
const uint8_t* model_flatbuffer_start =
464+
flatbuffers::GetBufferStartFromRootPointer(&model);
465+
if (model_flatbuffer_start == nullptr) {
466+
MicroPrintf("%s: Unable to locate flatbuffer start", __func__);
467+
return kTfLiteError;
468+
}
469+
432470
// TODO(ddavis-2015): support multiple compression schemes
433471
ctd->scheme = CompressionScheme::kBinQuant;
434472

@@ -446,19 +484,33 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
446484
return kTfLiteError;
447485
}
448486
ctd->data.lut_data->compressed_bit_width = index_bit_width;
487+
449488
const size_t value_buffer_index = lut_tensor.value_buffer();
450489
if (value_buffer_index >= model.buffers()->size()) {
451490
MicroPrintf("Compression: invalid value_buffer %u in LutTensor",
452491
value_buffer_index);
453492
return kTfLiteError;
454493
}
455-
auto value_buffer = model.buffers()->Get(value_buffer_index)->data();
456-
if (value_buffer == nullptr || value_buffer->data() == nullptr) {
494+
auto value_buffer = model.buffers()->Get(value_buffer_index);
495+
const uint8_t* value_buffer_start = nullptr;
496+
size_t value_buffer_size = 0;
497+
if (value_buffer != nullptr) {
498+
auto vp = value_buffer->data();
499+
if (vp != nullptr) {
500+
value_buffer_start = vp->data();
501+
value_buffer_size = vp->size();
502+
} else if (value_buffer->offset() > 1) {
503+
value_buffer_start = model_flatbuffer_start + value_buffer->offset();
504+
value_buffer_size = value_buffer->size();
505+
}
506+
}
507+
if (value_buffer_start == nullptr) {
457508
MicroPrintf("Compression: invalid value table for value_buffer %u",
458509
value_buffer_index);
459510
return kTfLiteError;
460511
}
461-
ctd->data.lut_data->value_table = value_buffer->data();
512+
ctd->data.lut_data->value_table = value_buffer_start;
513+
462514
auto tensor =
463515
model.subgraphs()->Get(subgraph_index)->tensors()->Get(tensor_index);
464516
if (tensor->shape() == nullptr) {
@@ -495,12 +547,12 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
495547
return kTfLiteError;
496548
}
497549
ctd->data.lut_data->value_table_channel_stride =
498-
(value_buffer->size() / tensor_type_size) / num_channels;
550+
(value_buffer_size / tensor_type_size) / num_channels;
499551
} else {
500552
ctd->data.lut_data->is_per_channel_quantized = false;
501553
ctd->data.lut_data->use_alternate_axis = false;
502554
ctd->data.lut_data->value_table_channel_stride =
503-
value_buffer->size() / tensor_type_size;
555+
value_buffer_size / tensor_type_size;
504556
}
505557

506558
return kTfLiteOk;
@@ -1038,6 +1090,14 @@ TfLiteStatus MicroAllocator::AllocateTfLiteEvalTensors(
10381090
const Model* model, SubgraphAllocations* subgraph_allocations) {
10391091
TFLITE_DCHECK(subgraph_allocations != nullptr);
10401092

1093+
// needed for tensor data buffer when model is larger than 2Gb
1094+
const uint8_t* flatbuffer_start =
1095+
flatbuffers::GetBufferStartFromRootPointer(model);
1096+
if (flatbuffer_start == nullptr) {
1097+
MicroPrintf("%s: Unable to locate flatbuffer start", __func__);
1098+
return kTfLiteError;
1099+
}
1100+
10411101
for (size_t subgraph_idx = 0; subgraph_idx < model->subgraphs()->size();
10421102
subgraph_idx++) {
10431103
const SubGraph* subgraph = model->subgraphs()->Get(subgraph_idx);
@@ -1057,7 +1117,8 @@ TfLiteStatus MicroAllocator::AllocateTfLiteEvalTensors(
10571117

10581118
for (size_t i = 0; i < alloc_count; ++i) {
10591119
TfLiteStatus status = internal::InitializeTfLiteEvalTensorFromFlatbuffer(
1060-
*subgraph->tensors()->Get(i), model->buffers(), &tensors[i]);
1120+
*subgraph->tensors()->Get(i), model->buffers(), &tensors[i],
1121+
flatbuffer_start);
10611122
if (status != kTfLiteOk) {
10621123
MicroPrintf("Failed to initialize tensor %d", i);
10631124
return kTfLiteError;
@@ -1104,14 +1165,22 @@ TfLiteTensor* MicroAllocator::AllocatePersistentTfLiteTensorInternal() {
11041165
TfLiteStatus MicroAllocator::PopulateTfLiteTensorFromFlatbuffer(
11051166
const Model* model, TfLiteTensor* tensor, int tensor_index,
11061167
int subgraph_idx, bool allocate_temp) {
1168+
// needed for tensor data buffer when model is larger than 2Gb
1169+
const uint8_t* flatbuffer_start =
1170+
flatbuffers::GetBufferStartFromRootPointer(model);
1171+
if (flatbuffer_start == nullptr) {
1172+
MicroPrintf("%s: Unable to locate flatbuffer start", __func__);
1173+
return kTfLiteError;
1174+
}
1175+
11071176
// TODO(b/162311891): This method serves as a stub to ensure quantized
11081177
// allocations in the tail can be recorded. Once the interpreter has APIs for
11091178
// accessing buffers on TfLiteEvalTensor this method can be dropped.
11101179
return internal::InitializeTfLiteTensorFromFlatbuffer(
11111180
persistent_buffer_allocator_, non_persistent_buffer_allocator_,
11121181
allocate_temp,
11131182
*model->subgraphs()->Get(subgraph_idx)->tensors()->Get(tensor_index),
1114-
model->buffers(), tensor);
1183+
model->buffers(), tensor, flatbuffer_start);
11151184
}
11161185

11171186
TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(

tensorflow/lite/micro/micro_allocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
4747
INonPersistentBufferAllocator* non_persistent_buffer_allocator,
4848
bool allocate_temp, const tflite::Tensor& flatbuffer_tensor,
4949
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
50-
TfLiteTensor* result);
50+
TfLiteTensor* result, const uint8_t* flatbuffer_start);
5151

5252
// Holds placeholder information for a scratch buffer request from a kernel.
5353
// This struct is only used during the model prepare stage. Each request from a

0 commit comments

Comments
 (0)