23
23
24
24
from kerastuner .engine import hypermodel as hypermodel_module
25
25
from kerastuner .engine import hyperparameters as hp_module
26
+ from kerastuner .engine import metrics_tracking
26
27
from kerastuner .engine import oracle as oracle_module
27
28
from kerastuner .engine import trial as trial_module
28
29
from kerastuner .engine import tuner as tuner_module
@@ -222,7 +223,6 @@ def update_trial(self,
222
223
"""Used by a worker to report the status of a trial."""
223
224
# Constructs the measurement.
224
225
# Adds the measurement of the objective functions to a trial.
225
- super (CloudOracle , self ).update_trial (trial_id , metrics , step )
226
226
elapsed_secs = time .time () - self ._start_time
227
227
if elapsed_secs < 0 or step < 0 :
228
228
raise ValueError (
@@ -234,10 +234,17 @@ def update_trial(self,
234
234
metric_list = []
235
235
for ob in self ._get_objective ():
236
236
if ob .name not in metrics :
237
+ ob_name = ob .name .replace ("val_" , "" )
238
+ if ob_name in metrics :
239
+ metric_list .append (
240
+ {"metric" : ob_name ,
241
+ "value" : float (metrics .get (ob_name ))}
242
+ )
237
243
tf .get_logger ().info (
238
244
'Objective "{}" is not found in metrics.' .format (ob .name )
239
245
)
240
246
continue
247
+
241
248
metric_list .append (
242
249
{"metric" : ob .name , "value" : float (metrics .get (ob .name ))}
243
250
)
@@ -246,7 +253,16 @@ def update_trial(self,
246
253
step , elapsed_secs , metric_list , trial_id
247
254
)
248
255
256
+ # Ensure metrics of trials are updated locally.
249
257
kerastuner_trial = self .trials [trial_id ]
258
+ for metric_name , metric_value in metrics .items ():
259
+ if not kerastuner_trial .metrics .exists (metric_name ):
260
+ direction = metrics_tracking .infer_metric_direction (
261
+ metric_name )
262
+ kerastuner_trial .metrics .register (
263
+ metric_name , direction = direction )
264
+ kerastuner_trial .metrics .update (
265
+ metric_name , metric_value , step = step )
250
266
251
267
# Checks whether a trial should stop or not.
252
268
tf .get_logger ().info ("UpdateTrial: polls the stop decision." )
@@ -501,7 +517,10 @@ def __init__(
501
517
)
502
518
# If study_id is not provided, CloudOracle creates one. Setting the
503
519
# study_id to what CloudOracle generates, to ensure they are the same.
504
- self ._study_id = oracle .study_id
520
+ if study_id :
521
+ self ._study_id = study_id
522
+ else :
523
+ self ._study_id = oracle .study_id
505
524
self .directory = directory
506
525
507
526
def run_trial (self , trial , * fit_args , ** fit_kwargs ):
@@ -573,16 +592,17 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
573
592
574
593
# Create an instance of tensorboard DirectoryWatcher to retrieve the
575
594
# logs for this trial run
576
- log_path = os .path .join (
595
+ train_log_path = os .path .join (
577
596
self ._get_tensorboard_log_dir (trial .trial_id ), "train" )
578
597
579
598
# Tensorboard log watcher expects the path to exist
580
- tf .io .gfile .makedirs (log_path )
599
+ tf .io .gfile .makedirs (train_log_path )
581
600
582
601
tf .get_logger ().info (
583
602
f"Retrieving training logs for trial { trial .trial_id } from"
584
- f" { log_path } " )
585
- log_reader = tf_utils .get_tensorboard_log_watcher_from_path (log_path )
603
+ f" { train_log_path } " )
604
+ train_log_reader = tf_utils .get_tensorboard_log_watcher_from_path (
605
+ train_log_path )
586
606
587
607
training_metrics = _TrainingMetrics ([], {})
588
608
epoch = 0
@@ -594,7 +614,7 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
594
614
595
615
# Retrieve available metrics if any
596
616
training_metrics = self ._get_remote_training_metrics (
597
- log_reader , training_metrics .partial_epoch_metrics )
617
+ train_log_reader , training_metrics .partial_epoch_metrics )
598
618
599
619
for epoch_metrics in training_metrics .completed_epoch_metrics :
600
620
# TODO(b/169197272) Validate metrics contain oracle objective
@@ -621,7 +641,8 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
621
641
622
642
# Retrieve and report any remaining metrics
623
643
training_metrics = self ._get_remote_training_metrics (
624
- log_reader , training_metrics .partial_epoch_metrics )
644
+ log_reader = train_log_reader ,
645
+ partial_epoch_metrics = training_metrics .partial_epoch_metrics )
625
646
626
647
for epoch_metrics in training_metrics .completed_epoch_metrics :
627
648
# TODO(b/169197272) Validate metrics contain oracle objective
@@ -640,6 +661,31 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
640
661
metrics = training_metrics .partial_epoch_metrics ,
641
662
step = epoch )
642
663
664
+ # Submit validation metrics if eval_files is provided at the end of
665
+ # the trial.
666
+ if copied_fit_kwargs .get ("eval_files" ):
667
+ # Create an instance of tensorboard DirectoryWatcher to retrieve the
668
+ # logs for validation run.
669
+ val_log_path = os .path .join (
670
+ self ._get_tensorboard_log_dir (trial .trial_id ), "validation" )
671
+ # Tensorboard log watcher expects the path to exist
672
+ tf .io .gfile .makedirs (val_log_path )
673
+ tf .get_logger ().info (
674
+ f"Retrieving validation logs for trial { trial .trial_id } from"
675
+ f" { val_log_path } " )
676
+ val_log_reader = tf_utils .get_tensorboard_log_watcher_from_path (
677
+ val_log_path )
678
+ validation_metrics = _TrainingMetrics ([], {})
679
+ validation_metrics = self ._get_remote_training_metrics (
680
+ log_reader = val_log_reader ,
681
+ partial_epoch_metrics = validation_metrics .partial_epoch_metrics ,
682
+ is_validation = True )
683
+ for metric in validation_metrics .completed_epoch_metrics :
684
+ if metric :
685
+ self .oracle .update_trial (
686
+ trial_id = trial .trial_id ,
687
+ metrics = metric )
688
+
643
689
def _get_job_spec_from_config (self , job_id : Text ) -> Dict [Text , Any ]:
644
690
"""Creates a request dictionary for the CAIP training service.
645
691
@@ -676,7 +722,8 @@ def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
676
722
def _get_remote_training_metrics (
677
723
self ,
678
724
log_reader ,
679
- partial_epoch_metrics : Dict [Text , float ]
725
+ partial_epoch_metrics : Dict [Text , float ],
726
+ is_validation : Optional [bool ] = False ,
680
727
) -> _TrainingMetrics :
681
728
"""Retrieves delta epoch metrics from tensorboard logs since last run.
682
729
@@ -693,6 +740,7 @@ def _get_remote_training_metrics(
693
740
pointing to the tensorboard logs directory.
694
741
partial_epoch_metrics: Any incomplete epoch metrics from previous
695
742
runs that should be used as a starting point.
743
+ is_validation: If True, get validation metrics.
696
744
Returns:
697
745
An instance of _TrainingMetrics a Namedtuple with
698
746
- 'completed_epoch_metrics'- a list of epoch metrics for completed
@@ -709,16 +757,23 @@ def _get_remote_training_metrics(
709
757
# epoch related metrics with a "epoch_" prefix. Please refer to
710
758
# https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/callbacks.py#L2179 # pylint: disable=line-too-long
711
759
if value .tag .startswith ("epoch_" ):
712
- metric = value .tag .replace ("epoch_" , "" )
713
- # If we have already seen this metric, this is a new epoch
714
- if metric in partial_epoch_metrics :
760
+ if is_validation :
761
+ metric = value .tag .replace ("epoch_" , "val_" )
762
+ # Validation metrics are calculated on trial end.
763
+ partial_epoch_metrics [metric ] = tf .make_ndarray (
764
+ event .summary .value [0 ].tensor )
715
765
completed_epoch_metrics .append (partial_epoch_metrics )
716
- partial_epoch_metrics = {}
717
- # Note this method captures all metrics even if they are not
718
- # part of the oracle objectives. We rely on oracle to ignore
719
- # the unrelated Objectives.
720
- partial_epoch_metrics [metric ] = tf .make_ndarray (
721
- event .summary .value [0 ].tensor )
766
+ else :
767
+ metric = value .tag .replace ("epoch_" , "" )
768
+ # If this metric has been seen, this is a new epoch.
769
+ if metric in partial_epoch_metrics :
770
+ completed_epoch_metrics .append (partial_epoch_metrics )
771
+ partial_epoch_metrics = {}
772
+ # Note this method captures all metrics even if they
773
+ # are not part of the oracle objectives. We rely on
774
+ # oracle to ignore the unrelated Objectives.
775
+ partial_epoch_metrics [metric ] = tf .make_ndarray (
776
+ event .summary .value [0 ].tensor )
722
777
return _TrainingMetrics (completed_epoch_metrics , partial_epoch_metrics )
723
778
724
779
def load_model (self , trial ):
0 commit comments