Skip to content

Commit 676a4f7

Browse files
author
Jonathan Huang
authored
Merge pull request #2826 from tombstone/inference
add inference tools for Open Image dataset.
2 parents 60fc781 + e836fc6 commit 676a4f7

File tree

4 files changed

+453
-0
lines changed

4 files changed

+453
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Tensorflow Object Detection API: main runnables.
2+
3+
package(
4+
default_visibility = ["//visibility:public"],
5+
)
6+
7+
licenses(["notice"])
8+
9+
# Apache 2.0
10+
11+
py_library(
12+
name = "detection_inference",
13+
srcs = ["detection_inference.py"],
14+
deps = [
15+
"//tensorflow",
16+
"//tensorflow_models/object_detection/core:standard_fields",
17+
],
18+
)
19+
20+
py_test(
21+
name = "detection_inference_test",
22+
srcs = ["detection_inference_test.py"],
23+
deps = [
24+
":detection_inference",
25+
"//third_party/py/PIL:pil",
26+
"//third_party/py/numpy",
27+
"//tensorflow",
28+
"//tensorflow_models/object_detection/core:standard_fields",
29+
"//tensorflow_models/object_detection/utils:dataset_util",
30+
],
31+
)
32+
33+
py_binary(
34+
name = "infer_detections",
35+
srcs = ["infer_detections.py"],
36+
deps = [
37+
":detection_inference",
38+
"//tensorflow",
39+
],
40+
)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility functions for detection inference."""
16+
from __future__ import division
17+
18+
import tensorflow as tf
19+
20+
from object_detection.core import standard_fields
21+
22+
23+
def build_input(tfrecord_paths):
24+
"""Builds the graph's input.
25+
26+
Args:
27+
tfrecord_paths: List of paths to the input TFRecords
28+
29+
Returns:
30+
serialized_example_tensor: The next serialized example. String scalar Tensor
31+
image_tensor: The decoded image of the example. Uint8 tensor,
32+
shape=[1, None, None,3]
33+
"""
34+
filename_queue = tf.train.string_input_producer(
35+
tfrecord_paths, shuffle=False, num_epochs=1)
36+
37+
tf_record_reader = tf.TFRecordReader()
38+
_, serialized_example_tensor = tf_record_reader.read(filename_queue)
39+
features = tf.parse_single_example(
40+
serialized_example_tensor,
41+
features={
42+
standard_fields.TfExampleFields.image_encoded:
43+
tf.FixedLenFeature([], tf.string),
44+
})
45+
encoded_image = features[standard_fields.TfExampleFields.image_encoded]
46+
image_tensor = tf.image.decode_image(encoded_image, channels=3)
47+
image_tensor.set_shape([None, None, 3])
48+
image_tensor = tf.expand_dims(image_tensor, 0)
49+
50+
return serialized_example_tensor, image_tensor
51+
52+
53+
def build_inference_graph(image_tensor, inference_graph_path):
54+
"""Loads the inference graph and connects it to the input image.
55+
56+
Args:
57+
image_tensor: The input image. uint8 tensor, shape=[1, None, None, 3]
58+
inference_graph_path: Path to the inference graph with embedded weights
59+
60+
Returns:
61+
detected_boxes_tensor: Detected boxes. Float tensor,
62+
shape=[num_detections, 4]
63+
detected_scores_tensor: Detected scores. Float tensor,
64+
shape=[num_detections]
65+
detected_labels_tensor: Detected labels. Int64 tensor,
66+
shape=[num_detections]
67+
"""
68+
with tf.gfile.Open(inference_graph_path, 'r') as graph_def_file:
69+
graph_content = graph_def_file.read()
70+
graph_def = tf.GraphDef()
71+
graph_def.MergeFromString(graph_content)
72+
73+
tf.import_graph_def(
74+
graph_def, name='', input_map={'image_tensor': image_tensor})
75+
76+
g = tf.get_default_graph()
77+
78+
num_detections_tensor = tf.squeeze(
79+
g.get_tensor_by_name('num_detections:0'), 0)
80+
num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)
81+
82+
detected_boxes_tensor = tf.squeeze(
83+
g.get_tensor_by_name('detection_boxes:0'), 0)
84+
detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]
85+
86+
detected_scores_tensor = tf.squeeze(
87+
g.get_tensor_by_name('detection_scores:0'), 0)
88+
detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]
89+
90+
detected_labels_tensor = tf.squeeze(
91+
g.get_tensor_by_name('detection_classes:0'), 0)
92+
detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
93+
detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]
94+
95+
return detected_boxes_tensor, detected_scores_tensor, detected_labels_tensor
96+
97+
98+
def infer_detections_and_add_to_example(
99+
serialized_example_tensor, detected_boxes_tensor, detected_scores_tensor,
100+
detected_labels_tensor, discard_image_pixels):
101+
"""Runs the supplied tensors and adds the inferred detections to the example.
102+
103+
Args:
104+
serialized_example_tensor: Serialized TF example. Scalar string tensor
105+
detected_boxes_tensor: Detected boxes. Float tensor,
106+
shape=[num_detections, 4]
107+
detected_scores_tensor: Detected scores. Float tensor,
108+
shape=[num_detections]
109+
detected_labels_tensor: Detected labels. Int64 tensor,
110+
shape=[num_detections]
111+
discard_image_pixels: If true, discards the image from the result
112+
Returns:
113+
The de-serialized TF example augmented with the inferred detections.
114+
"""
115+
tf_example = tf.train.Example()
116+
(serialized_example, detected_boxes, detected_scores,
117+
detected_classes) = tf.get_default_session().run([
118+
serialized_example_tensor, detected_boxes_tensor, detected_scores_tensor,
119+
detected_labels_tensor
120+
])
121+
detected_boxes = detected_boxes.T
122+
123+
tf_example.ParseFromString(serialized_example)
124+
feature = tf_example.features.feature
125+
feature[standard_fields.TfExampleFields.
126+
detection_score].float_list.value[:] = detected_scores
127+
feature[standard_fields.TfExampleFields.
128+
detection_bbox_ymin].float_list.value[:] = detected_boxes[0]
129+
feature[standard_fields.TfExampleFields.
130+
detection_bbox_xmin].float_list.value[:] = detected_boxes[1]
131+
feature[standard_fields.TfExampleFields.
132+
detection_bbox_ymax].float_list.value[:] = detected_boxes[2]
133+
feature[standard_fields.TfExampleFields.
134+
detection_bbox_xmax].float_list.value[:] = detected_boxes[3]
135+
feature[standard_fields.TfExampleFields.
136+
detection_class_label].int64_list.value[:] = detected_classes
137+
138+
if discard_image_pixels:
139+
del feature[standard_fields.TfExampleFields.image_encoded]
140+
141+
return tf_example
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
r"""Tests for detection_inference.py."""
16+
17+
import os
18+
import StringIO
19+
20+
import numpy as np
21+
from PIL import Image
22+
import tensorflow as tf
23+
24+
from object_detection.core import standard_fields
25+
from object_detection.inference import detection_inference
26+
from object_detection.utils import dataset_util
27+
28+
29+
def get_mock_tfrecord_path():
30+
return os.path.join(tf.test.get_temp_dir(), 'mock.tfrec')
31+
32+
33+
def create_mock_tfrecord():
34+
pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB')
35+
image_output_stream = StringIO.StringIO()
36+
pil_image.save(image_output_stream, format='png')
37+
encoded_image = image_output_stream.getvalue()
38+
39+
feature_map = {
40+
'test_field':
41+
dataset_util.float_list_feature([1, 2, 3, 4]),
42+
standard_fields.TfExampleFields.image_encoded:
43+
dataset_util.bytes_feature(encoded_image),
44+
}
45+
46+
tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map))
47+
with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer:
48+
writer.write(tf_example.SerializeToString())
49+
50+
51+
def get_mock_graph_path():
52+
return os.path.join(tf.test.get_temp_dir(), 'mock_graph.pb')
53+
54+
55+
def create_mock_graph():
56+
g = tf.Graph()
57+
with g.as_default():
58+
in_image_tensor = tf.placeholder(
59+
tf.uint8, shape=[1, None, None, 3], name='image_tensor')
60+
tf.constant([2.0], name='num_detections')
61+
tf.constant(
62+
[[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]],
63+
name='detection_boxes')
64+
tf.constant([[0.1, 0.2, 0.3]], name='detection_scores')
65+
tf.identity(
66+
tf.constant([[1.0, 2.0, 3.0]]) *
67+
tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)),
68+
name='detection_classes')
69+
graph_def = g.as_graph_def()
70+
71+
with tf.gfile.Open(get_mock_graph_path(), 'w') as fl:
72+
fl.write(graph_def.SerializeToString())
73+
74+
75+
class InferDetectionsTests(tf.test.TestCase):
76+
77+
def test_simple(self):
78+
create_mock_graph()
79+
create_mock_tfrecord()
80+
81+
serialized_example_tensor, image_tensor = detection_inference.build_input(
82+
[get_mock_tfrecord_path()])
83+
self.assertAllEqual(image_tensor.get_shape().as_list(), [1, None, None, 3])
84+
85+
(detected_boxes_tensor, detected_scores_tensor,
86+
detected_labels_tensor) = detection_inference.build_inference_graph(
87+
image_tensor, get_mock_graph_path())
88+
89+
with self.test_session(use_gpu=False) as sess:
90+
sess.run(tf.global_variables_initializer())
91+
sess.run(tf.local_variables_initializer())
92+
tf.train.start_queue_runners()
93+
94+
tf_example = detection_inference.infer_detections_and_add_to_example(
95+
serialized_example_tensor, detected_boxes_tensor,
96+
detected_scores_tensor, detected_labels_tensor, False)
97+
98+
self.assertProtoEquals(r"""
99+
features {
100+
feature {
101+
key: "image/detection/bbox/ymin"
102+
value { float_list { value: [0.0, 0.1] } } }
103+
feature {
104+
key: "image/detection/bbox/xmin"
105+
value { float_list { value: [0.8, 0.2] } } }
106+
feature {
107+
key: "image/detection/bbox/ymax"
108+
value { float_list { value: [0.7, 0.8] } } }
109+
feature {
110+
key: "image/detection/bbox/xmax"
111+
value { float_list { value: [1.0, 0.9] } } }
112+
feature {
113+
key: "image/detection/label"
114+
value { int64_list { value: [123, 246] } } }
115+
feature {
116+
key: "image/detection/score"
117+
value { float_list { value: [0.1, 0.2] } } }
118+
feature {
119+
key: "image/encoded"
120+
value { bytes_list { value:
121+
"\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\001\000\000"
122+
"\000\001\010\002\000\000\000\220wS\336\000\000\000\022IDATx"
123+
"\234b\250f`\000\000\000\000\377\377\003\000\001u\000|gO\242"
124+
"\213\000\000\000\000IEND\256B`\202" } } }
125+
feature {
126+
key: "test_field"
127+
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
128+
""", tf_example)
129+
130+
def test_discard_image(self):
131+
create_mock_graph()
132+
create_mock_tfrecord()
133+
134+
serialized_example_tensor, image_tensor = detection_inference.build_input(
135+
[get_mock_tfrecord_path()])
136+
(detected_boxes_tensor, detected_scores_tensor,
137+
detected_labels_tensor) = detection_inference.build_inference_graph(
138+
image_tensor, get_mock_graph_path())
139+
140+
with self.test_session(use_gpu=False) as sess:
141+
sess.run(tf.global_variables_initializer())
142+
sess.run(tf.local_variables_initializer())
143+
tf.train.start_queue_runners()
144+
145+
tf_example = detection_inference.infer_detections_and_add_to_example(
146+
serialized_example_tensor, detected_boxes_tensor,
147+
detected_scores_tensor, detected_labels_tensor, True)
148+
149+
self.assertProtoEquals(r"""
150+
features {
151+
feature {
152+
key: "image/detection/bbox/ymin"
153+
value { float_list { value: [0.0, 0.1] } } }
154+
feature {
155+
key: "image/detection/bbox/xmin"
156+
value { float_list { value: [0.8, 0.2] } } }
157+
feature {
158+
key: "image/detection/bbox/ymax"
159+
value { float_list { value: [0.7, 0.8] } } }
160+
feature {
161+
key: "image/detection/bbox/xmax"
162+
value { float_list { value: [1.0, 0.9] } } }
163+
feature {
164+
key: "image/detection/label"
165+
value { int64_list { value: [123, 246] } } }
166+
feature {
167+
key: "image/detection/score"
168+
value { float_list { value: [0.1, 0.2] } } }
169+
feature {
170+
key: "test_field"
171+
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
172+
""", tf_example)
173+
174+
175+
if __name__ == '__main__':
176+
tf.test.main()

0 commit comments

Comments
 (0)