Skip to content

Commit 532b946

Browse files
Vighnesh BirodkarTF Object Detection Team
authored andcommitted
Add image+box module in exporter.
PiperOrigin-RevId: 363794263
1 parent dfaf525 commit 532b946

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

research/object_detection/exporter_lib_tf2_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def postprocess(self, prediction_dict, true_image_shapes):
7575
}
7676
return postprocessed_tensors
7777

78+
def predict_masks_from_boxes(self, prediction_dict, true_image_shapes, boxes):
79+
output_dict = self.postprocess(prediction_dict, true_image_shapes)
80+
output_dict.update({
81+
'detection_masks': tf.ones(shape=(1, 2, 16), dtype=tf.float32),
82+
})
83+
return output_dict
84+
7885
def restore_map(self, checkpoint_path, fine_tune_checkpoint_type):
7986
pass
8087

@@ -291,6 +298,83 @@ def test_export_checkpoint_and_run_inference_with_image(self):
291298
[[150 + 0.7, 150 + 0.6], [150 + 0.9, 150 + 0.0]])
292299

293300

301+
class DetectionFromImageAndBoxModuleTest(tf.test.TestCase):
302+
303+
def get_dummy_input(self, input_type):
304+
"""Get dummy input for the given input type."""
305+
306+
if input_type == 'image_tensor' or input_type == 'image_and_boxes_tensor':
307+
return np.zeros((1, 20, 20, 3), dtype=np.uint8)
308+
if input_type == 'float_image_tensor':
309+
return np.zeros((1, 20, 20, 3), dtype=np.float32)
310+
elif input_type == 'encoded_image_string_tensor':
311+
image = Image.new('RGB', (20, 20))
312+
byte_io = io.BytesIO()
313+
image.save(byte_io, 'PNG')
314+
return [byte_io.getvalue()]
315+
elif input_type == 'tf_example':
316+
image_tensor = tf.zeros((20, 20, 3), dtype=tf.uint8)
317+
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
318+
example = tf.train.Example(
319+
features=tf.train.Features(
320+
feature={
321+
'image/encoded':
322+
dataset_util.bytes_feature(encoded_jpeg),
323+
'image/format':
324+
dataset_util.bytes_feature(six.b('jpeg')),
325+
'image/source_id':
326+
dataset_util.bytes_feature(six.b('image_id')),
327+
})).SerializeToString()
328+
return [example]
329+
330+
def _save_checkpoint_from_mock_model(self,
331+
checkpoint_dir,
332+
conv_weight_scalar=6.0):
333+
mock_model = FakeModel(conv_weight_scalar)
334+
fake_image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32)
335+
preprocessed_inputs, true_image_shapes = mock_model.preprocess(fake_image)
336+
predictions = mock_model.predict(preprocessed_inputs, true_image_shapes)
337+
mock_model.postprocess(predictions, true_image_shapes)
338+
339+
ckpt = tf.train.Checkpoint(model=mock_model)
340+
exported_checkpoint_manager = tf.train.CheckpointManager(
341+
ckpt, checkpoint_dir, max_to_keep=1)
342+
exported_checkpoint_manager.save(checkpoint_number=0)
343+
344+
def test_export_saved_model_and_run_inference_for_segmentation(
345+
self, input_type='image_and_boxes_tensor'):
346+
tmp_dir = self.get_temp_dir()
347+
self._save_checkpoint_from_mock_model(tmp_dir)
348+
349+
with mock.patch.object(
350+
model_builder, 'build', autospec=True) as mock_builder:
351+
mock_builder.return_value = FakeModel()
352+
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
353+
output_directory = os.path.join(tmp_dir, 'output')
354+
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
355+
exporter_lib_v2.export_inference_graph(
356+
input_type=input_type,
357+
pipeline_config=pipeline_config,
358+
trained_checkpoint_dir=tmp_dir,
359+
output_directory=output_directory)
360+
361+
saved_model_path = os.path.join(output_directory, 'saved_model')
362+
detect_fn = tf.saved_model.load(saved_model_path)
363+
image = self.get_dummy_input(input_type)
364+
boxes = tf.constant([
365+
[
366+
[0.0, 0.0, 0.5, 0.5],
367+
[0.5, 0.5, 0.8, 0.8],
368+
],
369+
])
370+
detections = detect_fn(tf.constant(image), boxes)
371+
372+
detection_fields = fields.DetectionResultFields
373+
self.assertIn(detection_fields.detection_masks, detections)
374+
self.assertListEqual(
375+
list(detections[detection_fields.detection_masks].shape), [1, 2, 16])
376+
377+
294378
if __name__ == '__main__':
295379
tf.enable_v2_behavior()
296380
tf.test.main()

