Skip to content

Commit 7aa320c

Browse files
chaoyan1037tensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 481234282
1 parent 9bcbe96 commit 7aa320c

File tree

1 file changed

+99
-40
lines changed

1 file changed

+99
-40
lines changed

official/projects/yt8m/tasks/yt8m_task.py

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Video classification task definition."""
16+
from typing import Dict, List, Optional, Tuple
17+
1618
from absl import logging
1719
import tensorflow as tf
1820

@@ -95,31 +97,46 @@ def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
9597

9698
return dataset
9799

98-
def build_losses(self, labels, model_outputs, aux_losses=None):
100+
def build_losses(self,
101+
labels,
102+
model_outputs,
103+
label_weights=None,
104+
aux_losses=None):
99105
"""Sigmoid Cross Entropy.
100106
101107
Args:
102108
labels: tensor containing truth labels.
103109
model_outputs: output logits of the classifier.
110+
label_weights: optional tensor of label weights.
104111
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
105112
keras.Model.
106113
107114
Returns:
108-
Tensors: The total loss, model loss tensors.
115+
A dict of tensors contains total loss, model loss tensors.
109116
"""
110117
losses_config = self.task_config.losses
111118
model_loss = tf.keras.losses.binary_crossentropy(
112119
labels,
113120
model_outputs,
114121
from_logits=losses_config.from_logits,
115-
label_smoothing=losses_config.label_smoothing)
122+
label_smoothing=losses_config.label_smoothing,
123+
axis=None)
124+
125+
if label_weights is None:
126+
model_loss = tf_utils.safe_mean(model_loss)
127+
else:
128+
model_loss = model_loss * label_weights
129+
# Manutally compute weighted mean loss.
130+
total_loss = tf.reduce_sum(model_loss)
131+
total_weight = tf.cast(
132+
tf.reduce_sum(label_weights), dtype=total_loss.dtype)
133+
model_loss = tf.math.divide_no_nan(total_loss, total_weight)
116134

117-
model_loss = tf_utils.safe_mean(model_loss)
118135
total_loss = model_loss
119136
if aux_losses:
120137
total_loss += tf.add_n(aux_losses)
121138

122-
return total_loss, model_loss
139+
return {'total_loss': total_loss, 'model_loss': model_loss}
123140

124141
def build_metrics(self, training=True):
125142
"""Gets streaming metrics for training/validation.
@@ -130,10 +147,10 @@ def build_metrics(self, training=True):
130147
top_n: A positive Integer specifying the average precision at n, or None
131148
to use all provided data points.
132149
Args:
133-
training: bool value, true for training mode, false for eval/validation.
150+
training: Bool value, true for training mode, false for eval/validation.
134151
135152
Returns:
136-
list of strings that indicate metrics to be used
153+
A list of strings that indicate metrics to be used.
137154
"""
138155
metrics = []
139156
metric_names = ['total_loss', 'model_loss']
@@ -149,15 +166,48 @@ def build_metrics(self, training=True):
149166

150167
return metrics
151168

169+
def process_metrics(self,
170+
metrics: List[tf.keras.metrics.Metric],
171+
labels: tf.Tensor,
172+
outputs: tf.Tensor,
173+
model_losses: Optional[Dict[str, tf.Tensor]] = None,
174+
label_weights: Optional[tf.Tensor] = None,
175+
training: bool = True,
176+
**kwargs) -> Dict[str, Tuple[tf.Tensor, ...]]:
177+
"""Updates metrics.
178+
179+
Args:
180+
metrics: Evaluation metrics to be updated.
181+
labels: A tensor containing truth labels.
182+
outputs: Model output logits of the classifier.
183+
model_losses: An optional dict of model losses.
184+
label_weights: Optional label weights, can be broadcast into shape of
185+
outputs/labels.
186+
training: Bool indicates if in training mode.
187+
**kwargs: Additional input arguments.
188+
189+
Returns:
190+
Updated dict of metrics log.
191+
"""
192+
if model_losses is None:
193+
model_losses = {}
194+
195+
logs = {}
196+
if not training:
197+
logs.update({self.avg_prec_metric.name: (labels, outputs)})
198+
199+
for m in metrics:
200+
m.update_state(model_losses[m.name])
201+
logs[m.name] = m.result()
202+
return logs
203+
152204
def train_step(self, inputs, model, optimizer, metrics=None):
153205
"""Does forward and backward.
154206
155207
Args:
156-
inputs: a dictionary of input tensors. output_dict = {
157-
"video_ids": batch_video_ids,
158-
"video_matrix": batch_video_matrix,
159-
"labels": batch_labels,
160-
"num_frames": batch_frames, }
208+
inputs: a dictionary of input tensors. output_dict = { "video_ids":
209+
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
210+
batch_labels, "num_frames": batch_frames, }
161211
model: the model, forward pass definition.
162212
optimizer: the optimizer for this training step.
163213
metrics: a nested structure of metrics objects.
@@ -167,6 +217,7 @@ def train_step(self, inputs, model, optimizer, metrics=None):
167217
"""
168218
features, labels = inputs['video_matrix'], inputs['labels']
169219
num_frames = inputs['num_frames']
220+
label_weights = inputs.get('label_weights', None)
170221

171222
# sample random frames / random sequence
172223
num_frames = tf.cast(num_frames, tf.float32)
@@ -183,26 +234,28 @@ def train_step(self, inputs, model, optimizer, metrics=None):
183234
# Casting output layer as float32 is necessary when mixed_precision is
184235
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
185236
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
186-
187237
# Computes per-replica loss
188-
loss, model_loss = self.build_losses(
189-
model_outputs=outputs, labels=labels, aux_losses=model.losses)
238+
all_losses = self.build_losses(
239+
model_outputs=outputs,
240+
labels=labels,
241+
label_weights=label_weights,
242+
aux_losses=model.losses)
243+
244+
loss = all_losses['total_loss']
190245
# Scales loss as the default gradients allreduce performs sum inside the
191246
# optimizer.
192247
scaled_loss = loss / num_replicas
193248

194249
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
195250
# scaled for numerical stability.
196-
if isinstance(optimizer,
197-
tf.keras.mixed_precision.LossScaleOptimizer):
251+
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
198252
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
199253

200254
tvars = model.trainable_variables
201255
grads = tape.gradient(scaled_loss, tvars)
202256
# Scales back gradient before apply_gradients when LossScaleOptimizer is
203257
# used.
204-
if isinstance(optimizer,
205-
tf.keras.mixed_precision.LossScaleOptimizer):
258+
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
206259
grads = optimizer.get_unscaled_gradients(grads)
207260

208261
# Apply gradient clipping.
@@ -213,24 +266,24 @@ def train_step(self, inputs, model, optimizer, metrics=None):
213266

214267
logs = {self.loss: loss}
215268

216-
all_losses = {'total_loss': loss, 'model_loss': model_loss}
217-
218-
if metrics:
219-
for m in metrics:
220-
m.update_state(all_losses[m.name])
221-
logs.update({m.name: m.result()})
269+
logs.update(
270+
self.process_metrics(
271+
metrics,
272+
labels=labels,
273+
outputs=outputs,
274+
model_losses=all_losses,
275+
label_weights=label_weights,
276+
training=True))
222277

223278
return logs
224279

225280
def validation_step(self, inputs, model, metrics=None):
226281
"""Validatation step.
227282
228283
Args:
229-
inputs: a dictionary of input tensors. output_dict = {
230-
"video_ids": batch_video_ids,
231-
"video_matrix": batch_video_matrix,
232-
"labels": batch_labels,
233-
"num_frames": batch_frames, }
284+
inputs: a dictionary of input tensors. output_dict = { "video_ids":
285+
batch_video_ids, "video_matrix": batch_video_matrix, "labels":
286+
batch_labels, "num_frames": batch_frames, }
234287
model: the model, forward definition
235288
metrics: a nested structure of metrics objects.
236289
@@ -239,6 +292,7 @@ def validation_step(self, inputs, model, metrics=None):
239292
"""
240293
features, labels = inputs['video_matrix'], inputs['labels']
241294
num_frames = inputs['num_frames']
295+
label_weights = inputs.get('label_weights', None)
242296

243297
# sample random frames (None, 5, 1152) -> (None, 30, 1152)
244298
sample_frames = self.task_config.validation_data.num_frames
@@ -252,23 +306,28 @@ def validation_step(self, inputs, model, metrics=None):
252306
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
253307
if self.task_config.validation_data.segment_labels:
254308
# workaround to ignore the unrated labels.
255-
outputs *= inputs['label_weights']
309+
outputs *= label_weights
256310
# remove padding
257311
outputs = outputs[~tf.reduce_all(labels == -1, axis=1)]
258312
labels = labels[~tf.reduce_all(labels == -1, axis=1)]
259-
loss, model_loss = self.build_losses(
260-
model_outputs=outputs, labels=labels, aux_losses=model.losses)
261313

262-
logs = {self.loss: loss}
314+
all_losses = self.build_losses(
315+
labels=labels,
316+
model_outputs=outputs,
317+
label_weights=label_weights,
318+
aux_losses=model.losses)
263319

264-
all_losses = {'total_loss': loss, 'model_loss': model_loss}
320+
logs = {self.loss: all_losses['total_loss']}
265321

266-
logs.update({self.avg_prec_metric.name: (labels, outputs)})
322+
logs.update(
323+
self.process_metrics(
324+
metrics,
325+
labels=labels,
326+
outputs=outputs,
327+
model_losses=all_losses,
328+
label_weights=inputs.get('label_weights', None),
329+
training=False))
267330

268-
if metrics:
269-
for m in metrics:
270-
m.update_state(all_losses[m.name])
271-
logs.update({m.name: m.result()})
272331
return logs
273332

274333
def inference_step(self, inputs, model):

0 commit comments

Comments
 (0)