Skip to content

Commit 0327186

Browse files
yeqinglitensorflower-gardener
authored andcommitted
Adds a video classification export module for example. The module is subject to change in the near future.
PiperOrigin-RevId: 380907598
1 parent a964b89 commit 0327186

File tree

4 files changed

+314
-1
lines changed

4 files changed

+314
-1
lines changed

official/vision/beta/serving/export_saved_model_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from official.vision.beta.serving import detection
2828
from official.vision.beta.serving import image_classification
2929
from official.vision.beta.serving import semantic_segmentation
30+
from official.vision.beta.serving import video_classification
3031

3132

3233
def export_inference_graph(
@@ -99,6 +100,13 @@ def export_inference_graph(
99100
batch_size=batch_size,
100101
input_image_size=input_image_size,
101102
num_channels=num_channels)
103+
elif isinstance(params.task,
104+
configs.video_classification.VideoClassificationTask):
105+
export_module = video_classification.VideoClassificationModule(
106+
params=params,
107+
batch_size=batch_size,
108+
input_image_size=input_image_size,
109+
num_channels=num_channels)
102110
else:
103111
raise ValueError('Export module not implemented for {} task.'.format(
104112
type(params.task)))

official/vision/beta/serving/image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Lint as: python3
16-
"""Detection input and model functions for serving/inference."""
16+
"""Image classification input and model functions for serving/inference."""
1717

1818
import tensorflow as tf
1919

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Video classification input and model functions for serving/inference."""
17+
from typing import Mapping, Dict, Text
18+
19+
import tensorflow as tf
20+
21+
from official.vision.beta.dataloaders import video_input
22+
from official.vision.beta.serving import export_base
23+
from official.vision.beta.tasks import video_classification
24+
25+
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
26+
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
27+
28+
29+
class VideoClassificationModule(export_base.ExportModule):
30+
"""Video classification Module."""
31+
32+
def _build_model(self):
33+
input_params = self.params.task.train_data
34+
self._num_frames = input_params.feature_shape[0]
35+
self._stride = input_params.temporal_stride
36+
self._min_resize = input_params.min_image_size
37+
self._crop_size = input_params.feature_shape[1]
38+
39+
self._output_audio = input_params.output_audio
40+
task = video_classification.VideoClassificationTask(self.params.task)
41+
return task.build_model()
42+
43+
def _decode_tf_example(self, encoded_inputs: tf.Tensor):
44+
sequence_description = {
45+
# Each image is a string encoding JPEG.
46+
video_input.IMAGE_KEY:
47+
tf.io.FixedLenSequenceFeature((), tf.string),
48+
}
49+
if self._output_audio:
50+
sequence_description[self._params.task.validation_data.audio_feature] = (
51+
tf.io.VarLenFeature(dtype=tf.float32))
52+
_, decoded_tensors = tf.io.parse_single_sequence_example(
53+
encoded_inputs, {}, sequence_description)
54+
for key, value in decoded_tensors.items():
55+
if isinstance(value, tf.SparseTensor):
56+
decoded_tensors[key] = tf.sparse.to_dense(value)
57+
return decoded_tensors
58+
59+
def _preprocess_image(self, image):
60+
image = video_input.process_image(
61+
image=image,
62+
is_training=False,
63+
num_frames=self._num_frames,
64+
stride=self._stride,
65+
num_test_clips=1,
66+
min_resize=self._min_resize,
67+
crop_size=self._crop_size,
68+
num_crops=1)
69+
image = tf.cast(image, tf.float32) # Use config.
70+
features = {'image': image}
71+
return features
72+
73+
def _preprocess_audio(self, audio):
74+
features = {}
75+
audio = tf.cast(audio, dtype=tf.float32) # Use config.
76+
audio = video_input.preprocess_ops_3d.sample_sequence(
77+
audio, 20, random=False, stride=1)
78+
audio = tf.ensure_shape(
79+
audio, self._params.task.validation_data.audio_feature_shape)
80+
features['audio'] = audio
81+
return features
82+
83+
@tf.function
84+
def inference_from_tf_example(
85+
self, encoded_inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
86+
with tf.device('cpu:0'):
87+
if self._output_audio:
88+
inputs = tf.map_fn(
89+
self._decode_tf_example, (encoded_inputs),
90+
fn_output_signature={
91+
video_input.IMAGE_KEY: tf.string,
92+
self._params.task.validation_data.audio_feature: tf.float32
93+
})
94+
return self.serve(inputs['image'], inputs['audio'])
95+
else:
96+
inputs = tf.map_fn(
97+
self._decode_tf_example, (encoded_inputs),
98+
fn_output_signature={
99+
video_input.IMAGE_KEY: tf.string,
100+
})
101+
return self.serve(inputs[video_input.IMAGE_KEY], tf.zeros([1, 1]))
102+
103+
@tf.function
104+
def inference_from_image_tensors(
105+
self, input_frames: tf.Tensor) -> Mapping[str, tf.Tensor]:
106+
return self.serve(input_frames, tf.zeros([1, 1]))
107+
108+
@tf.function
109+
def inference_from_image_audio_tensors(
110+
self, input_frames: tf.Tensor,
111+
input_audio: tf.Tensor) -> Mapping[str, tf.Tensor]:
112+
return self.serve(input_frames, input_audio)
113+
114+
@tf.function
115+
def inference_from_image_bytes(self, inputs: tf.Tensor):
116+
raise NotImplementedError(
117+
'Video classification do not support image bytes input.')
118+
119+
def serve(self, input_frames: tf.Tensor, input_audio: tf.Tensor):
120+
"""Cast image to float and run inference.
121+
122+
Args:
123+
input_frames: uint8 Tensor of shape [batch_size, None, None, 3]
124+
input_audio: float32
125+
126+
Returns:
127+
Tensor holding classification output logits.
128+
"""
129+
with tf.device('cpu:0'):
130+
inputs = tf.map_fn(
131+
self._preprocess_image, (input_frames),
132+
fn_output_signature={
133+
'image': tf.float32,
134+
})
135+
if self._output_audio:
136+
inputs.update(
137+
tf.map_fn(
138+
self._preprocess_audio, (input_audio),
139+
fn_output_signature={'audio': tf.float32}))
140+
logits = self.inference_step(inputs)
141+
if self.params.task.train_data.is_multilabel:
142+
probs = tf.math.sigmoid(logits)
143+
else:
144+
probs = tf.nn.softmax(logits)
145+
return {'logits': logits, 'probs': probs}
146+
147+
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
148+
"""Gets defined function signatures.
149+
150+
Args:
151+
function_keys: A dictionary with keys as the function to create signature
152+
for and values as the signature keys when returns.
153+
154+
Returns:
155+
A dictionary with key as signature key and value as concrete functions
156+
that can be used for tf.saved_model.save.
157+
"""
158+
signatures = {}
159+
for key, def_name in function_keys.items():
160+
if key == 'image_tensor':
161+
input_signature = tf.TensorSpec(
162+
shape=[self._batch_size] + self._input_image_size + [3],
163+
dtype=tf.uint8,
164+
name='INPUT_FRAMES')
165+
signatures[
166+
def_name] = self.inference_from_image_tensors.get_concrete_function(
167+
input_signature)
168+
elif key == 'frames_audio':
169+
input_signature = [
170+
tf.TensorSpec(
171+
shape=[self._batch_size] + self._input_image_size + [3],
172+
dtype=tf.uint8,
173+
name='INPUT_FRAMES'),
174+
tf.TensorSpec(
175+
shape=[self._batch_size] +
176+
self.params.task.train_data.audio_feature_shape,
177+
dtype=tf.float32,
178+
name='INPUT_AUDIO')
179+
]
180+
signatures[
181+
def_name] = self.inference_from_image_audio_tensors.get_concrete_function(
182+
input_signature)
183+
elif key == 'serve_examples' or key == 'tf_example':
184+
input_signature = tf.TensorSpec(
185+
shape=[self._batch_size], dtype=tf.string)
186+
signatures[
187+
def_name] = self.inference_from_tf_example.get_concrete_function(
188+
input_signature)
189+
else:
190+
raise ValueError('Unrecognized `input_type`')
191+
return signatures
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
17+
# import io
18+
import os
19+
import random
20+
21+
from absl.testing import parameterized
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
from official.common import registry_imports # pylint: disable=unused-import
26+
from official.core import exp_factory
27+
from official.vision.beta.dataloaders import tfexample_utils
28+
from official.vision.beta.serving import video_classification
29+
30+
31+
class VideoClassificationTest(tf.test.TestCase, parameterized.TestCase):
32+
33+
def _get_classification_module(self):
34+
params = exp_factory.get_exp_config('video_classification_ucf101')
35+
params.task.train_data.feature_shape = (8, 64, 64, 3)
36+
params.task.validation_data.feature_shape = (8, 64, 64, 3)
37+
params.task.model.backbone.resnet_3d.model_id = 50
38+
classification_module = video_classification.VideoClassificationModule(
39+
params, batch_size=1, input_image_size=[8, 64, 64])
40+
return classification_module
41+
42+
def _export_from_module(self, module, input_type, save_directory):
43+
signatures = module.get_inference_signatures(
44+
{input_type: 'serving_default'})
45+
tf.saved_model.save(module, save_directory, signatures=signatures)
46+
47+
def _get_dummy_input(self, input_type, module=None):
48+
"""Get dummy input for the given input type."""
49+
50+
if input_type == 'image_tensor':
51+
images = np.random.randint(
52+
low=0, high=255, size=(1, 8, 64, 64, 3), dtype=np.uint8)
53+
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
54+
return images, images
55+
elif input_type == 'tf_example':
56+
example = tfexample_utils.make_video_test_example(
57+
image_shape=(64, 64, 3),
58+
audio_shape=(20, 128),
59+
label=random.randint(0, 100)).SerializeToString()
60+
images = tf.nest.map_structure(
61+
tf.stop_gradient,
62+
tf.map_fn(
63+
module._decode_tf_example,
64+
elems=tf.constant([example]),
65+
fn_output_signature={
66+
video_classification.video_input.IMAGE_KEY: tf.string,
67+
}))
68+
images = images[video_classification.video_input.IMAGE_KEY]
69+
return [example], images
70+
else:
71+
raise ValueError(f'{input_type}')
72+
73+
@parameterized.parameters(
74+
{'input_type': 'image_tensor'},
75+
{'input_type': 'tf_example'},
76+
)
77+
def test_export(self, input_type):
78+
tmp_dir = self.get_temp_dir()
79+
module = self._get_classification_module()
80+
81+
self._export_from_module(module, input_type, tmp_dir)
82+
83+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
84+
self.assertTrue(
85+
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
86+
self.assertTrue(
87+
os.path.exists(
88+
os.path.join(tmp_dir, 'variables',
89+
'variables.data-00000-of-00001')))
90+
91+
imported = tf.saved_model.load(tmp_dir)
92+
classification_fn = imported.signatures['serving_default']
93+
94+
images, images_tensor = self._get_dummy_input(input_type, module)
95+
processed_images = tf.nest.map_structure(
96+
tf.stop_gradient,
97+
tf.map_fn(
98+
module._preprocess_image,
99+
elems=images_tensor,
100+
fn_output_signature={
101+
'image': tf.float32,
102+
}))
103+
expected_logits = module.model(processed_images, training=False)
104+
expected_prob = tf.nn.softmax(expected_logits)
105+
out = classification_fn(tf.constant(images))
106+
107+
# The imported model should contain any trackable attrs that the original
108+
# model had.
109+
self.assertAllClose(out['logits'].numpy(), expected_logits.numpy())
110+
self.assertAllClose(out['probs'].numpy(), expected_prob.numpy())
111+
112+
113+
if __name__ == '__main__':
114+
tf.test.main()

0 commit comments

Comments
 (0)