Skip to content

Commit 651ce41

Browse files
chongyouquanTensorflow Cloud maintainers
authored andcommitted
Integrate HParams plugin with DistributingCloudTuner
PiperOrigin-RevId: 339603283
1 parent 11a7cf6 commit 651ce41

File tree

4 files changed

+282
-49
lines changed

4 files changed

+282
-49
lines changed

src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import mock
2626

2727
import tensorflow as tf
28+
from tensorboard.plugins.hparams import api as hparams_api
2829
from tensorflow_cloud.core import deploy
2930
from tensorflow_cloud.core import machine_config
3031
from tensorflow_cloud.core import validate
@@ -125,9 +126,10 @@ def _remote_tuner(
125126
objective,
126127
hyperparameters,
127128
study_config,
128-
directory="gs://remote_dir",
129+
directory=None,
129130
max_trials=None
130131
):
132+
directory = directory or self._remote_dir
131133
return tuner.DistributingCloudTuner(
132134
hypermodel=build_model,
133135
objective=objective,
@@ -457,19 +459,66 @@ def test_get_best_trials_multi_tuners(self):
457459
self.assertEqual(best_trials_1[0].best_step, 3)
458460

459461
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
460-
def test_add_tensorboard_callback(self, mock_super_tuner):
462+
@mock.patch.object(tf.summary, "create_file_writer", auto_spec=True)
463+
@mock.patch.object(hparams_api, "hparams", auto_spec=True)
464+
def test_add_logging_user_specified(
465+
self, mock_hparams, mock_create_file_writer, mock_super_tuner):
461466
remote_tuner = self._remote_tuner(None, None, self._study_config)
462467

463-
callbacks = [
464-
tf.keras.callbacks.TensorBoard(log_dir="user_defined_path_1"),
465-
tf.keras.callbacks.TensorBoard(log_dir="user_defined_path_2")]
468+
callbacks = [tf.keras.callbacks.TensorBoard(
469+
log_dir=remote_tuner.directory,
470+
write_images=True)]
471+
472+
remote_tuner._add_logging(callbacks, self._test_trial)
473+
474+
expected_logdir = os.path.join(
475+
remote_tuner.directory, self._test_trial.trial_id, "logs")
476+
expected_hparams = {hparams_api.HParam(
477+
"learning_rate", hparams_api.Discrete([1e-4, 1e-3, 1e-2])): 1e-4}
466478

467-
trial_id = "test_trial_id"
468-
remote_tuner._add_tensorboard_callback(callbacks, trial_id)
469479
self.assertLen(callbacks, 1)
480+
self.assertEqual(callbacks[0].log_dir, expected_logdir)
481+
self.assertEqual(callbacks[0].write_images, True)
482+
mock_create_file_writer.assert_called_once_with(expected_logdir)
483+
self.assertEqual(mock_hparams.call_count, 1)
470484
self.assertEqual(
471-
callbacks[0].log_dir,
472-
os.path.join(remote_tuner.directory, trial_id, "logs"))
485+
repr(mock_hparams.call_args[0][0]), repr(expected_hparams))
486+
487+
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
488+
@mock.patch.object(tf.summary, "create_file_writer", auto_spec=True)
489+
@mock.patch.object(hparams_api, "hparams", auto_spec=True)
490+
def test_add_logging_not_specified(
491+
self, mock_hparams, mock_create_file_writer, mock_super_tuner):
492+
remote_tuner = self._remote_tuner(None, None, self._study_config)
493+
494+
callbacks = []
495+
remote_tuner._add_logging(callbacks, self._test_trial)
496+
497+
expected_logdir = os.path.join(
498+
remote_tuner.directory, self._test_trial.trial_id, "logs")
499+
500+
self.assertLen(callbacks, 1)
501+
self.assertEqual(callbacks[0].log_dir, expected_logdir)
502+
mock_create_file_writer.assert_not_called()
503+
mock_hparams.assert_not_called()
504+
505+
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
506+
@mock.patch.object(tf.summary, "create_file_writer", auto_spec=True)
507+
@mock.patch.object(hparams_api, "hparams", auto_spec=True)
508+
def test_add_logging_mismatched_dir(
509+
self, mock_hparams, mock_create_file_writer, mock_super_tuner):
510+
remote_tuner = self._remote_tuner(None, None, self._study_config)
511+
512+
callbacks = [tf.keras.callbacks.TensorBoard(
513+
log_dir=os.path.join(remote_tuner.directory, "logs"))]
514+
515+
with self.assertRaisesRegex(
516+
ValueError, "log_dir in TensorBoard callback should be "
517+
"gs://remote_dir, but was gs://remote_dir/logs"):
518+
remote_tuner._add_logging(callbacks, self._test_trial)
519+
520+
mock_create_file_writer.assert_not_called()
521+
mock_hparams.assert_not_called()
473522

