@@ -194,7 +194,8 @@ namespace internal {
194
194
// return nullptr if no buffer is found.
195
195
void * GetFlatbufferTensorBuffer (
196
196
const tflite::Tensor& flatbuffer_tensor,
197
- const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers) {
197
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
198
+ const uint8_t * flatbuffer_start) {
198
199
// We need to figure out where the actual contents of this tensor are stored
199
200
// in memory. We'll check to see if there's a serialized buffer (pretty much
200
201
// the same as a constant op in TensorFlow) associated with this tensor first,
@@ -212,6 +213,9 @@ void* GetFlatbufferTensorBuffer(
212
213
// data structure to point to it.
213
214
out_buffer = const_cast <void *>(static_cast <const void *>(array->data ()));
214
215
}
216
+ } else if (buffer->offset () > 1 && buffer->size () > 0 ) {
217
+ out_buffer = const_cast <void *>(
218
+ static_cast <const void *>(flatbuffer_start + buffer->offset ()));
215
219
}
216
220
// TODO(petewarden): It's not clear in what circumstances we could have a
217
221
// buffer in the serialized tensor, but it doesn't have any data in it. Is
@@ -227,7 +231,7 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
227
231
INonPersistentBufferAllocator* non_persistent_buffer_allocator,
228
232
bool allocate_temp, const tflite::Tensor& flatbuffer_tensor,
229
233
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
230
- TfLiteTensor* result) {
234
+ TfLiteTensor* result, const uint8_t * flatbuffer_start ) {
231
235
TFLITE_DCHECK (result != nullptr );
232
236
233
237
*result = {};
@@ -238,7 +242,8 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
238
242
// Make sure we remember if the serialized tensor is designated as a variable.
239
243
result->is_variable = flatbuffer_tensor.is_variable ();
240
244
241
- result->data .data = GetFlatbufferTensorBuffer (flatbuffer_tensor, buffers);
245
+ result->data .data =
246
+ GetFlatbufferTensorBuffer (flatbuffer_tensor, buffers, flatbuffer_start);
242
247
243
248
// TODO(petewarden): Some of these paths aren't getting enough testing
244
249
// coverage, so we should figure out some tests that exercise them.
@@ -345,14 +350,15 @@ TfLiteStatus InitializeTfLiteTensorFromFlatbuffer(
345
350
TfLiteStatus InitializeTfLiteEvalTensorFromFlatbuffer (
346
351
const tflite::Tensor& flatbuffer_tensor,
347
352
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
348
- TfLiteEvalTensor* result) {
353
+ TfLiteEvalTensor* result, const uint8_t * flatbuffer_start ) {
349
354
*result = {};
350
355
// Make sure the serialized type is one we know how to deal with, and convert
351
356
// it from a flatbuffer enum into a constant used by the kernel C API.
352
357
TF_LITE_ENSURE_STATUS (
353
358
tflite::ConvertTensorType (flatbuffer_tensor.type (), &result->type ));
354
359
355
- result->data .data = GetFlatbufferTensorBuffer (flatbuffer_tensor, buffers);
360
+ result->data .data =
361
+ GetFlatbufferTensorBuffer (flatbuffer_tensor, buffers, flatbuffer_start);
356
362
357
363
if (flatbuffer_tensor.shape () == nullptr ) {
358
364
// flatbuffer_tensor.shape() can return a nullptr in the case of a scalar
@@ -376,6 +382,15 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata(
376
382
if (buffers == nullptr ) {
377
383
return nullptr ;
378
384
}
385
+
386
+ // needed for compression metadata buffer when model is larger than 2Gb
387
+ const uint8_t * model_flatbuffer_start =
388
+ flatbuffers::GetBufferStartFromRootPointer (&model);
389
+ if (model_flatbuffer_start == nullptr ) {
390
+ MicroPrintf (" %s: Unable to locate flatbuffer start" , __func__);
391
+ return nullptr ;
392
+ }
393
+
379
394
const size_t metadata_string_length = std::strlen (kCompressionMetadataString );
380
395
for (size_t metadata_index = 0 ; metadata_index < metadata_vector->size ();
381
396
metadata_index++) {
@@ -392,18 +407,33 @@ const tflite::micro::compression::Metadata* GetCompressionMetadata(
392
407
MicroPrintf (" Compression: Invalid buffer index %u" , buffer_index);
393
408
continue ;
394
409
}
395
- auto vp = buffers->Get (buffer_index)->data ();
396
- if (vp == nullptr || vp->data () == nullptr ) {
410
+
411
+ auto buffer = buffers->Get (buffer_index);
412
+ const uint8_t * metadata_start = nullptr ;
413
+ size_t metadata_size = 0 ;
414
+ if (buffer != nullptr ) {
415
+ auto vp = buffer->data ();
416
+ if (vp != nullptr ) {
417
+ metadata_start = vp->data ();
418
+ metadata_size = vp->size ();
419
+ } else if (buffer->offset () > 1 ) {
420
+ metadata_start = model_flatbuffer_start + buffer->offset ();
421
+ metadata_size = buffer->size ();
422
+ }
423
+ }
424
+
425
+ if (metadata_start == nullptr ) {
397
426
MicroPrintf (" Compression: Invalid data for buffer index %u" ,
398
427
buffer_index);
399
428
continue ;
400
429
}
430
+
401
431
// TODO(ddavis-2015): support multiple compression methods, possibly
402
432
// through multiple verification checks.
403
433
// Then return a pair<void*, compression_scheme>.
404
434
auto compression_metadata =
405
- tflite::micro::compression::GetSizePrefixedMetadata (vp );
406
- flatbuffers::Verifier verifier (vp-> data (), vp-> size () ,
435
+ tflite::micro::compression::GetMetadata (metadata_start );
436
+ flatbuffers::Verifier verifier (metadata_start, metadata_size ,
407
437
flatbuffers::Verifier::Options ());
408
438
if (!tflite::micro::compression::VerifyMetadataBuffer (verifier)) {
409
439
MicroPrintf (" Compression: verification failure" );
@@ -429,6 +459,14 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
429
459
const Model& model, const size_t subgraph_index,
430
460
const tflite::micro::compression::LutTensor& lut_tensor,
431
461
CompressionTensorData* ctd) {
462
+ // needed for LUT value buffer when model is larger than 2Gb
463
+ const uint8_t * model_flatbuffer_start =
464
+ flatbuffers::GetBufferStartFromRootPointer (&model);
465
+ if (model_flatbuffer_start == nullptr ) {
466
+ MicroPrintf (" %s: Unable to locate flatbuffer start" , __func__);
467
+ return kTfLiteError ;
468
+ }
469
+
432
470
// TODO(ddavis-2015): support multiple compression schemes
433
471
ctd->scheme = CompressionScheme::kBinQuant ;
434
472
@@ -446,19 +484,33 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
446
484
return kTfLiteError ;
447
485
}
448
486
ctd->data .lut_data ->compressed_bit_width = index_bit_width;
487
+
449
488
const size_t value_buffer_index = lut_tensor.value_buffer ();
450
489
if (value_buffer_index >= model.buffers ()->size ()) {
451
490
MicroPrintf (" Compression: invalid value_buffer %u in LutTensor" ,
452
491
value_buffer_index);
453
492
return kTfLiteError ;
454
493
}
455
- auto value_buffer = model.buffers ()->Get (value_buffer_index)->data ();
456
- if (value_buffer == nullptr || value_buffer->data () == nullptr ) {
494
+ auto value_buffer = model.buffers ()->Get (value_buffer_index);
495
+ const uint8_t * value_buffer_start = nullptr ;
496
+ size_t value_buffer_size = 0 ;
497
+ if (value_buffer != nullptr ) {
498
+ auto vp = value_buffer->data ();
499
+ if (vp != nullptr ) {
500
+ value_buffer_start = vp->data ();
501
+ value_buffer_size = vp->size ();
502
+ } else if (value_buffer->offset () > 1 ) {
503
+ value_buffer_start = model_flatbuffer_start + value_buffer->offset ();
504
+ value_buffer_size = value_buffer->size ();
505
+ }
506
+ }
507
+ if (value_buffer_start == nullptr ) {
457
508
MicroPrintf (" Compression: invalid value table for value_buffer %u" ,
458
509
value_buffer_index);
459
510
return kTfLiteError ;
460
511
}
461
- ctd->data .lut_data ->value_table = value_buffer->data ();
512
+ ctd->data .lut_data ->value_table = value_buffer_start;
513
+
462
514
auto tensor =
463
515
model.subgraphs ()->Get (subgraph_index)->tensors ()->Get (tensor_index);
464
516
if (tensor->shape () == nullptr ) {
@@ -495,12 +547,12 @@ TfLiteStatus InitializeCompressionTensorDataFromFlatbuffer(
495
547
return kTfLiteError ;
496
548
}
497
549
ctd->data .lut_data ->value_table_channel_stride =
498
- (value_buffer-> size () / tensor_type_size) / num_channels;
550
+ (value_buffer_size / tensor_type_size) / num_channels;
499
551
} else {
500
552
ctd->data .lut_data ->is_per_channel_quantized = false ;
501
553
ctd->data .lut_data ->use_alternate_axis = false ;
502
554
ctd->data .lut_data ->value_table_channel_stride =
503
- value_buffer-> size () / tensor_type_size;
555
+ value_buffer_size / tensor_type_size;
504
556
}
505
557
506
558
return kTfLiteOk ;
@@ -1038,6 +1090,14 @@ TfLiteStatus MicroAllocator::AllocateTfLiteEvalTensors(
1038
1090
const Model* model, SubgraphAllocations* subgraph_allocations) {
1039
1091
TFLITE_DCHECK (subgraph_allocations != nullptr );
1040
1092
1093
+ // needed for tensor data buffer when model is larger than 2Gb
1094
+ const uint8_t * flatbuffer_start =
1095
+ flatbuffers::GetBufferStartFromRootPointer (model);
1096
+ if (flatbuffer_start == nullptr ) {
1097
+ MicroPrintf (" %s: Unable to locate flatbuffer start" , __func__);
1098
+ return kTfLiteError ;
1099
+ }
1100
+
1041
1101
for (size_t subgraph_idx = 0 ; subgraph_idx < model->subgraphs ()->size ();
1042
1102
subgraph_idx++) {
1043
1103
const SubGraph* subgraph = model->subgraphs ()->Get (subgraph_idx);
@@ -1057,7 +1117,8 @@ TfLiteStatus MicroAllocator::AllocateTfLiteEvalTensors(
1057
1117
1058
1118
for (size_t i = 0 ; i < alloc_count; ++i) {
1059
1119
TfLiteStatus status = internal::InitializeTfLiteEvalTensorFromFlatbuffer (
1060
- *subgraph->tensors ()->Get (i), model->buffers (), &tensors[i]);
1120
+ *subgraph->tensors ()->Get (i), model->buffers (), &tensors[i],
1121
+ flatbuffer_start);
1061
1122
if (status != kTfLiteOk ) {
1062
1123
MicroPrintf (" Failed to initialize tensor %d" , i);
1063
1124
return kTfLiteError ;
@@ -1104,14 +1165,22 @@ TfLiteTensor* MicroAllocator::AllocatePersistentTfLiteTensorInternal() {
1104
1165
TfLiteStatus MicroAllocator::PopulateTfLiteTensorFromFlatbuffer (
1105
1166
const Model* model, TfLiteTensor* tensor, int tensor_index,
1106
1167
int subgraph_idx, bool allocate_temp) {
1168
+ // needed for tensor data buffer when model is larger than 2Gb
1169
+ const uint8_t * flatbuffer_start =
1170
+ flatbuffers::GetBufferStartFromRootPointer (model);
1171
+ if (flatbuffer_start == nullptr ) {
1172
+ MicroPrintf (" %s: Unable to locate flatbuffer start" , __func__);
1173
+ return kTfLiteError ;
1174
+ }
1175
+
1107
1176
// TODO(b/162311891): This method serves as a stub to ensure quantized
1108
1177
// allocations in the tail can be recorded. Once the interpreter has APIs for
1109
1178
// accessing buffers on TfLiteEvalTensor this method can be dropped.
1110
1179
return internal::InitializeTfLiteTensorFromFlatbuffer (
1111
1180
persistent_buffer_allocator_, non_persistent_buffer_allocator_,
1112
1181
allocate_temp,
1113
1182
*model->subgraphs ()->Get (subgraph_idx)->tensors ()->Get (tensor_index),
1114
- model->buffers (), tensor);
1183
+ model->buffers (), tensor, flatbuffer_start );
1115
1184
}
1116
1185
1117
1186
TfLiteStatus MicroAllocator::CommitStaticMemoryPlan (
0 commit comments