@@ -177,8 +177,8 @@ def update_weights(
177177 bucket_loss += np .maximum (prof [SequenceExampleFeatureNames .regret ], 0 )
178178 losses_per_bucket .append (bucket_loss )
179179 logging .info ('Losses per bucket: %s' , losses_per_bucket )
180- losses_per_bucket_normalized = losses_per_bucket / np . max (
181- np .abs (losses_per_bucket ))
180+ losses_per_bucket_normalized = losses_per_bucket / (
181+ np .max ( np . abs (losses_per_bucket )) + 1e-6 )
182182 probs_t = self ._get_exp_gradient_step (losses_per_bucket_normalized , 1.0 )
183183 self ._round += 1
184184 self ._probs = (self ._probs * (self ._round - 1 ) + probs_t ) / self ._round
@@ -228,6 +228,7 @@ def __init__(
228228 self ._trainig_weights = TrainingWeights ()
229229 self ._features_to_remove = features_to_remove
230230 self ._global_step = 0
231+ self ._is_model_init = False
231232
232233 observation_spec , action_spec = config .get_inlining_signature_spec ()
233234 sequence_features = {
@@ -322,13 +323,12 @@ def load_dataset(self, filepaths: list[str]) -> tf.data.TFRecordDataset:
322323 self ._make_feature_label , num_processors = self ._num_processors ))
323324 dataset = dataset .unbatch ().shuffle (self ._shuffle_size ).batch (
324325 self ._batch_size , drop_remainder = True ) # 4194304
325- dataset = dataset .apply (tf .data .experimental .ignore_errors ())
326326
327327 return dataset
328328
329329 def _create_weights (self , labels , weights_arr ):
330- p_norm = min (weights_arr ) # check that this should be min
331- weights_arr = tf .map_fn ( lambda x : p_norm / x , tf . constant ( weights_arr ) )
330+ p_norm = tf . reduce_min (weights_arr )
331+ weights_arr = tf .math . divide ( p_norm , weights_arr )
332332 int_labels = tf .cast (labels , tf .int32 )
333333 return tf .gather (weights_arr , int_labels )
334334
@@ -365,6 +365,7 @@ def _update_metrics(self, y_true, y_pred, loss, weights):
365365 tf .summary .scalar (
366366 name = metric .name , data = metric .result (), step = self ._global_step )
367367
368+ @tf .function
368369 def _train_step (self , example , label , weight_labels , weights_arr ):
369370 y_true = label [:, 0 ]
370371 y_true = tf .reshape (y_true , [self ._batch_size , 1 ])
@@ -381,10 +382,15 @@ def train(self, filepaths: list[str]):
381382 """Train the model for number of the specified number of epochs."""
382383 dataset = self .load_dataset (filepaths )
383384 logging .info ('Datasets loaded from %s' , str (filepaths ))
384- input_shape = dataset .element_spec [0 ].shape [- 1 ]
385- self ._initialize_model (input_shape = input_shape )
386- self ._initialize_metrics ()
387- for _ in range (self ._epochs ):
385+ input_shape = int (dataset .element_spec [0 ].shape [- 1 ])
386+ if not self ._is_model_init :
387+ self ._initialize_model (input_shape = input_shape )
388+ self ._initialize_metrics ()
389+ self ._is_model_init = True
390+ self ._global_step = 0
391+ logging .info ('Training started' )
392+ for epoch in range (self ._epochs ):
393+ logging .info ('Epoch %s' , epoch )
388394 for metric in self ._metrics :
389395 metric .reset_states ()
390396 for step , (x_batch_train , y_batch_train ) in enumerate (dataset ):
0 commit comments