13
13
# limitations under the License.
14
14
15
15
"""Video classification task definition."""
16
+ from typing import Dict , List , Optional , Tuple
17
+
16
18
from absl import logging
17
19
import tensorflow as tf
18
20
@@ -95,31 +97,46 @@ def build_inputs(self, params: yt8m_cfg.DataConfig, input_context=None):
95
97
96
98
return dataset
97
99
98
- def build_losses (self , labels , model_outputs , aux_losses = None ):
100
+ def build_losses (self ,
101
+ labels ,
102
+ model_outputs ,
103
+ label_weights = None ,
104
+ aux_losses = None ):
99
105
"""Sigmoid Cross Entropy.
100
106
101
107
Args:
102
108
labels: tensor containing truth labels.
103
109
model_outputs: output logits of the classifier.
110
+ label_weights: optional tensor of label weights.
104
111
aux_losses: tensor containing auxiliarly loss tensors, i.e. `losses` in
105
112
keras.Model.
106
113
107
114
Returns:
108
- Tensors: The total loss, model loss tensors.
115
+ A dict of tensors contains total loss, model loss tensors.
109
116
"""
110
117
losses_config = self .task_config .losses
111
118
model_loss = tf .keras .losses .binary_crossentropy (
112
119
labels ,
113
120
model_outputs ,
114
121
from_logits = losses_config .from_logits ,
115
- label_smoothing = losses_config .label_smoothing )
122
+ label_smoothing = losses_config .label_smoothing ,
123
+ axis = None )
124
+
125
+ if label_weights is None :
126
+ model_loss = tf_utils .safe_mean (model_loss )
127
+ else :
128
+ model_loss = model_loss * label_weights
129
+ # Manutally compute weighted mean loss.
130
+ total_loss = tf .reduce_sum (model_loss )
131
+ total_weight = tf .cast (
132
+ tf .reduce_sum (label_weights ), dtype = total_loss .dtype )
133
+ model_loss = tf .math .divide_no_nan (total_loss , total_weight )
116
134
117
- model_loss = tf_utils .safe_mean (model_loss )
118
135
total_loss = model_loss
119
136
if aux_losses :
120
137
total_loss += tf .add_n (aux_losses )
121
138
122
- return total_loss , model_loss
139
+ return { ' total_loss' : total_loss , ' model_loss' : model_loss }
123
140
124
141
def build_metrics (self , training = True ):
125
142
"""Gets streaming metrics for training/validation.
@@ -130,10 +147,10 @@ def build_metrics(self, training=True):
130
147
top_n: A positive Integer specifying the average precision at n, or None
131
148
to use all provided data points.
132
149
Args:
133
- training: bool value, true for training mode, false for eval/validation.
150
+ training: Bool value, true for training mode, false for eval/validation.
134
151
135
152
Returns:
136
- list of strings that indicate metrics to be used
153
+ A list of strings that indicate metrics to be used.
137
154
"""
138
155
metrics = []
139
156
metric_names = ['total_loss' , 'model_loss' ]
@@ -149,15 +166,48 @@ def build_metrics(self, training=True):
149
166
150
167
return metrics
151
168
169
+ def process_metrics (self ,
170
+ metrics : List [tf .keras .metrics .Metric ],
171
+ labels : tf .Tensor ,
172
+ outputs : tf .Tensor ,
173
+ model_losses : Optional [Dict [str , tf .Tensor ]] = None ,
174
+ label_weights : Optional [tf .Tensor ] = None ,
175
+ training : bool = True ,
176
+ ** kwargs ) -> Dict [str , Tuple [tf .Tensor , ...]]:
177
+ """Updates metrics.
178
+
179
+ Args:
180
+ metrics: Evaluation metrics to be updated.
181
+ labels: A tensor containing truth labels.
182
+ outputs: Model output logits of the classifier.
183
+ model_losses: An optional dict of model losses.
184
+ label_weights: Optional label weights, can be broadcast into shape of
185
+ outputs/labels.
186
+ training: Bool indicates if in training mode.
187
+ **kwargs: Additional input arguments.
188
+
189
+ Returns:
190
+ Updated dict of metrics log.
191
+ """
192
+ if model_losses is None :
193
+ model_losses = {}
194
+
195
+ logs = {}
196
+ if not training :
197
+ logs .update ({self .avg_prec_metric .name : (labels , outputs )})
198
+
199
+ for m in metrics :
200
+ m .update_state (model_losses [m .name ])
201
+ logs [m .name ] = m .result ()
202
+ return logs
203
+
152
204
def train_step (self , inputs , model , optimizer , metrics = None ):
153
205
"""Does forward and backward.
154
206
155
207
Args:
156
- inputs: a dictionary of input tensors. output_dict = {
157
- "video_ids": batch_video_ids,
158
- "video_matrix": batch_video_matrix,
159
- "labels": batch_labels,
160
- "num_frames": batch_frames, }
208
+ inputs: a dictionary of input tensors. output_dict = { "video_ids":
209
+ batch_video_ids, "video_matrix": batch_video_matrix, "labels":
210
+ batch_labels, "num_frames": batch_frames, }
161
211
model: the model, forward pass definition.
162
212
optimizer: the optimizer for this training step.
163
213
metrics: a nested structure of metrics objects.
@@ -167,6 +217,7 @@ def train_step(self, inputs, model, optimizer, metrics=None):
167
217
"""
168
218
features , labels = inputs ['video_matrix' ], inputs ['labels' ]
169
219
num_frames = inputs ['num_frames' ]
220
+ label_weights = inputs .get ('label_weights' , None )
170
221
171
222
# sample random frames / random sequence
172
223
num_frames = tf .cast (num_frames , tf .float32 )
@@ -183,26 +234,28 @@ def train_step(self, inputs, model, optimizer, metrics=None):
183
234
# Casting output layer as float32 is necessary when mixed_precision is
184
235
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
185
236
outputs = tf .nest .map_structure (lambda x : tf .cast (x , tf .float32 ), outputs )
186
-
187
237
# Computes per-replica loss
188
- loss , model_loss = self .build_losses (
189
- model_outputs = outputs , labels = labels , aux_losses = model .losses )
238
+ all_losses = self .build_losses (
239
+ model_outputs = outputs ,
240
+ labels = labels ,
241
+ label_weights = label_weights ,
242
+ aux_losses = model .losses )
243
+
244
+ loss = all_losses ['total_loss' ]
190
245
# Scales loss as the default gradients allreduce performs sum inside the
191
246
# optimizer.
192
247
scaled_loss = loss / num_replicas
193
248
194
249
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
195
250
# scaled for numerical stability.
196
- if isinstance (optimizer ,
197
- tf .keras .mixed_precision .LossScaleOptimizer ):
251
+ if isinstance (optimizer , tf .keras .mixed_precision .LossScaleOptimizer ):
198
252
scaled_loss = optimizer .get_scaled_loss (scaled_loss )
199
253
200
254
tvars = model .trainable_variables
201
255
grads = tape .gradient (scaled_loss , tvars )
202
256
# Scales back gradient before apply_gradients when LossScaleOptimizer is
203
257
# used.
204
- if isinstance (optimizer ,
205
- tf .keras .mixed_precision .LossScaleOptimizer ):
258
+ if isinstance (optimizer , tf .keras .mixed_precision .LossScaleOptimizer ):
206
259
grads = optimizer .get_unscaled_gradients (grads )
207
260
208
261
# Apply gradient clipping.
@@ -213,24 +266,24 @@ def train_step(self, inputs, model, optimizer, metrics=None):
213
266
214
267
logs = {self .loss : loss }
215
268
216
- all_losses = {'total_loss' : loss , 'model_loss' : model_loss }
217
-
218
- if metrics :
219
- for m in metrics :
220
- m .update_state (all_losses [m .name ])
221
- logs .update ({m .name : m .result ()})
269
+ logs .update (
270
+ self .process_metrics (
271
+ metrics ,
272
+ labels = labels ,
273
+ outputs = outputs ,
274
+ model_losses = all_losses ,
275
+ label_weights = label_weights ,
276
+ training = True ))
222
277
223
278
return logs
224
279
225
280
def validation_step (self , inputs , model , metrics = None ):
226
281
"""Validatation step.
227
282
228
283
Args:
229
- inputs: a dictionary of input tensors. output_dict = {
230
- "video_ids": batch_video_ids,
231
- "video_matrix": batch_video_matrix,
232
- "labels": batch_labels,
233
- "num_frames": batch_frames, }
284
+ inputs: a dictionary of input tensors. output_dict = { "video_ids":
285
+ batch_video_ids, "video_matrix": batch_video_matrix, "labels":
286
+ batch_labels, "num_frames": batch_frames, }
234
287
model: the model, forward definition
235
288
metrics: a nested structure of metrics objects.
236
289
@@ -239,6 +292,7 @@ def validation_step(self, inputs, model, metrics=None):
239
292
"""
240
293
features , labels = inputs ['video_matrix' ], inputs ['labels' ]
241
294
num_frames = inputs ['num_frames' ]
295
+ label_weights = inputs .get ('label_weights' , None )
242
296
243
297
# sample random frames (None, 5, 1152) -> (None, 30, 1152)
244
298
sample_frames = self .task_config .validation_data .num_frames
@@ -252,23 +306,28 @@ def validation_step(self, inputs, model, metrics=None):
252
306
outputs = tf .nest .map_structure (lambda x : tf .cast (x , tf .float32 ), outputs )
253
307
if self .task_config .validation_data .segment_labels :
254
308
# workaround to ignore the unrated labels.
255
- outputs *= inputs [ ' label_weights' ]
309
+ outputs *= label_weights
256
310
# remove padding
257
311
outputs = outputs [~ tf .reduce_all (labels == - 1 , axis = 1 )]
258
312
labels = labels [~ tf .reduce_all (labels == - 1 , axis = 1 )]
259
- loss , model_loss = self .build_losses (
260
- model_outputs = outputs , labels = labels , aux_losses = model .losses )
261
313
262
- logs = {self .loss : loss }
314
+ all_losses = self .build_losses (
315
+ labels = labels ,
316
+ model_outputs = outputs ,
317
+ label_weights = label_weights ,
318
+ aux_losses = model .losses )
263
319
264
- all_losses = {'total_loss' : loss , 'model_loss' : model_loss }
320
+ logs = {self . loss : all_losses [ 'total_loss' ] }
265
321
266
- logs .update ({self .avg_prec_metric .name : (labels , outputs )})
322
+ logs .update (
323
+ self .process_metrics (
324
+ metrics ,
325
+ labels = labels ,
326
+ outputs = outputs ,
327
+ model_losses = all_losses ,
328
+ label_weights = inputs .get ('label_weights' , None ),
329
+ training = False ))
267
330
268
- if metrics :
269
- for m in metrics :
270
- m .update_state (all_losses [m .name ])
271
- logs .update ({m .name : m .result ()})
272
331
return logs
273
332
274
333
def inference_step (self , inputs , model ):
0 commit comments