Skip to content

Commit 6449120

Browse files
ziyeqinghantflite-support-robot
authored andcommitted
Object Detection Metadata Writer: Map output by outputs indices in the TFLite SubGraph.
PiperOrigin-RevId: 399863321
1 parent 93a0409 commit 6449120

File tree

7 files changed

+291
-12
lines changed

7 files changed

+291
-12
lines changed

tensorflow_lite_support/metadata/python/metadata_writers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ py_library(
7777
":metadata_writer",
7878
":writer_utils",
7979
"//tensorflow_lite_support/metadata:metadata_schema_py",
80+
"//tensorflow_lite_support/metadata:schema_py",
8081
"//tensorflow_lite_support/metadata/python:metadata",
8182
"@flatbuffers//:runtime_py",
8283
],

tensorflow_lite_support/metadata/python/metadata_writers/object_detector.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# ==============================================================================
1515
"""Writes metadata and label file to the object detector models."""
1616

17+
import logging
1718
from typing import List, Optional, Type, Union
1819

1920
import flatbuffers
2021
from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
22+
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
2123
from tensorflow_lite_support.metadata.python import metadata as _metadata
2224
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
2325
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer
@@ -30,6 +32,9 @@
3032
"stream.")
3133
_INPUT_NAME = "image"
3234
_INPUT_DESCRIPTION = "Input image to be detected."
35+
# The output tensor names shouldn't be changed since these name will be used
36+
# to handle the order of output in TFLite Task Library when doing inference
37+
# in on-device application.
3338
_OUTPUT_LOCATION_NAME = "location"
3439
_OUTPUT_LOCATION_DESCRIPTION = "The locations of the detected boxes."
3540
_OUTPUT_CATRGORY_NAME = "category"
@@ -77,6 +82,12 @@ def _create_metadata_with_value_range(
7782
return tensor_metadata
7883

7984

85+
def _get_tflite_outputs(model_buffer: bytearray) -> List[int]:
86+
"""Gets the tensor indices of output in the TFLite Subgraph."""
87+
model = _schema_fb.Model.GetRootAsModel(model_buffer, 0)
88+
return model.Subgraphs(0).OutputsAsNumpy()
89+
90+
8091
def _extend_new_files(
8192
file_list: List[str],
8293
associated_files: Optional[List[Type[metadata_info.AssociatedFileMd]]]):
@@ -126,7 +137,6 @@ def create_from_metadata_info(
126137
Returns:
127138
A MetadataWriter object.
128139
"""
129-
130140
if general_md is None:
131141
general_md = metadata_info.GeneralMd(
132142
name=_MODEL_NAME, description=_MODEL_DESCRIPTION)
@@ -137,23 +147,35 @@ def create_from_metadata_info(
137147
description=_INPUT_DESCRIPTION,
138148
color_space_type=_metadata_fb.ColorSpaceType.RGB)
139149

150+
warn_message_format = (
151+
"The output name isn't the default string \"%s\". This may cause the "
152+
"model not work in the TFLite Task Library since the tensor name will "
153+
"be used to handle the output order in the TFLite Task Library.")
140154
if output_location_md is None:
141155
output_location_md = metadata_info.TensorMd(
142156
name=_OUTPUT_LOCATION_NAME, description=_OUTPUT_LOCATION_DESCRIPTION)
157+
elif output_location_md.name != _OUTPUT_LOCATION_NAME:
158+
logging.warning(warn_message_format, _OUTPUT_LOCATION_NAME)
143159

144160
if output_category_md is None:
145161
output_category_md = metadata_info.CategoryTensorMd(
146162
name=_OUTPUT_CATRGORY_NAME, description=_OUTPUT_CATEGORY_DESCRIPTION)
163+
elif output_category_md.name != _OUTPUT_CATRGORY_NAME:
164+
logging.warning(warn_message_format, _OUTPUT_CATRGORY_NAME)
147165

148166
if output_score_md is None:
149167
output_score_md = metadata_info.ClassificationTensorMd(
150168
name=_OUTPUT_SCORE_NAME,
151169
description=_OUTPUT_SCORE_DESCRIPTION,
152170
)
171+
elif output_score_md.name != _OUTPUT_SCORE_NAME:
172+
logging.warning(warn_message_format, _OUTPUT_SCORE_NAME)
153173

154174
if output_number_md is None:
155175
output_number_md = metadata_info.TensorMd(
156176
name=_OUTPUT_NUMBER_NAME, description=_OUTPUT_NUMBER_DESCRIPTION)
177+
elif output_number_md.name != _OUTPUT_NUMBER_NAME:
178+
logging.warning(warn_message_format, _OUTPUT_NUMBER_NAME)
157179

158180
# Create output tensor group info.
159181
group = _metadata_fb.TensorGroupT()
@@ -162,15 +184,37 @@ def create_from_metadata_info(
162184
output_location_md.name, output_category_md.name, output_score_md.name
163185
]
164186

165-
# Create subgraph info.
166-
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
167-
subgraph_metadata.inputTensorMetadata = [input_md.create_metadata()]
168-
subgraph_metadata.outputTensorMetadata = [
187+
# Gets the tensor inidces of tflite outputs and then gets the order of the
188+
# output metadata by the value of tensor indices. For instance, if the
189+
# output indices are [601, 599, 598, 600], tensor names and indices aligned
190+
# are:
191+
# - location: 598
192+
# - category: 599
193+
# - score: 600
194+
# - number of detections: 601
195+
# because of the op's ports of TFLITE_DETECTION_POST_PROCESS
196+
# (https://github.com/tensorflow/tensorflow/blob/a4fe268ea084e7d323133ed7b986e0ae259a2bc7/tensorflow/lite/kernels/detection_postprocess.cc#L47-L50).
197+
# Thus, the metadata of tensors are sorted in this way, according to
198+
# output_tensor_indicies correctly.
199+
output_tensor_indices = _get_tflite_outputs(model_buffer)
200+
metadata_list = [
169201
_create_location_metadata(output_location_md),
170202
_create_metadata_with_value_range(output_category_md),
171203
_create_metadata_with_value_range(output_score_md),
172204
output_number_md.create_metadata()
173205
]
206+
207+
# Align indices with tensors.
208+
sorted_indices = sorted(output_tensor_indices)
209+
indices_to_tensors = dict(zip(sorted_indices, metadata_list))
210+
211+
# Output metadata according to output_tensor_indices.
212+
output_metadata = [indices_to_tensors[i] for i in output_tensor_indices]
213+
214+
# Create subgraph info.
215+
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
216+
subgraph_metadata.inputTensorMetadata = [input_md.create_metadata()]
217+
subgraph_metadata.outputTensorMetadata = output_metadata
174218
subgraph_metadata.outputTensorGroups = [group]
175219

176220
# Create model metadata

tensorflow_lite_support/metadata/python/tests/metadata_writers/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ py_test(
105105
"//tensorflow_lite_support/metadata/python/metadata_writers:metadata_info",
106106
"//tensorflow_lite_support/metadata/python/metadata_writers:object_detector",
107107
"@absl_py//absl/testing:parameterized",
108-
"@flatbuffers//:runtime_py",
109108
],
110109
)
111110

tensorflow_lite_support/metadata/python/tests/metadata_writers/object_detector_test.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_lite_support.metadata.python.metadata_writers import object_detector
2828
from tensorflow_lite_support.metadata.python.tests.metadata_writers import test_utils
2929

30+
_PATH = "../testdata/object_detector/"
3031
_MODEL = "../testdata/object_detector/ssd_mobilenet_v1.tflite"
3132
_LABEL_FILE = "../testdata/object_detector/labelmap.txt"
3233
_NORM_MEAN = 127.5
@@ -54,17 +55,31 @@ def setUp(self):
5455
self._dummy_score_file = test_utils.get_resource_path(
5556
_DUMMY_SCORE_CALIBRATION_FILE)
5657

57-
def test_create_for_inference_should_succeed(self):
58+
@parameterized.parameters(
59+
("ssd_mobilenet_v1"),
60+
("efficientdet_lite0_v1"),
61+
)
62+
def test_create_for_inference_should_succeed(self, model_name):
63+
model_path = os.path.join(_PATH, model_name + ".tflite")
5864
writer = object_detector.MetadataWriter.create_for_inference(
59-
test_utils.load_file(_MODEL), [_NORM_MEAN], [_NORM_STD],
65+
test_utils.load_file(model_path), [_NORM_MEAN], [_NORM_STD],
6066
[self._label_file])
61-
self._validate_metadata(writer, _JSON_FOR_INFERENCE)
67+
68+
json_path = os.path.join(_PATH, model_name + ".json")
69+
self._validate_metadata(writer, json_path)
6270
self._validate_populated_model(writer)
6371

64-
def test_create_from_metadata_info_by_default_should_succeed(self):
72+
@parameterized.parameters(
73+
("ssd_mobilenet_v1"),
74+
("efficientdet_lite0_v1"),
75+
)
76+
def test_create_from_metadata_info_by_default_should_succeed(
77+
self, model_name: str):
78+
model_path = os.path.join(_PATH, model_name + ".tflite")
6579
writer = object_detector.MetadataWriter.create_from_metadata_info(
66-
test_utils.load_file(_MODEL))
67-
self._validate_metadata(writer, _JSON_DEFAULT)
80+
test_utils.load_file(model_path))
81+
json_path = os.path.join(_PATH, model_name + "_default.json")
82+
self._validate_metadata(writer, json_path)
6883
self._validate_populated_model(writer)
6984

7085
def test_create_for_inference_score_calibration_should_succeed(self):
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
{
2+
"name": "ObjectDetector",
3+
"description": "Identify which of a known set of objects might be present and provide information about their positions within the given image or a video stream.",
4+
"subgraph_metadata": [
5+
{
6+
"input_tensor_metadata": [
7+
{
8+
"name": "image",
9+
"description": "Input image to be detected.",
10+
"content": {
11+
"content_properties_type": "ImageProperties",
12+
"content_properties": {
13+
"color_space": "RGB"
14+
}
15+
},
16+
"process_units": [
17+
{
18+
"options_type": "NormalizationOptions",
19+
"options": {
20+
"mean": [
21+
127.5
22+
],
23+
"std": [
24+
127.5
25+
]
26+
}
27+
}
28+
],
29+
"stats": {
30+
"max": [
31+
255.0
32+
],
33+
"min": [
34+
0.0
35+
]
36+
}
37+
}
38+
],
39+
"output_tensor_metadata": [
40+
{
41+
"name": "score",
42+
"description": "The scores of the detected boxes.",
43+
"content": {
44+
"content_properties_type": "FeatureProperties",
45+
"content_properties": {
46+
},
47+
"range": {
48+
"min": 2,
49+
"max": 2
50+
}
51+
},
52+
"stats": {
53+
}
54+
},
55+
{
56+
"name": "location",
57+
"description": "The locations of the detected boxes.",
58+
"content": {
59+
"content_properties_type": "BoundingBoxProperties",
60+
"content_properties": {
61+
"index": [
62+
1,
63+
0,
64+
3,
65+
2
66+
],
67+
"type": "BOUNDARIES"
68+
},
69+
"range": {
70+
"min": 2,
71+
"max": 2
72+
}
73+
},
74+
"stats": {
75+
}
76+
},
77+
{
78+
"name": "number of detections",
79+
"description": "The number of the detected boxes.",
80+
"content": {
81+
"content_properties_type": "FeatureProperties",
82+
"content_properties": {
83+
}
84+
},
85+
"stats": {
86+
}
87+
},
88+
{
89+
"name": "category",
90+
"description": "The categories of the detected boxes.",
91+
"content": {
92+
"content_properties_type": "FeatureProperties",
93+
"content_properties": {
94+
},
95+
"range": {
96+
"min": 2,
97+
"max": 2
98+
}
99+
},
100+
"stats": {
101+
},
102+
"associated_files": [
103+
{
104+
"name": "labelmap.txt",
105+
"description": "Labels for categories that the model can recognize.",
106+
"type": "TENSOR_VALUE_LABELS"
107+
}
108+
]
109+
}
110+
],
111+
"output_tensor_groups": [
112+
{
113+
"name": "detection_result",
114+
"tensor_names": [
115+
"location",
116+
"category",
117+
"score"
118+
]
119+
}
120+
]
121+
}
122+
]
123+
}

0 commit comments

Comments
 (0)