@@ -49,20 +49,22 @@ def _export_from_module(self, module, input_type, save_directory):
49
49
{input_type : 'serving_default' })
50
50
tf .saved_model .save (module , save_directory , signatures = signatures )
51
51
52
- def _get_dummy_input (self , input_type , input_image_size ):
52
+ def _get_dummy_input (self , input_type , input_image_size , num_channels ):
53
53
"""Get dummy input for the given input type."""
54
54
55
55
height = input_image_size [0 ]
56
56
width = input_image_size [1 ]
57
57
if input_type == 'image_tensor' :
58
- return tf .zeros ((1 , height , width , 3 ), dtype = np .uint8 )
58
+ return tf .zeros ((1 , height , width , num_channels ), dtype = np .uint8 )
59
59
elif input_type == 'image_bytes' :
60
- image = Image .fromarray (np .zeros ((height , width , 3 ), dtype = np .uint8 ))
60
+ image = Image .fromarray (
61
+ np .zeros ((height , width , num_channels ), dtype = np .uint8 )
62
+ )
61
63
byte_io = io .BytesIO ()
62
64
image .save (byte_io , 'PNG' )
63
65
return [byte_io .getvalue ()]
64
66
elif input_type == 'tf_example' :
65
- image_tensor = tf .zeros ((height , width , 3 ), dtype = tf .uint8 )
67
+ image_tensor = tf .zeros ((height , width , num_channels ), dtype = tf .uint8 )
66
68
encoded_jpeg = tf .image .encode_jpeg (tf .constant (image_tensor )).numpy ()
67
69
example = tf .train .Example (
68
70
features = tf .train .Features (
@@ -73,7 +75,7 @@ def _get_dummy_input(self, input_type, input_image_size):
73
75
})).SerializeToString ()
74
76
return [example ]
75
77
elif input_type == 'tflite' :
76
- return tf .zeros ((1 , height , width , 3 ), dtype = np .float32 )
78
+ return tf .zeros ((1 , height , width , num_channels ), dtype = np .float32 )
77
79
78
80
@parameterized .parameters (
79
81
('image_tensor' , False , [112 , 112 ], False ),
@@ -105,7 +107,7 @@ def test_export(self, input_type, rescale_output, input_image_size,
105
107
imported = tf .saved_model .load (tmp_dir )
106
108
segmentation_fn = imported .signatures ['serving_default' ]
107
109
108
- images = self ._get_dummy_input (input_type , input_image_size )
110
+ images = self ._get_dummy_input (input_type , input_image_size , num_channels = 3 )
109
111
if input_type != 'tflite' :
110
112
processed_images , _ = tf .nest .map_structure (
111
113
tf .stop_gradient ,
@@ -128,6 +130,68 @@ def test_export(self, input_type, rescale_output, input_image_size,
128
130
out = segmentation_fn (tf .constant (images ))
129
131
self .assertAllClose (out ['logits' ].numpy (), expected_output .numpy ())
130
132
133
+ @parameterized .parameters (
134
+ ('image_tensor' ,),
135
+ ('tflite' ,),
136
+ )
137
+ def test_export_with_extra_input_channels (self , input_type ):
138
+ tmp_dir = self .get_temp_dir ()
139
+ num_channels = 6
140
+ params = exp_factory .get_exp_config ('mnv2_deeplabv3_pascal' )
141
+ params .task .init_checkpoint = None
142
+ params .task .model .input_size = [112 , 112 , num_channels ]
143
+ params .task .export_config .rescale_output = False
144
+ params .task .train_data .preserve_aspect_ratio = False
145
+ params .task .train_data .image_feature .mean = [0.5 ] * num_channels
146
+ params .task .train_data .image_feature .stddev = [0.5 ] * num_channels
147
+ params .task .train_data .image_feature .num_channels = num_channels
148
+ module = semantic_segmentation .SegmentationModule (
149
+ params ,
150
+ batch_size = 1 ,
151
+ input_image_size = [112 , 112 ],
152
+ input_type = input_type ,
153
+ num_channels = num_channels ,
154
+ )
155
+
156
+ self ._export_from_module (module , input_type , tmp_dir )
157
+
158
+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , 'saved_model.pb' )))
159
+ self .assertTrue (
160
+ os .path .exists (os .path .join (tmp_dir , 'variables' , 'variables.index' ))
161
+ )
162
+ self .assertTrue (
163
+ os .path .exists (
164
+ os .path .join (tmp_dir , 'variables' , 'variables.data-00000-of-00001' )
165
+ )
166
+ )
167
+
168
+ imported = tf .saved_model .load (tmp_dir )
169
+ segmentation_fn = imported .signatures ['serving_default' ]
170
+
171
+ images = self ._get_dummy_input (input_type , [112 , 112 ], num_channels )
172
+
173
+ if input_type != 'tflite' :
174
+ processed_images , _ = tf .nest .map_structure (
175
+ tf .stop_gradient ,
176
+ tf .map_fn (
177
+ module ._build_inputs ,
178
+ elems = tf .zeros ((1 , 112 , 112 , num_channels ), dtype = tf .uint8 ),
179
+ fn_output_signature = (
180
+ tf .TensorSpec (
181
+ shape = [112 , 112 , num_channels ], dtype = tf .float32
182
+ ),
183
+ tf .TensorSpec (shape = [4 , 2 ], dtype = tf .float32 ),
184
+ ),
185
+ ),
186
+ )
187
+ else :
188
+ processed_images = images
189
+
190
+ logits = module .model (processed_images , training = False )['logits' ]
191
+ expected_output = tf .image .resize (logits , [112 , 112 ], method = 'bilinear' )
192
+ out = segmentation_fn (tf .constant (images ))
193
+ self .assertAllClose (out ['logits' ].numpy (), expected_output .numpy ())
194
+
131
195
def test_export_invalid_batch_size (self ):
132
196
batch_size = 3
133
197
tmp_dir = self .get_temp_dir ()
0 commit comments