Skip to content

Commit 2333570

Browse files
No public description
PiperOrigin-RevId: 671550909
1 parent 65f6f10 commit 2333570

File tree

2 files changed

+75
-8
lines changed

2 files changed

+75
-8
lines changed

official/vision/serving/semantic_segmentation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ class SegmentationModule(export_base.ExportModule):
2626

2727
def _build_model(self):
2828
input_specs = tf_keras.layers.InputSpec(
29-
shape=[self._batch_size] + self._input_image_size + [3])
29+
shape=[self._batch_size] + self._input_image_size + [self._num_channels]
30+
)
3031

3132
return factory.build_segmentation_model(
3233
input_specs=input_specs,
@@ -72,7 +73,9 @@ def serve(self, images):
7273
if self._input_type != 'tflite':
7374
with tf.device('cpu:0'):
7475
images_spec = tf.TensorSpec(
75-
shape=self._input_image_size + [3], dtype=tf.float32)
76+
shape=self._input_image_size + [self._num_channels],
77+
dtype=tf.float32,
78+
)
7679
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
7780

7881
images, image_info = tf.nest.map_structure(

official/vision/serving/semantic_segmentation_test.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,22 @@ def _export_from_module(self, module, input_type, save_directory):
4949
{input_type: 'serving_default'})
5050
tf.saved_model.save(module, save_directory, signatures=signatures)
5151

52-
def _get_dummy_input(self, input_type, input_image_size):
52+
def _get_dummy_input(self, input_type, input_image_size, num_channels):
5353
"""Get dummy input for the given input type."""
5454

5555
height = input_image_size[0]
5656
width = input_image_size[1]
5757
if input_type == 'image_tensor':
58-
return tf.zeros((1, height, width, 3), dtype=np.uint8)
58+
return tf.zeros((1, height, width, num_channels), dtype=np.uint8)
5959
elif input_type == 'image_bytes':
60-
image = Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))
60+
image = Image.fromarray(
61+
np.zeros((height, width, num_channels), dtype=np.uint8)
62+
)
6163
byte_io = io.BytesIO()
6264
image.save(byte_io, 'PNG')
6365
return [byte_io.getvalue()]
6466
elif input_type == 'tf_example':
65-
image_tensor = tf.zeros((height, width, 3), dtype=tf.uint8)
67+
image_tensor = tf.zeros((height, width, num_channels), dtype=tf.uint8)
6668
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
6769
example = tf.train.Example(
6870
features=tf.train.Features(
@@ -73,7 +75,7 @@ def _get_dummy_input(self, input_type, input_image_size):
7375
})).SerializeToString()
7476
return [example]
7577
elif input_type == 'tflite':
76-
return tf.zeros((1, height, width, 3), dtype=np.float32)
78+
return tf.zeros((1, height, width, num_channels), dtype=np.float32)
7779

7880
@parameterized.parameters(
7981
('image_tensor', False, [112, 112], False),
@@ -105,7 +107,7 @@ def test_export(self, input_type, rescale_output, input_image_size,
105107
imported = tf.saved_model.load(tmp_dir)
106108
segmentation_fn = imported.signatures['serving_default']
107109

108-
images = self._get_dummy_input(input_type, input_image_size)
110+
images = self._get_dummy_input(input_type, input_image_size, num_channels=3)
109111
if input_type != 'tflite':
110112
processed_images, _ = tf.nest.map_structure(
111113
tf.stop_gradient,
@@ -128,6 +130,68 @@ def test_export(self, input_type, rescale_output, input_image_size,
128130
out = segmentation_fn(tf.constant(images))
129131
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
130132

133+
@parameterized.parameters(
134+
('image_tensor',),
135+
('tflite',),
136+
)
137+
def test_export_with_extra_input_channels(self, input_type):
138+
tmp_dir = self.get_temp_dir()
139+
num_channels = 6
140+
params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal')
141+
params.task.init_checkpoint = None
142+
params.task.model.input_size = [112, 112, num_channels]
143+
params.task.export_config.rescale_output = False
144+
params.task.train_data.preserve_aspect_ratio = False
145+
params.task.train_data.image_feature.mean = [0.5] * num_channels
146+
params.task.train_data.image_feature.stddev = [0.5] * num_channels
147+
params.task.train_data.image_feature.num_channels = num_channels
148+
module = semantic_segmentation.SegmentationModule(
149+
params,
150+
batch_size=1,
151+
input_image_size=[112, 112],
152+
input_type=input_type,
153+
num_channels=num_channels,
154+
)
155+
156+
self._export_from_module(module, input_type, tmp_dir)
157+
158+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
159+
self.assertTrue(
160+
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index'))
161+
)
162+
self.assertTrue(
163+
os.path.exists(
164+
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')
165+
)
166+
)
167+
168+
imported = tf.saved_model.load(tmp_dir)
169+
segmentation_fn = imported.signatures['serving_default']
170+
171+
images = self._get_dummy_input(input_type, [112, 112], num_channels)
172+
173+
if input_type != 'tflite':
174+
processed_images, _ = tf.nest.map_structure(
175+
tf.stop_gradient,
176+
tf.map_fn(
177+
module._build_inputs,
178+
elems=tf.zeros((1, 112, 112, num_channels), dtype=tf.uint8),
179+
fn_output_signature=(
180+
tf.TensorSpec(
181+
shape=[112, 112, num_channels], dtype=tf.float32
182+
),
183+
tf.TensorSpec(shape=[4, 2], dtype=tf.float32),
184+
),
185+
),
186+
)
187+
else:
188+
processed_images = images
189+
190+
logits = module.model(processed_images, training=False)['logits']
191+
expected_output = tf.image.resize(logits, [112, 112], method='bilinear')
192+
out = segmentation_fn(tf.constant(images))
193+
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
194+
131195
def test_export_invalid_batch_size(self):
132196
batch_size = 3
133197
tmp_dir = self.get_temp_dir()

0 commit comments

Comments
 (0)