Skip to content

Commit 48b4b57

Browse files
arashwantensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 416953858
1 parent 8280799 commit 48b4b57

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

official/vision/beta/tasks/semantic_segmentation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def build_metrics(self, training: bool = True):
183183
num_classes=self.task_config.model.num_classes,
184184
rescale_predictions=False,
185185
dtype=tf.float32))
186-
if self.task_config.model.mask_scoring_head:
186+
if self.task_config.model.get('mask_scoring_head'):
187187
metrics.append(
188188
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
189189
else:
@@ -193,7 +193,7 @@ def build_metrics(self, training: bool = True):
193193
rescale_predictions=not self.task_config.validation_data
194194
.resize_eval_groundtruth,
195195
dtype=tf.float32)
196-
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.mask_scoring_head:
196+
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.get('mask_scoring_head'): # pylint: disable=line-too-long
197197
# Masks scores metric can only be computed if labels are scaled to match
198198
# preticted mask scores.
199199
metrics.append(
@@ -232,6 +232,8 @@ def train_step(self,
232232
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
233233
with tf.GradientTape() as tape:
234234
outputs = model(features, training=True)
235+
if isinstance(outputs, tf.Tensor):
236+
outputs = {'logits': outputs}
235237
# Casting output layer as float32 is necessary when mixed_precision is
236238
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
237239
outputs = tf.nest.map_structure(
@@ -287,6 +289,8 @@ def validation_step(self,
287289
features, input_partition_dims)
288290

289291
outputs = self.inference_step(features, model)
292+
if isinstance(outputs, tf.Tensor):
293+
outputs = {'logits': outputs}
290294
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
291295

292296
if self.task_config.validation_data.resize_eval_groundtruth:

0 commit comments

Comments
 (0)