Skip to content

Commit d21a9bb

Browse files
fyangftensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 422863802
1 parent a4df1c6 commit d21a9bb

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

official/vision/beta/serving/semantic_segmentation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ def _build_inputs(self, image):
4646
offset=MEAN_RGB,
4747
scale=STDDEV_RGB)
4848

49-
image, _ = preprocess_ops.resize_and_crop_image(
49+
image, image_info = preprocess_ops.resize_and_crop_image(
5050
image,
5151
self._input_image_size,
5252
padded_size=self._input_image_size,
5353
aug_scale_min=1.0,
5454
aug_scale_max=1.0)
55-
return image
55+
return image, image_info
5656

5757
def serve(self, images):
5858
"""Cast image to float and run inference.
@@ -64,21 +64,27 @@ def serve(self, images):
6464
"""
6565
# Skip image preprocessing when input_type is tflite so it is compatible
6666
# with TFLite quantization.
67+
image_info = None
6768
if self._input_type != 'tflite':
6869
with tf.device('cpu:0'):
6970
images = tf.cast(images, dtype=tf.float32)
71+
images_spec = tf.TensorSpec(
72+
shape=self._input_image_size + [3], dtype=tf.float32)
73+
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
7074

71-
images = tf.nest.map_structure(
75+
images, image_info = tf.nest.map_structure(
7276
tf.identity,
7377
tf.map_fn(
7478
self._build_inputs,
7579
elems=images,
76-
fn_output_signature=tf.TensorSpec(
77-
shape=self._input_image_size + [3], dtype=tf.float32),
80+
fn_output_signature=(images_spec, image_info_spec),
7881
parallel_iterations=32))
7982

8083
outputs = self.inference_step(images)
8184
outputs['logits'] = tf.image.resize(
8285
outputs['logits'], self._input_image_size, method='bilinear')
8386

87+
if image_info is not None:
88+
outputs.update({'image_info': image_info})
89+
8490
return outputs

official/vision/beta/serving/semantic_segmentation_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ def test_export(self, input_type='image_tensor'):
9393

9494
images = self._get_dummy_input(input_type)
9595
if input_type != 'tflite':
96-
processed_images = tf.nest.map_structure(
96+
processed_images, _ = tf.nest.map_structure(
9797
tf.stop_gradient,
9898
tf.map_fn(
9999
module._build_inputs,
100100
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
101-
fn_output_signature=tf.TensorSpec(
102-
shape=[112, 112, 3], dtype=tf.float32)))
101+
fn_output_signature=(tf.TensorSpec(
102+
shape=[112, 112, 3], dtype=tf.float32),
103+
tf.TensorSpec(
104+
shape=[4, 2], dtype=tf.float32))))
103105
else:
104106
processed_images = images
105107
expected_output = tf.image.resize(

0 commit comments

Comments
 (0)