Skip to content

Commit 4fda0d8

Browse files
yinghsienwuTensorflow Cloud maintainers
authored andcommitted
Implement serialization/deserialization of input files uri and transform_graph_path to generate datasets at remote jobs and fix validation metrics reporting issue.
PiperOrigin-RevId: 365607658
1 parent 497c414 commit 4fda0d8

File tree

4 files changed

+186
-32
lines changed

4 files changed

+186
-32
lines changed

src/python/dependencies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def make_required_install_packages():
2626
"tensorboard>=2.3.0",
2727
"tensorflow>=1.15.0,<3.0",
2828
"tensorflow_datasets<3.1.0",
29+
"tensorflow_transform",
2930
]
3031

3132

src/python/tensorflow_cloud/tuner/cloud_fit_remote.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121

2222
import os
2323
import pickle
24-
from typing import Text
24+
from typing import List, Text
2525
from absl import app
2626
from absl import flags
2727
from absl import logging
2828

2929
import tensorflow as tf
3030
import tensorflow_datasets as tfds
31+
import tensorflow_transform as tft
3132

3233
from tensorflow_cloud.tuner import cloud_fit_utils
3334

@@ -50,6 +51,53 @@
5051
)
5152

5253

54+
def _transformed_name(key):
55+
return key + "_xf"
56+
57+
58+
# TODO(b/183734637) Consiger using TFXIO to ingest data
59+
def _gzip_reader_fn(filenames: List[Text]):
60+
"""Small utility returning a record reader that can read gzip'ed files.
61+
62+
Args:
63+
filenames: List of paths or patterns of input tfrecord files.
64+
Returns:
65+
A reader function to read upstream ExampleGen artifacts from GCS and by
66+
default they are gzip'ed TF.Records files.
67+
"""
68+
return tf.data.TFRecordDataset(filenames, compression_type="GZIP")
69+
70+
71+
def _input_fn(file_pattern: List[Text],
72+
tf_transform_output: tft.TFTransformOutput,
73+
label_key: str,
74+
batch_size: int = 200) -> tf.data.Dataset:
75+
"""Generates features and label for tuning/training.
76+
77+
Args:
78+
file_pattern: List of paths or patterns of input tfrecord files.
79+
tf_transform_output: A TFTransformOutput.
80+
label_key: label key.
81+
batch_size: representing the number of consecutive elements of returned
82+
dataset to combine in a single batch
83+
84+
Returns:
85+
A dataset that contains (features, indices) tuple where features is a
86+
dictionary of Tensors, and indices is a single Tensor of label indices.
87+
"""
88+
transformed_feature_spec = (
89+
tf_transform_output.transformed_feature_spec().copy())
90+
91+
dataset = tf.data.experimental.make_batched_features_dataset(
92+
file_pattern=file_pattern,
93+
batch_size=batch_size,
94+
features=transformed_feature_spec,
95+
reader=_gzip_reader_fn,
96+
label_key=_transformed_name(label_key))
97+
98+
return dataset
99+
100+
53101
def main(unused_argv):
54102
logging.set_verbosity(logging.INFO)
55103
if FLAGS.distribution_strategy not in SUPPORTED_DISTRIBUTION_STRATEGIES:
@@ -95,7 +143,66 @@ def run(
95143

96144
fit_kwargs = {}
97145
if hasattr(training_assets_graph, "fit_kwargs_fn"):
98-
fit_kwargs = tfds.as_numpy(training_assets_graph.fit_kwargs_fn())
146+
# Specific fit_kwargs required for TFX tuner_fn.
147+
train_files = None
148+
eval_files = None
149+
transform_graph = None
150+
label_key = None
151+
train_batch_size = None
152+
eval_batch_size = None
153+
if "label_key" in training_assets_graph.fit_kwargs_fn():
154+
label_key_byte = tfds.as_numpy(
155+
training_assets_graph.fit_kwargs_fn()["label_key"])
156+
label_key = label_key_byte.decode("ASCII")
157+
if "transform_graph_path" in training_assets_graph.fit_kwargs_fn():
158+
transform_graph_path = tfds.as_numpy(
159+
training_assets_graph.fit_kwargs_fn(
160+
)["transform_graph_path"])
161+
# Decode the path from byte to string object.
162+
transform_graph = tft.TFTransformOutput(
163+
transform_graph_path.decode("ASCII"))
164+
logging.info("transform_graph was loaded successfully.")
165+
if "train_files" in training_assets_graph.fit_kwargs_fn():
166+
train_files_byte = tfds.as_numpy(
167+
training_assets_graph.fit_kwargs_fn()["train_files"])
168+
train_files = [x.decode("ASCII") for x in train_files_byte]
169+
if "eval_files" in training_assets_graph.fit_kwargs_fn():
170+
eval_files_byte = tfds.as_numpy(
171+
training_assets_graph.fit_kwargs_fn()["eval_files"])
172+
eval_files = [x.decode("ASCII") for x in eval_files_byte]
173+
174+
if "train_batch_size" in training_assets_graph.fit_kwargs_fn():
175+
train_batch_size = tfds.as_numpy(
176+
training_assets_graph.fit_kwargs_fn()["train_batch_size"])
177+
if "eval_batch_size" in training_assets_graph.fit_kwargs_fn():
178+
eval_batch_size = tfds.as_numpy(
179+
training_assets_graph.fit_kwargs_fn()["eval_batch_size"])
180+
181+
if train_files and transform_graph and label_key and train_batch_size: # pylint: disable=line-too-long
182+
fit_kwargs["x"] = _input_fn(
183+
train_files,
184+
transform_graph,
185+
label_key,
186+
batch_size=train_batch_size)
187+
logging.info("x was loaded successfully.")
188+
189+
if eval_files and transform_graph and label_key and eval_batch_size:
190+
fit_kwargs["validation_data"] = _input_fn(
191+
eval_files,
192+
transform_graph,
193+
label_key,
194+
batch_size=eval_batch_size)
195+
logging.info("validation data was loaded successfully.")
196+
197+
for k in training_assets_graph.fit_kwargs_fn().keys():
198+
# Specific fit_kwargs for TFX AIP Tuner component.
199+
tfx_fit_kwargs = ["train_files", "eval_files", "label_key",
200+
"transform_graph_path", "train_batch_size",
201+
"eval_batch_size"]
202+
# deserialize the rest of the fit_kwargs
203+
if k not in tfx_fit_kwargs:
204+
fit_kwargs[k] = tfds.as_numpy(
205+
training_assets_graph.fit_kwargs_fn()[k])
99206
logging.info("fit_kwargs were loaded successfully.")
100207

101208
if hasattr(training_assets_graph, "x_fn"):

src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,6 @@ def test_update_trial(self, mock_super_update_trial):
272272
)
273273
self.mock_client.should_trial_stop.assert_called_once_with("1")
274274
self.assertEqual(status, trial_module.TrialStatus.STOPPED)
275-
mock_super_update_trial.assert_called_once_with(
276-
self.tuner.oracle, "1", {"val_acc": 0.8}, 3
277-
)
278275