research/object_detection/exporter_lib_v2.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,73 @@ def export_inference_graph(input_type,
288288
signatures=concrete_function)
289289

290290
config_util.save_pipeline_config(pipeline_config, output_directory)
291+
292+
293+
class DetectionFromImageAndBoxModule(DetectionInferenceModule):
294+
"""Detection Inference Module for image with bounding box inputs.
295+
296+
The saved model will require two inputs (image and normalized boxes) and run
297+
per-box mask prediction. To be compatible with this exporter, the detection
298+
model has to implement a called predict_masks_from_boxes(
299+
prediction_dict, true_image_shapes, provided_boxes, **params), where
300+
- prediciton_dict is a dict returned by the predict method.
301+
- true_image_shapes is a tensor of size [batch_size, 3], containing the
302+
true shape of each image in case it is padded.
303+
- provided_boxes is a [batch_size, num_boxes, 4] size tensor containing
304+
boxes specified in normalized coordinates.
305+
"""
306+
307+
def __init__(self,
308+
detection_model,
309+
use_side_inputs=False,
310+
zipped_side_inputs=None):
311+
"""Initializes a module for detection.
312+
313+
Args:
314+
detection_model: the detection model to use for inference.
315+
use_side_inputs: whether to use side inputs.
316+
zipped_side_inputs: the zipped side inputs.
317+
"""
318+
assert hasattr(detection_model, 'predict_masks_from_boxes')
319+
super(DetectionFromImageAndBoxModule,
320+
self).__init__(detection_model, use_side_inputs, zipped_side_inputs)
321+
322+
def _run_segmentation_on_images(self, image, boxes, **kwargs):
323+
"""Run segmentation on images with provided boxes.
324+
325+
Args:
326+
image: uint8 Tensor of shape [1, None, None, 3].
327+
boxes: float32 tensor of shape [1, None, 4] containing normalized box
328+
coordinates.
329+
**kwargs: additional keyword arguments.
330+
331+
Returns:
332+
Tensor dictionary holding detections (including masks).
333+
"""
334+
label_id_offset = 1
335+
336+
image = tf.cast(image, tf.float32)
337+
image, shapes = self._model.preprocess(image)
338+
prediction_dict = self._model.predict(image, shapes, **kwargs)
339+
detections = self._model.predict_masks_from_boxes(prediction_dict, shapes,
340+
boxes)
341+
classes_field = fields.DetectionResultFields.detection_classes
342+
detections[classes_field] = (
343+
tf.cast(detections[classes_field], tf.float32) + label_id_offset)
344+
345+
for key, val in detections.items():
346+
detections[key] = tf.cast(val, tf.float32)
347+
348+
return detections
349+
350+
@tf.function(input_signature=[
351+
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8),
352+
tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32)
353+
])
354+
def __call__(self, input_tensor, boxes):
355+
return self._run_segmentation_on_images(input_tensor, boxes)
356+
357+
358+
DETECTION_MODULE_MAP.update({
359+
'image_and_boxes_tensor': DetectionFromImageAndBoxModule,
360+
})

0 commit comments

Comments
 (0)