@@ -75,6 +75,13 @@ def postprocess(self, prediction_dict, true_image_shapes):
75
75
}
76
76
return postprocessed_tensors
77
77
78
+ def predict_masks_from_boxes (self , prediction_dict , true_image_shapes , boxes ):
79
+ output_dict = self .postprocess (prediction_dict , true_image_shapes )
80
+ output_dict .update ({
81
+ 'detection_masks' : tf .ones (shape = (1 , 2 , 16 ), dtype = tf .float32 ),
82
+ })
83
+ return output_dict
84
+
78
85
def restore_map (self , checkpoint_path , fine_tune_checkpoint_type ):
79
86
pass
80
87
@@ -291,6 +298,83 @@ def test_export_checkpoint_and_run_inference_with_image(self):
291
298
[[150 + 0.7 , 150 + 0.6 ], [150 + 0.9 , 150 + 0.0 ]])
292
299
293
300
301
+ class DetectionFromImageAndBoxModuleTest (tf .test .TestCase ):
302
+
303
+ def get_dummy_input (self , input_type ):
304
+ """Get dummy input for the given input type."""
305
+
306
+ if input_type == 'image_tensor' or input_type == 'image_and_boxes_tensor' :
307
+ return np .zeros ((1 , 20 , 20 , 3 ), dtype = np .uint8 )
308
+ if input_type == 'float_image_tensor' :
309
+ return np .zeros ((1 , 20 , 20 , 3 ), dtype = np .float32 )
310
+ elif input_type == 'encoded_image_string_tensor' :
311
+ image = Image .new ('RGB' , (20 , 20 ))
312
+ byte_io = io .BytesIO ()
313
+ image .save (byte_io , 'PNG' )
314
+ return [byte_io .getvalue ()]
315
+ elif input_type == 'tf_example' :
316
+ image_tensor = tf .zeros ((20 , 20 , 3 ), dtype = tf .uint8 )
317
+ encoded_jpeg = tf .image .encode_jpeg (tf .constant (image_tensor )).numpy ()
318
+ example = tf .train .Example (
319
+ features = tf .train .Features (
320
+ feature = {
321
+ 'image/encoded' :
322
+ dataset_util .bytes_feature (encoded_jpeg ),
323
+ 'image/format' :
324
+ dataset_util .bytes_feature (six .b ('jpeg' )),
325
+ 'image/source_id' :
326
+ dataset_util .bytes_feature (six .b ('image_id' )),
327
+ })).SerializeToString ()
328
+ return [example ]
329
+
330
+ def _save_checkpoint_from_mock_model (self ,
331
+ checkpoint_dir ,
332
+ conv_weight_scalar = 6.0 ):
333
+ mock_model = FakeModel (conv_weight_scalar )
334
+ fake_image = tf .zeros (shape = [1 , 10 , 10 , 3 ], dtype = tf .float32 )
335
+ preprocessed_inputs , true_image_shapes = mock_model .preprocess (fake_image )
336
+ predictions = mock_model .predict (preprocessed_inputs , true_image_shapes )
337
+ mock_model .postprocess (predictions , true_image_shapes )
338
+
339
+ ckpt = tf .train .Checkpoint (model = mock_model )
340
+ exported_checkpoint_manager = tf .train .CheckpointManager (
341
+ ckpt , checkpoint_dir , max_to_keep = 1 )
342
+ exported_checkpoint_manager .save (checkpoint_number = 0 )
343
+
344
+ def test_export_saved_model_and_run_inference_for_segmentation (
345
+ self , input_type = 'image_and_boxes_tensor' ):
346
+ tmp_dir = self .get_temp_dir ()
347
+ self ._save_checkpoint_from_mock_model (tmp_dir )
348
+
349
+ with mock .patch .object (
350
+ model_builder , 'build' , autospec = True ) as mock_builder :
351
+ mock_builder .return_value = FakeModel ()
352
+ exporter_lib_v2 .INPUT_BUILDER_UTIL_MAP ['model_build' ] = mock_builder
353
+ output_directory = os .path .join (tmp_dir , 'output' )
354
+ pipeline_config = pipeline_pb2 .TrainEvalPipelineConfig ()
355
+ exporter_lib_v2 .export_inference_graph (
356
+ input_type = input_type ,
357
+ pipeline_config = pipeline_config ,
358
+ trained_checkpoint_dir = tmp_dir ,
359
+ output_directory = output_directory )
360
+
361
+ saved_model_path = os .path .join (output_directory , 'saved_model' )
362
+ detect_fn = tf .saved_model .load (saved_model_path )
363
+ image = self .get_dummy_input (input_type )
364
+ boxes = tf .constant ([
365
+ [
366
+ [0.0 , 0.0 , 0.5 , 0.5 ],
367
+ [0.5 , 0.5 , 0.8 , 0.8 ],
368
+ ],
369
+ ])
370
+ detections = detect_fn (tf .constant (image ), boxes )
371
+
372
+ detection_fields = fields .DetectionResultFields
373
+ self .assertIn (detection_fields .detection_masks , detections )
374
+ self .assertListEqual (
375
+ list (detections [detection_fields .detection_masks ].shape ), [1 , 2 , 16 ])
376
+
377
+
294
378
if __name__ == '__main__' :
295
379
tf .enable_v2_behavior ()
296
380
tf .test .main ()
0 commit comments