Skip to content

Commit a44092b

Browse files
No public description
PiperOrigin-RevId: 567692565
1 parent 0d75a9d commit a44092b

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

official/projects/yt8m/configs/yt8m.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class DataConfig(cfg.DataConfig):
8080
sample_random_frames: bool = True
8181
# Sample random frames if not None. No sampling in inference.
8282
num_sample_frames: Optional[int] = 300
83+
input_per_feature_l2_norm: bool = False
8384
prefetch_buffer_size: int = 100
8485
shuffle_buffer_size: int = 100
8586
num_classes: int = 3862

official/projects/yt8m/dataloaders/yt8m_input.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
157157
return output_dict
158158

159159

160+
# TODO(allenyan, zhengxu): Adds a unit test for this function.
160161
def _get_video_matrix(features, feature_size, dtype, max_frames,
161162
max_quantized_value, min_quantized_value):
162163
"""Decodes features from an input string and quantizes it.
@@ -187,8 +188,16 @@ def _get_video_matrix(features, feature_size, dtype, max_frames,
187188
return feature_matrix, num_frames
188189

189190

190-
def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
191-
max_frames, max_quantized_value, min_quantized_value):
191+
def _concat_features(
192+
features,
193+
feature_names,
194+
feature_sizes,
195+
feature_dtypes,
196+
max_frames,
197+
max_quantized_value,
198+
min_quantized_value,
199+
per_feature_l2_norm=False,
200+
):
192201
"""Loads (potentially) different types of features and concatenates them.
193202
194203
Args:
@@ -199,6 +208,7 @@ def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
199208
max_frames: number of frames in the sequence
200209
max_quantized_value: the maximum of the quantized value.
201210
min_quantized_value: the minimum of the quantized value.
211+
per_feature_l2_norm: whether to l2 normalize each feature.
202212
203213
Returns:
204214
video_matrix: different features concatenated into one matrix
@@ -225,6 +235,8 @@ def _concat_features(features, feature_names, feature_sizes, feature_dtypes,
225235
min_quantized_value)
226236
num_common_frames = tf.math.minimum(num_frames_in_this_feature,
227237
num_common_frames)
238+
if per_feature_l2_norm:
239+
feature_matrix = tf.math.l2_normalize(feature_matrix, axis=-1)
228240
feature_matrices[i] = feature_matrix
229241

230242
for i in range(num_features):
@@ -347,14 +359,15 @@ def __init__(
347359
self._num_sample_frames = input_params.num_sample_frames
348360
self._max_quantized_value = max_quantized_value
349361
self._min_quantized_value = min_quantized_value
362+
self._input_per_feature_l2_norm = input_params.input_per_feature_l2_norm
350363

351364
def _parse_train_data(self, decoded_tensors):
352365
"""Parses data for training."""
353366
# loads (potentially) different types of features and concatenates them
354367
video_matrix, num_frames = _concat_features(
355368
decoded_tensors, self._feature_names, self._feature_sizes,
356369
self._feature_dtypes, self._max_frames, self._max_quantized_value,
357-
self._min_quantized_value)
370+
self._min_quantized_value, self._input_per_feature_l2_norm)
358371
if not self._include_video_id and "id" in decoded_tensors:
359372
del decoded_tensors["id"]
360373

@@ -383,7 +396,7 @@ def _parse_eval_data(self, decoded_tensors):
383396
video_matrix, num_frames = _concat_features(
384397
decoded_tensors, self._feature_names, self._feature_sizes,
385398
self._feature_dtypes, self._max_frames, self._max_quantized_value,
386-
self._min_quantized_value)
399+
self._min_quantized_value, self._input_per_feature_l2_norm)
387400
if not self._include_video_id and "id" in decoded_tensors:
388401
del decoded_tensors["id"]
389402

official/projects/yt8m/dataloaders/yt8m_input_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,16 @@ def test_read_segment_level_input(self, include_video_id, num_sample_frames):
160160
if include_video_id:
161161
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size])
162162

163-
@parameterized.parameters((True, 4), (False, 4), (False, None))
163+
@parameterized.parameters(
164+
(True, 4, False),
165+
(False, 4, False),
166+
(False, None, False),
167+
(True, 4, True),
168+
(False, 4, True),
169+
(False, None, True),
170+
)
164171
def test_read_video_level_float_input(
165-
self, include_video_id, num_sample_frames
172+
self, include_video_id, num_sample_frames, per_feature_l2_norm
166173
):
167174
data_dir = os.path.join(self.get_temp_dir(), 'data2')
168175
tf.io.gfile.makedirs(data_dir)
@@ -188,6 +195,7 @@ def test_read_video_level_float_input(
188195
params.feature_from_bytes = (False, False)
189196
params.label_field = 'clip/label/index'
190197
params.include_video_id = include_video_id
198+
params.input_per_feature_l2_norm = per_feature_l2_norm
191199
reader = self.create_input_reader(params)
192200

193201
dataset = reader.read()
@@ -211,6 +219,9 @@ def test_read_video_level_float_input(
211219
'FEATURE/feature/floats'].feature[0].float_list.value
212220
expected_labels = examples[0].context.feature[
213221
params.label_field].int64_list.value
222+
if per_feature_l2_norm:
223+
expected_feature = tf.math.l2_normalize(expected_feature, axis=-1)
224+
expected_context = tf.math.l2_normalize(expected_context, axis=-1)
214225
self.assertAllEqual(expected_feature,
215226
example['video_matrix'][0, 0, params.feature_sizes[0]:])
216227
self.assertAllEqual(expected_context,

0 commit comments

Comments
 (0)