474523
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
475524
def test_add_model_checkpoint_callback(self, mock_super_tuner):

src/python/tensorflow_cloud/tuner/tests/unit/utils_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from kerastuner.engine import oracle as oracle_module
2121
from kerastuner.engine import trial as trial_module
2222
import tensorflow as tf
23+
from tensorboard.plugins.hparams import api as hparams_api
2324
from tensorflow_cloud.tuner.tuner import utils
2425

2526
STUDY_CONFIG_DISCRETE = {
@@ -315,6 +316,99 @@ def test_convert_optimizer_trial_to_keras_trial(self):
315316
self.assertEqual(
316317
trial.hyperparameters.values, {"learning_rate": 0.0001})
317318

319+
def test_convert_hyperparams_to_hparams_choice(self):
320+
hps = hp_module.HyperParameters()
321+
hps.Choice("learning_rate", [1e-4, 1e-3, 1e-2])
322+
hparams = utils.convert_hyperparams_to_hparams(hps)
323+
expected_hparams = {
324+
hparams_api.HParam("learning_rate",
325+
hparams_api.Discrete([1e-4, 1e-3, 1e-2])): 1e-4,
326+
}
327+
self.assertEqual(repr(hparams), repr(expected_hparams))
328+
329+
@parameterized.parameters(
330+
("units", 2, 16, None, hparams_api.IntInterval(2, 16), 2),
331+
("units", 32, 128, 32, hparams_api.Discrete([32, 64, 96, 128]), 32))
332+
def test_convert_hyperparams_to_hparams_int(self, name, min_value,
333+
max_value, step,
334+
expected_domain,
335+
expected_value):
336+
hps = hp_module.HyperParameters()
337+
if step:
338+
hps.Int(name, min_value=min_value, max_value=max_value, step=step)
339+
else:
340+
hps.Int(name, min_value=min_value, max_value=max_value)
341+
hparams = utils.convert_hyperparams_to_hparams(hps)
342+
expected_hparams = {
343+
hparams_api.HParam(name, expected_domain): expected_value,
344+
}
345+
self.assertEqual(repr(hparams), repr(expected_hparams))
346+
347+
@parameterized.parameters(
348+
("learning_rate", 0.5, 1.5, 0.25,
349+
hparams_api.Discrete([0.5, 0.75, 1.0, 1.25, 1.5]), 0.5),
350+
("learning_rate", 1e-4, 1e-1, None,
351+
hparams_api.RealInterval(1e-4, 1e-1), 1e-4))
352+
def test_convert_hyperparams_to_hparams_float(self, name, min_value,
353+
max_value, step,
354+
expected_domain,
355+
expected_value):
356+
hps = hp_module.HyperParameters()
357+
hps.Float(name, min_value=min_value, max_value=max_value, step=step)
358+
hparams = utils.convert_hyperparams_to_hparams(hps)
359+
expected_hparams = {
360+
hparams_api.HParam(name, expected_domain): expected_value,
361+
}
362+
self.assertEqual(repr(hparams), repr(expected_hparams))
363+
364+
def test_convert_hyperparams_to_hparams_multi_float(self):
365+
hps = hp_module.HyperParameters()
366+
hps.Float("theta", min_value=0.0, max_value=1.57)
367+
hps.Float("r", min_value=0.0, max_value=1.0)
368+
hparams = utils.convert_hyperparams_to_hparams(hps)
369+
expected_hparams = {
370+
hparams_api.HParam("r", hparams_api.RealInterval(0.0, 1.0)): 0.0,
371+
hparams_api.HParam("theta",
372+
hparams_api.RealInterval(0.0, 1.57)): 0.0,
373+
}
374+
hparams_repr_list = [repr(hparams[x]) for x in hparams.keys()]
375+
expected_hparams_repr_list = [
376+
repr(expected_hparams[x]) for x in expected_hparams.keys()
377+
]
378+
self.assertCountEqual(hparams_repr_list, expected_hparams_repr_list)
379+
380+
def test_convert_hyperparams_to_hparams_boolean(self):
381+
hps = hp_module.HyperParameters()
382+
hps.Boolean("has_beta")
383+
hparams = utils.convert_hyperparams_to_hparams(hps)
384+
expected_hparams = {
385+
hparams_api.HParam("has_beta", hparams_api.Discrete([True, False])):
386+
False,
387+
}
388+
self.assertEqual(repr(hparams), repr(expected_hparams))
389+
390+
@parameterized.parameters(
391+
("beta", 0.1),
392+
("type", "WIDE_AND_DEEP"),
393+
("num_layers", 2))
394+
def test_convert_hyperparams_to_hparams_fixed(self, name, value):
395+
hps = hp_module.HyperParameters()
396+
hps.Fixed(name, value)
397+
hparams = utils.convert_hyperparams_to_hparams(hps)
398+
expected_hparams = {
399+
hparams_api.HParam(name, hparams_api.Discrete([value])): value,
400+
}
401+
self.assertEqual(repr(hparams), repr(expected_hparams))
402+
403+
def test_convert_hyperparams_to_hparams_fixed_bool(self):
404+
hps = hp_module.HyperParameters()
405+
hps.Fixed("condition", True)
406+
hparams = utils.convert_hyperparams_to_hparams(hps)
407+
expected_hparams = {
408+
hparams_api.HParam("condition", hparams_api.Discrete([1])): 1,
409+
}
410+
self.assertEqual(repr(hparams), repr(expected_hparams))
411+
318412
@parameterized.parameters(
319413
("val_loss", "min",
320414
[oracle_module.Objective(name="val_loss", direction="min")]),

src/python/tensorflow_cloud/tuner/tuner.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from kerastuner.engine import tuner as tuner_module
2929
import tensorflow as tf
3030

31+
from tensorboard.plugins.hparams import api as hparams_api
3132
from tensorflow_cloud.core import deploy
3233
from tensorflow_cloud.core import machine_config
3334
from tensorflow_cloud.core import validate
@@ -492,8 +493,8 @@ def __init__(
492493
super(DistributingCloudTuner, self,).__init__(
493494
oracle=oracle, hypermodel=hypermodel, **kwargs
494495
)
495-
# If study id is not provided cloud_oracle creates ones. Setting the
496-
# study_id based on cloud oracles logic to ensure they are the same.
496+
# If study_id is not provided, CloudOracle creates one. Setting the
497+
# study_id to what CloudOracle generates, to ensure they are the same.
497498
self._study_id = oracle.study_id
498499
self.directory = directory
499500

@@ -519,16 +520,15 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
519520
callbacks = fit_kwargs.pop("callbacks", [])
520521
callbacks = self._deepcopy_callbacks(callbacks)
521522

522-
# Note run_trial does not use `TunerCallback` calls, since
523+
# Note: run_trial does not use `TunerCallback` calls, since
523524
# training is performed on AI Platform training remotely.
524525

525-
# Creating a tensorboard callback with log-dir path specific for this
526-
# trail_id. The tensorboard logs are used for passing metrics back from
527-
# remote execution.
528-
self._add_tensorboard_callback(callbacks, trial.trial_id)
526+
# Handle TensorBoard/hyperparameter logging here. The TensorBoard
527+
# logs are used for passing metrics back from remote execution.
528+
self._add_logging(callbacks, trial)
529529

530530
# Creating a save_model checkpoint callback with a saved model file path
531-
# specific to this trial, this is to prevent different trials from
531+
# specific to this trial. This is to prevent different trials from
532532
# overwriting each other.
533533
self._add_model_checkpoint_callback(
534534
callbacks, trial.trial_id)
@@ -605,7 +605,9 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
605605
if not google_api_client.wait_for_api_training_job_completion(
606606
job_id, self._project_id):
607607
raise RuntimeError(
608-
"AIP Training job failed, see logs for details at https://console.cloud.google.com/ai-platform/jobs/{}/charts/cpu?project={}" # pylint: disable=line-too-long
608+
"AIP Training job failed, see logs for details at "
609+
"https://console.cloud.google.com/ai-platform/jobs/"
610+
"{}/charts/cpu?project={}"
609611
.format(job_id, self._project_id))
610612

611613
# Retrieve and report any remaining metrics
@@ -657,7 +659,7 @@ def _get_remote_training_metrics(
657659
self,
658660
log_reader,
659661
partial_epoch_metrics: Dict[Text, float]
660-
)-> _TrainingMetrics:
662+
) -> _TrainingMetrics:
661663
"""Retrieves delta epoch metrics from tensorboard logs since last run.
662664
663665
This method reports any complete epoch metrics that are available since
@@ -683,9 +685,9 @@ def _get_remote_training_metrics(
683685
completed_epoch_metrics = []
684686
for event in log_reader.Load():
685687
for value in event.summary.value:
686-
# Note tf.keras.callbacks.TensorBoard() with update_freq="epoch"
687-
# logs the epoch related metrics with a "epoch_" prefix. This is
688-
# not a requirement by tensorboard.
688+
# Note: tf.keras.callbacks.TensorBoard.on_epoch_end() logs the
689+
# epoch related metrics with a "epoch_" prefix. Please refer to
690+
# https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/callbacks.py#L2179 # pylint: disable=line-too-long
689691
if value.tag.startswith("epoch_"):
690692
metric = value.tag.replace("epoch_", "")
691693
# If we have already seen this metric, this is a new epoch
@@ -708,7 +710,6 @@ def load_model(self, trial):
708710
raise NotImplementedError("load_model for remote run is not supported.")
709711

710712
def save_model(self, trial_id: int, model, step: int = 0):
711-
712713
# In remote execution models are saved automatically in Google Cloud
713714
# Storage (GCS) bucket hence no additional actions are needed to save
714715
# the model.
@@ -719,27 +720,58 @@ def _add_model_checkpoint_callback(self, callbacks, trial_id):
719720
filepath=self._get_model_checkpoint_dir(trial_id),
720721
save_freq="epoch"))
721722

722-
def _add_tensorboard_callback(self, callbacks, trial_id):
723-
# due to https://github.com/keras-team/keras/issues/14223 multiple
724-
# tensorboard callbacks are not supported. Removing user defined
725-
# tf.keras.callbacks.TensorBoard callback.
723+
def _add_logging(self, callbacks, trial):
724+
"""Add a TensorBoard callback if needed, otherwise log hyperparameters.
726725
727-
tf.get_logger().info(
728-
"Only one tf.keras.callbacks.TensorBoard callback is allowed, removing user defined callbacks." # pylint: disable=line-too-long
729-
)
730-
callbacks[:] = [
731-
x for x in callbacks if x.__class__.__name__ != "TensorBoard"]
726+
Note: Due to https://github.com/keras-team/keras/issues/14223, multiple
727+
TensorBoard callbacks are not supported. If user specified a TensorBoard
728+
callback, we treat it as an intent to log the metrics, and we shall
729+
additionally log the hyperparameters as well. Otherwise, we'll add a
730+
TensorBoard callback to pass back the epoch related metrics from
731+
remote execution.
732732
733-
callbacks.append(tf.keras.callbacks.TensorBoard(
734-
log_dir=self._get_tensorboard_log_dir(trial_id)))
733+
Arguments:
734+
callbacks: List of callbacks passed in to the search function.
735+
trial: A `Trial` instance.
736+
Raises:
737+
ValueError: If TensorBoard callback's log_dir does not match
738+
self.directory.
739+
"""
740+
741+
logdir = self._get_tensorboard_log_dir(trial.trial_id)
742+
for callback in callbacks:
743+
if callback.__class__.__name__ == "TensorBoard":
744+
# Validate TensorBoard log_dir
745+
if callback.log_dir != self.directory:
746+
# TODO(b/170687807) Switch from using .format() to f-string
747+
raise ValueError(
748+
"log_dir in TensorBoard callback should be {}, "
749+
"but was {}".format(self.directory, callback.log_dir)
750+
)
751+
# Patch the log_dir
752+
callback.log_dir = logdir
753+
# Do hyperparameter logging here to avoid having to
754+
# serialize/deserialize the hyperparameters if logged through
755+
# passing hparams_api.KerasCallback to client.cloud_fit.
756+
with tf.summary.create_file_writer(logdir).as_default():
757+
hparams_api.hparams(utils.convert_hyperparams_to_hparams(
758+
trial.hyperparameters))
759+
# We're done here, since there should only be one TensorBoard
760+
# callback
761+
return
762+
763+
# TensorBoard callback not specified by user, add it here. The
764+
# TensorBoard logs are used for passing metrics back from
765+
# remote execution.
766+
callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=logdir))
735767

736-
def _get_tensorboard_log_dir(self, trial_id)-> Text:
768+
def _get_tensorboard_log_dir(self, trial_id) -> Text:
737769
# Defining <directory>/<trial_id>/logs as log structure.
738770
# self._add_tensorboard_callback uses this directory structure to
739771
# configure the tf.keras.callbacks.TensorBoard() for each trial.
740772
return os.path.join(self.directory, str(trial_id), "logs")
741773

742-
def _get_model_checkpoint_dir(self, trial_id)->Text:
774+
def _get_model_checkpoint_dir(self, trial_id) -> Text:
743775
# Defining <directory>/<trial_id>/checkpoint as checkpoint structure.
744776
# self._add_model_checkpoint_callback uses this directory structure to
745777
# configure the tf.keras.callbacks.ModelCheckpoint() for each trial.

0 commit comments

Comments
 (0)