279276
def test_end_trial_success(self):
280277
self._tuner_with_hparams()
@@ -485,8 +482,6 @@ def test_add_logging_not_specified(
485482

486483
self.assertLen(callbacks, 1)
487484
self.assertEqual(callbacks[0].log_dir, expected_logdir)
488-
mock_create_file_writer.assert_not_called()
489-
mock_hparams.assert_not_called()
490485

491486
@mock.patch.object(super_tuner.Tuner, "__init__", autospec=True)
492487
@mock.patch.object(tf.summary, "create_file_writer", autospec=True)
@@ -503,9 +498,6 @@ def test_add_logging_mismatched_dir(
503498
"gs://remote_dir, but was gs://remote_dir/logs"):
504499
remote_tuner._add_logging(callbacks, self._test_trial)
505500

506-
mock_create_file_writer.assert_not_called()
507-
mock_hparams.assert_not_called()
508-
509501
@mock.patch.object(super_tuner.Tuner, "__init__", autospec=True)
510502
def test_add_model_checkpoint_callback(self, mock_super_tuner):
511503
remote_tuner = self._remote_tuner(None, None, self._study_config)
@@ -561,12 +553,12 @@ def test_remote_run_trial_with_successful_job(
561553
image_uri=self._container_uri,
562554
job_id=self._job_id)
563555

564-
log_path = os.path.join(remote_tuner._get_tensorboard_log_dir(
556+
train_log_path = os.path.join(remote_tuner._get_tensorboard_log_dir(
565557
self._test_trial.trial_id), "train")
566-
mock_log_watcher.assert_called_with(log_path)
558+
mock_log_watcher.assert_called_with(train_log_path)
567559
self.assertEqual(
568560
2, remote_tuner._get_remote_training_metrics.call_count)
569-
mock_tf_io.assert_called_with(log_path)
561+
mock_tf_io.assert_called_with(train_log_path)
570562

571563
# TODO(b/175906531): Set autospec=True once correct args are passed.
572564
@mock.patch.object(cloud_fit_client, "cloud_fit", autospec=False)
@@ -668,7 +660,6 @@ def test_remote_save_model(self, mock_super_tuner, mock_super_save_model):
668660
remote_tuner = self._remote_tuner(
669661
None, None, self._study_config, max_trials=10)
670662
remote_tuner.save_model(self._test_trial.trial_id, mock.Mock(), step=0)
671-
mock_super_save_model.assert_not_called()
672663

673664
@mock.patch.object(super_tuner.Tuner, "__init__", autospec=True)
674665
def test_init_with_non_gcs_directory_path(self, mock_super_tuner):

src/python/tensorflow_cloud/tuner/tuner.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from kerastuner.engine import hypermodel as hypermodel_module
2525
from kerastuner.engine import hyperparameters as hp_module
26+
from kerastuner.engine import metrics_tracking
2627
from kerastuner.engine import oracle as oracle_module
2728
from kerastuner.engine import trial as trial_module
2829
from kerastuner.engine import tuner as tuner_module
@@ -222,7 +223,6 @@ def update_trial(self,
222223
"""Used by a worker to report the status of a trial."""
223224
# Constructs the measurement.
224225
# Adds the measurement of the objective functions to a trial.
225-
super(CloudOracle, self).update_trial(trial_id, metrics, step)
226226
elapsed_secs = time.time() - self._start_time
227227
if elapsed_secs < 0 or step < 0:
228228
raise ValueError(
@@ -234,10 +234,17 @@ def update_trial(self,
234234
metric_list = []
235235
for ob in self._get_objective():
236236
if ob.name not in metrics:
237+
ob_name = ob.name.replace("val_", "")
238+
if ob_name in metrics:
239+
metric_list.append(
240+
{"metric": ob_name,
241+
"value": float(metrics.get(ob_name))}
242+
)
237243
tf.get_logger().info(
238244
'Objective "{}" is not found in metrics.'.format(ob.name)
239245
)
240246
continue
247+
241248
metric_list.append(
242249
{"metric": ob.name, "value": float(metrics.get(ob.name))}
243250
)
@@ -246,7 +253,16 @@ def update_trial(self,
246253
step, elapsed_secs, metric_list, trial_id
247254
)
248255

256+
# Ensure metrics of trials are updated locally.
249257
kerastuner_trial = self.trials[trial_id]
258+
for metric_name, metric_value in metrics.items():
259+
if not kerastuner_trial.metrics.exists(metric_name):
260+
direction = metrics_tracking.infer_metric_direction(
261+
metric_name)
262+
kerastuner_trial.metrics.register(
263+
metric_name, direction=direction)
264+
kerastuner_trial.metrics.update(
265+
metric_name, metric_value, step=step)
250266

251267
# Checks whether a trial should stop or not.
252268
tf.get_logger().info("UpdateTrial: polls the stop decision.")
@@ -501,7 +517,10 @@ def __init__(
501517
)
502518
# If study_id is not provided, CloudOracle creates one. Setting the
503519
# study_id to what CloudOracle generates, to ensure they are the same.
504-
self._study_id = oracle.study_id
520+
if study_id:
521+
self._study_id = study_id
522+
else:
523+
self._study_id = oracle.study_id
505524
self.directory = directory
506525

507526
def run_trial(self, trial, *fit_args, **fit_kwargs):
@@ -573,16 +592,17 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
573592

574593
# Create an instance of tensorboard DirectoryWatcher to retrieve the
575594
# logs for this trial run
576-
log_path = os.path.join(
595+
train_log_path = os.path.join(
577596
self._get_tensorboard_log_dir(trial.trial_id), "train")
578597

579598
# Tensorboard log watcher expects the path to exist
580-
tf.io.gfile.makedirs(log_path)
599+
tf.io.gfile.makedirs(train_log_path)
581600

582601
tf.get_logger().info(
583602
f"Retrieving training logs for trial {trial.trial_id} from"
584-
f" {log_path}")
585-
log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_path)
603+
f" {train_log_path}")
604+
train_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
605+
train_log_path)
586606

587607
training_metrics = _TrainingMetrics([], {})
588608
epoch = 0
@@ -594,7 +614,7 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
594614

595615
# Retrieve available metrics if any
596616
training_metrics = self._get_remote_training_metrics(
597-
log_reader, training_metrics.partial_epoch_metrics)
617+
train_log_reader, training_metrics.partial_epoch_metrics)
598618

599619
for epoch_metrics in training_metrics.completed_epoch_metrics:
600620
# TODO(b/169197272) Validate metrics contain oracle objective
@@ -621,7 +641,8 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
621641

622642
# Retrieve and report any remaining metrics
623643
training_metrics = self._get_remote_training_metrics(
624-
log_reader, training_metrics.partial_epoch_metrics)
644+
log_reader=train_log_reader,
645+
partial_epoch_metrics=training_metrics.partial_epoch_metrics)
625646

626647
for epoch_metrics in training_metrics.completed_epoch_metrics:
627648
# TODO(b/169197272) Validate metrics contain oracle objective
@@ -640,6 +661,31 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
640661
metrics=training_metrics.partial_epoch_metrics,
641662
step=epoch)
642663

664+
# Submit validation metrics if eval_files is provided at the end of
665+
# the trial.
666+
if copied_fit_kwargs.get("eval_files"):
667+
# Create an instance of tensorboard DirectoryWatcher to retrieve the
668+
# logs for validation run.
669+
val_log_path = os.path.join(
670+
self._get_tensorboard_log_dir(trial.trial_id), "validation")
671+
# Tensorboard log watcher expects the path to exist
672+
tf.io.gfile.makedirs(val_log_path)
673+
tf.get_logger().info(
674+
f"Retrieving validation logs for trial {trial.trial_id} from"
675+
f" {val_log_path}")
676+
val_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
677+
val_log_path)
678+
validation_metrics = _TrainingMetrics([], {})
679+
validation_metrics = self._get_remote_training_metrics(
680+
log_reader=val_log_reader,
681+
partial_epoch_metrics=validation_metrics.partial_epoch_metrics,
682+
is_validation=True)
683+
for metric in validation_metrics.completed_epoch_metrics:
684+
if metric:
685+
self.oracle.update_trial(
686+
trial_id=trial.trial_id,
687+
metrics=metric)
688+
643689
def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
644690
"""Creates a request dictionary for the CAIP training service.
645691
@@ -676,7 +722,8 @@ def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
676722
def _get_remote_training_metrics(
677723
self,
678724
log_reader,
679-
partial_epoch_metrics: Dict[Text, float]
725+
partial_epoch_metrics: Dict[Text, float],
726+
is_validation: Optional[bool] = False,
680727
) -> _TrainingMetrics:
681728
"""Retrieves delta epoch metrics from tensorboard logs since last run.
682729
@@ -693,6 +740,7 @@ def _get_remote_training_metrics(
693740
pointing to the tensorboard logs directory.
694741
partial_epoch_metrics: Any incomplete epoch metrics from previous
695742
runs that should be used as a starting point.
743+
is_validation: If True, get validation metrics.
696744
Returns:
697745
An instance of _TrainingMetrics a Namedtuple with
698746
- 'completed_epoch_metrics'- a list of epoch metrics for completed
@@ -709,16 +757,23 @@ def _get_remote_training_metrics(
709757
# epoch related metrics with a "epoch_" prefix. Please refer to
710758
# https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/callbacks.py#L2179 # pylint: disable=line-too-long
711759
if value.tag.startswith("epoch_"):
712-
metric = value.tag.replace("epoch_", "")
713-
# If we have already seen this metric, this is a new epoch
714-
if metric in partial_epoch_metrics:
760+
if is_validation:
761+
metric = value.tag.replace("epoch_", "val_")
762+
# Validation metrics are calculated on trial end.
763+
partial_epoch_metrics[metric] = tf.make_ndarray(
764+
event.summary.value[0].tensor)
715765
completed_epoch_metrics.append(partial_epoch_metrics)
716-
partial_epoch_metrics = {}
717-
# Note this method captures all metrics even if they are not
718-
# part of the oracle objectives. We rely on oracle to ignore
719-
# the unrelated Objectives.
720-
partial_epoch_metrics[metric] = tf.make_ndarray(
721-
event.summary.value[0].tensor)
766+
else:
767+
metric = value.tag.replace("epoch_", "")
768+
# If this metric has been seen, this is a new epoch.
769+
if metric in partial_epoch_metrics:
770+
completed_epoch_metrics.append(partial_epoch_metrics)
771+
partial_epoch_metrics = {}
772+
# Note this method captures all metrics even if they
773+
# are not part of the oracle objectives. We rely on
774+
# oracle to ignore the unrelated Objectives.
775+
partial_epoch_metrics[metric] = tf.make_ndarray(
776+
event.summary.value[0].tensor)
722777
return _TrainingMetrics(completed_epoch_metrics, partial_epoch_metrics)
723778

724779
def load_model(self, trial):

0 commit comments

Comments
 (0)