@@ -46,13 +46,13 @@ def _build_inputs(self, image):
46
46
offset = MEAN_RGB ,
47
47
scale = STDDEV_RGB )
48
48
49
- image , _ = preprocess_ops .resize_and_crop_image (
49
+ image , image_info = preprocess_ops .resize_and_crop_image (
50
50
image ,
51
51
self ._input_image_size ,
52
52
padded_size = self ._input_image_size ,
53
53
aug_scale_min = 1.0 ,
54
54
aug_scale_max = 1.0 )
55
- return image
55
+ return image , image_info
56
56
57
57
def serve (self , images ):
58
58
"""Cast image to float and run inference.
@@ -64,21 +64,27 @@ def serve(self, images):
64
64
"""
65
65
# Skip image preprocessing when input_type is tflite so it is compatible
66
66
# with TFLite quantization.
67
+ image_info = None
67
68
if self ._input_type != 'tflite' :
68
69
with tf .device ('cpu:0' ):
69
70
images = tf .cast (images , dtype = tf .float32 )
71
+ images_spec = tf .TensorSpec (
72
+ shape = self ._input_image_size + [3 ], dtype = tf .float32 )
73
+ image_info_spec = tf .TensorSpec (shape = [4 , 2 ], dtype = tf .float32 )
70
74
71
- images = tf .nest .map_structure (
75
+ images , image_info = tf .nest .map_structure (
72
76
tf .identity ,
73
77
tf .map_fn (
74
78
self ._build_inputs ,
75
79
elems = images ,
76
- fn_output_signature = tf .TensorSpec (
77
- shape = self ._input_image_size + [3 ], dtype = tf .float32 ),
80
+ fn_output_signature = (images_spec , image_info_spec ),
78
81
parallel_iterations = 32 ))
79
82
80
83
outputs = self .inference_step (images )
81
84
outputs ['logits' ] = tf .image .resize (
82
85
outputs ['logits' ], self ._input_image_size , method = 'bilinear' )
83
86
87
+ if image_info is not None :
88
+ outputs .update ({'image_info' : image_info })
89
+
84
90
return outputs
0 commit comments