28
28
from kerastuner .engine import tuner as tuner_module
29
29
import tensorflow as tf
30
30
31
+ from tensorboard .plugins .hparams import api as hparams_api
31
32
from tensorflow_cloud .core import deploy
32
33
from tensorflow_cloud .core import machine_config
33
34
from tensorflow_cloud .core import validate
@@ -492,8 +493,8 @@ def __init__(
492
493
super (DistributingCloudTuner , self ,).__init__ (
493
494
oracle = oracle , hypermodel = hypermodel , ** kwargs
494
495
)
495
- # If study id is not provided cloud_oracle creates ones . Setting the
496
- # study_id based on cloud oracles logic to ensure they are the same.
496
+ # If study_id is not provided, CloudOracle creates one . Setting the
497
+ # study_id to what CloudOracle generates, to ensure they are the same.
497
498
self ._study_id = oracle .study_id
498
499
self .directory = directory
499
500
@@ -519,16 +520,15 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
519
520
callbacks = fit_kwargs .pop ("callbacks" , [])
520
521
callbacks = self ._deepcopy_callbacks (callbacks )
521
522
522
- # Note run_trial does not use `TunerCallback` calls, since
523
+ # Note: run_trial does not use `TunerCallback` calls, since
523
524
# training is performed on AI Platform training remotely.
524
525
525
- # Creating a tensorboard callback with log-dir path specific for this
526
- # trail_id. The tensorboard logs are used for passing metrics back from
527
- # remote execution.
528
- self ._add_tensorboard_callback (callbacks , trial .trial_id )
526
+ # Handle TensorBoard/hyperparameter logging here. The TensorBoard
527
+ # logs are used for passing metrics back from remote execution.
528
+ self ._add_logging (callbacks , trial )
529
529
530
530
# Creating a save_model checkpoint callback with a saved model file path
531
- # specific to this trial, this is to prevent different trials from
531
+ # specific to this trial. This is to prevent different trials from
532
532
# overwriting each other.
533
533
self ._add_model_checkpoint_callback (
534
534
callbacks , trial .trial_id )
@@ -605,7 +605,9 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
605
605
if not google_api_client .wait_for_api_training_job_completion (
606
606
job_id , self ._project_id ):
607
607
raise RuntimeError (
608
- "AIP Training job failed, see logs for details at https://console.cloud.google.com/ai-platform/jobs/{}/charts/cpu?project={}" # pylint: disable=line-too-long
608
+ "AIP Training job failed, see logs for details at "
609
+ "https://console.cloud.google.com/ai-platform/jobs/"
610
+ "{}/charts/cpu?project={}"
609
611
.format (job_id , self ._project_id ))
610
612
611
613
# Retrieve and report any remaining metrics
@@ -657,7 +659,7 @@ def _get_remote_training_metrics(
657
659
self ,
658
660
log_reader ,
659
661
partial_epoch_metrics : Dict [Text , float ]
660
- )-> _TrainingMetrics :
662
+ ) -> _TrainingMetrics :
661
663
"""Retrieves delta epoch metrics from tensorboard logs since last run.
662
664
663
665
This method reports any complete epoch metrics that are available since
@@ -683,9 +685,9 @@ def _get_remote_training_metrics(
683
685
completed_epoch_metrics = []
684
686
for event in log_reader .Load ():
685
687
for value in event .summary .value :
686
- # Note tf.keras.callbacks.TensorBoard() with update_freq="epoch"
687
- # logs the epoch related metrics with a "epoch_" prefix. This is
688
- # not a requirement by tensorboard.
688
+ # Note: tf.keras.callbacks.TensorBoard.on_epoch_end () logs the
689
+ # epoch related metrics with a "epoch_" prefix. Please refer to
690
+ # https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/callbacks.py#L2179 # pylint: disable=line-too-long
689
691
if value .tag .startswith ("epoch_" ):
690
692
metric = value .tag .replace ("epoch_" , "" )
691
693
# If we have already seen this metric, this is a new epoch
@@ -708,7 +710,6 @@ def load_model(self, trial):
708
710
raise NotImplementedError ("load_model for remote run is not supported." )
709
711
710
712
def save_model (self , trial_id : int , model , step : int = 0 ):
711
-
712
713
# In remote execution models are saved automatically in Google Cloud
713
714
# Storage (GCS) bucket hence no additional actions are needed to save
714
715
# the model.
@@ -719,27 +720,58 @@ def _add_model_checkpoint_callback(self, callbacks, trial_id):
719
720
filepath = self ._get_model_checkpoint_dir (trial_id ),
720
721
save_freq = "epoch" ))
721
722
722
- def _add_tensorboard_callback (self , callbacks , trial_id ):
723
- # due to https://github.com/keras-team/keras/issues/14223 multiple
724
- # tensorboard callbacks are not supported. Removing user defined
725
- # tf.keras.callbacks.TensorBoard callback.
723
+ def _add_logging (self , callbacks , trial ):
724
+ """Add a TensorBoard callback if needed, otherwise log hyperparameters.
726
725
727
- tf .get_logger ().info (
728
- "Only one tf.keras.callbacks.TensorBoard callback is allowed, removing user defined callbacks." # pylint: disable=line-too-long
729
- )
730
- callbacks [:] = [
731
- x for x in callbacks if x .__class__ .__name__ != "TensorBoard" ]
726
+ Note: Due to https://github.com/keras-team/keras/issues/14223, multiple
727
+ TensorBoard callbacks are not supported. If user specified a TensorBoard
728
+ callback, we treat it as an intent to log the metrics, and we shall
729
+ additionally log the hyperparameters as well. Otherwise, we'll add a
730
+ TensorBoard callback to pass back the epoch related metrics from
731
+ remote execution.
732
732
733
- callbacks .append (tf .keras .callbacks .TensorBoard (
734
- log_dir = self ._get_tensorboard_log_dir (trial_id )))
733
+ Arguments:
734
+ callbacks: List of callbacks passed in to the search function.
735
+ trial: A `Trial` instance.
736
+ Raises:
737
+ ValueError: If TensorBoard callback's log_dir does not match
738
+ self.directory.
739
+ """
740
+
741
+ logdir = self ._get_tensorboard_log_dir (trial .trial_id )
742
+ for callback in callbacks :
743
+ if callback .__class__ .__name__ == "TensorBoard" :
744
+ # Validate TensorBoard log_dir
745
+ if callback .log_dir != self .directory :
746
+ # TODO(b/170687807) Switch from using .format() to f-string
747
+ raise ValueError (
748
+ "log_dir in TensorBoard callback should be {}, "
749
+ "but was {}" .format (self .directory , callback .log_dir )
750
+ )
751
+ # Patch the log_dir
752
+ callback .log_dir = logdir
753
+ # Do hyperparameter logging here to avoid having to
754
+ # serialize/deserialize the hyperparameters if logged through
755
+ # passing hparams_api.KerasCallback to client.cloud_fit.
756
+ with tf .summary .create_file_writer (logdir ).as_default ():
757
+ hparams_api .hparams (utils .convert_hyperparams_to_hparams (
758
+ trial .hyperparameters ))
759
+ # We're done here, since there should only be one TensorBoard
760
+ # callback
761
+ return
762
+
763
+ # TensorBoard callback not specified by user, add it here. The
764
+ # TensorBoard logs are used for passing metrics back from
765
+ # remote execution.
766
+ callbacks .append (tf .keras .callbacks .TensorBoard (log_dir = logdir ))
735
767
736
- def _get_tensorboard_log_dir (self , trial_id )-> Text :
768
+ def _get_tensorboard_log_dir (self , trial_id ) -> Text :
737
769
# Defining <directory>/<trial_id>/logs as log structure.
738
770
# self._add_tensorboard_callback uses this directory structure to
739
771
# configure the tf.keras.callbacks.TensorBoard() for each trial.
740
772
return os .path .join (self .directory , str (trial_id ), "logs" )
741
773
742
- def _get_model_checkpoint_dir (self , trial_id )-> Text :
774
+ def _get_model_checkpoint_dir (self , trial_id ) -> Text :
743
775
# Defining <directory>/<trial_id>/checkpoint as checkpoint structure.
744
776
# self._add_model_checkpoint_callback uses this directory structure to
745
777
# configure the tf.keras.callbacks.ModelCheckpoint() for each trial.
0 commit comments