@@ -63,10 +63,9 @@ def _create_replicated_step(self,
63
63
trainable_variables )
64
64
logging .info ('Filter trainable variables from %d to %d' ,
65
65
len (model .trainable_variables ), len (trainable_variables ))
66
- _update_state = lambda labels , outputs : None
66
+ update_state_fn = lambda labels , outputs : None
67
67
if isinstance (metric , tf .keras .metrics .Metric ):
68
- _update_state = lambda labels , outputs : metric .update_state (
69
- labels , outputs )
68
+ update_state_fn = metric .update_state
70
69
else :
71
70
logging .error ('Detection: train metric is not an instance of '
72
71
'tf.keras.metrics.Metric.' )
@@ -82,10 +81,11 @@ def _replicated_step(inputs):
82
81
for k , v in all_losses .items ():
83
82
losses [k ] = tf .reduce_mean (v )
84
83
per_replica_loss = losses ['total_loss' ] / strategy .num_replicas_in_sync
85
- _update_state (labels , outputs )
84
+ update_state_fn (labels , outputs )
86
85
87
86
grads = tape .gradient (per_replica_loss , trainable_variables )
88
- optimizer .apply_gradients (zip (grads , trainable_variables ))
87
+ clipped_grads , _ = tf .clip_by_global_norm (grads , clip_norm = 1.0 )
88
+ optimizer .apply_gradients (zip (clipped_grads , trainable_variables ))
89
89
return losses
90
90
91
91
return _replicated_step
0 commit comments