Skip to content

Commit 912b54b

Browse files
tensorflower-gardeneraman2930
authored andcommitted
Adding gradient clipping for detection models.
PiperOrigin-RevId: 365639389
1 parent e3ecfed commit 912b54b

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

official/vision/detection/configs/shapemask_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from official.modeling.hyperparams import params_dict
1818
from official.vision.detection.configs import base_config
1919

20-
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
20+
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
2121

2222
SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
2323
SHAPEMASK_CFG.override({

official/vision/detection/executor/detection_executor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ def _create_replicated_step(self,
6363
trainable_variables)
6464
logging.info('Filter trainable variables from %d to %d',
6565
len(model.trainable_variables), len(trainable_variables))
66-
_update_state = lambda labels, outputs: None
66+
update_state_fn = lambda labels, outputs: None
6767
if isinstance(metric, tf.keras.metrics.Metric):
68-
_update_state = lambda labels, outputs: metric.update_state(
69-
labels, outputs)
68+
update_state_fn = metric.update_state
7069
else:
7170
logging.error('Detection: train metric is not an instance of '
7271
'tf.keras.metrics.Metric.')
@@ -82,10 +81,11 @@ def _replicated_step(inputs):
8281
for k, v in all_losses.items():
8382
losses[k] = tf.reduce_mean(v)
8483
per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
85-
_update_state(labels, outputs)
84+
update_state_fn(labels, outputs)
8685

8786
grads = tape.gradient(per_replica_loss, trainable_variables)
88-
optimizer.apply_gradients(zip(grads, trainable_variables))
87+
clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
88+
optimizer.apply_gradients(zip(clipped_grads, trainable_variables))
8989
return losses
9090

9191
return _replicated_step

0 commit comments

Comments
 (0)