@@ -183,7 +183,7 @@ def build_metrics(self, training: bool = True):
183
183
num_classes = self .task_config .model .num_classes ,
184
184
rescale_predictions = False ,
185
185
dtype = tf .float32 ))
186
- if self .task_config .model .mask_scoring_head :
186
+ if self .task_config .model .get ( ' mask_scoring_head' ) :
187
187
metrics .append (
188
188
tf .keras .metrics .MeanSquaredError (name = 'mask_scores_mse' ))
189
189
else :
@@ -193,7 +193,7 @@ def build_metrics(self, training: bool = True):
193
193
rescale_predictions = not self .task_config .validation_data
194
194
.resize_eval_groundtruth ,
195
195
dtype = tf .float32 )
196
- if self .task_config .validation_data .resize_eval_groundtruth and self .task_config .model .mask_scoring_head :
196
+ if self .task_config .validation_data .resize_eval_groundtruth and self .task_config .model .get ( ' mask_scoring_head' ): # pylint: disable=line-too-long
197
197
# Masks scores metric can only be computed if labels are scaled to match
198
198
# preticted mask scores.
199
199
metrics .append (
@@ -232,6 +232,8 @@ def train_step(self,
232
232
num_replicas = tf .distribute .get_strategy ().num_replicas_in_sync
233
233
with tf .GradientTape () as tape :
234
234
outputs = model (features , training = True )
235
+ if isinstance (outputs , tf .Tensor ):
236
+ outputs = {'logits' : outputs }
235
237
# Casting output layer as float32 is necessary when mixed_precision is
236
238
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
237
239
outputs = tf .nest .map_structure (
@@ -287,6 +289,8 @@ def validation_step(self,
287
289
features , input_partition_dims )
288
290
289
291
outputs = self .inference_step (features , model )
292
+ if isinstance (outputs , tf .Tensor ):
293
+ outputs = {'logits' : outputs }
290
294
outputs = tf .nest .map_structure (lambda x : tf .cast (x , tf .float32 ), outputs )
291
295
292
296
if self .task_config .validation_data .resize_eval_groundtruth :
0 commit comments