Skip to content

Commit f1ae448

Browse files
juanuribe28Tensorflow Cloud maintainers
authored andcommitted
Save model from run_experiment and add TODO comments.
PiperOrigin-RevId: 387643746
1 parent ac7fc6c commit f1ae448

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/python/tensorflow_cloud/core/experimental/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
271271
worker_config)
272272
run_experiment_kwargs.update(
273273
dict(distribution_strategy=distribution_strategy))
274-
train_lib.run_experiment(**run_experiment_kwargs)
274+
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
275+
model.save(run_experiment_kwargs['model_dir'])
275276

276277
run_kwargs.update(dict(entry_point=None,
277278
distribution_strategy=None))
@@ -282,11 +283,14 @@ def get_distribution_strategy(chief_config, worker_count, worker_config):
282283
"""Gets a tf distribution strategy based on the cloud run config."""
283284
if worker_count > 0:
284285
if machine_config.is_tpu_config(worker_config):
285-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
286+
# TODO(b/194857231) Dependency conflict for using TPUs
287+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
288+
tpu='local')
286289
tf.config.experimental_connect_to_cluster(resolver)
287290
tf.tpu.experimental.initialize_tpu_system(resolver)
288291
return tf.distribute.TPUStrategy(resolver)
289292
else:
293+
# TODO(b/148619319) Saving model currently failing
290294
return tf.distribute.MultiWorkerMirroredStrategy()
291295
elif chief_config.accelerator_count > 1:
292296
return tf.distribute.MirroredStrategy()

src/python/tensorflow_cloud/core/experimental/tests/unit/models_test.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,28 @@ def setup_run_experiment(self):
7777
mode='train_and_eval',
7878
params=config,
7979
model_dir='model_path')
80-
80+
self.model = mock.MagicMock()
8181
self.run_experiment = mock.patch.object(
8282
train_lib,
8383
'run_experiment',
8484
autospec=True,
85+
return_value=(self.model, {})
8586
).start()
8687

88+
def setup_tpu(self):
89+
mock.patch.object(tf.tpu.experimental,
90+
'initialize_tpu_system',
91+
autospec=True).start()
92+
mock.patch.object(tf.config,
93+
'experimental_connect_to_cluster',
94+
autospec=True).start()
95+
mock.patch('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
96+
).start()
97+
mock_tpu_strategy = mock.MagicMock(
98+
spec=tf.distribute.TPUStrategy)
99+
mock.patch('tensorflow.distribute.TPUStrategy',
100+
return_value=mock_tpu_strategy).start()
101+
87102
def tearDown(self):
88103
mock.patch.stopall()
89104
super(ModelsTest, self).tearDown()
@@ -182,20 +197,8 @@ def test_run_experiment_cloud_remote(self):
182197
self.remote.assert_called()
183198
self.run_experiment.assert_called()
184199
self.run.assert_called()
185-
186-
def setup_tpu(self):
187-
mock.patch.object(tf.tpu.experimental,
188-
'initialize_tpu_system',
189-
autospec=True).start()
190-
mock.patch.object(tf.config,
191-
'experimental_connect_to_cluster',
192-
autospec=True).start()
193-
mock.patch('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
194-
).start()
195-
mock_tpu_strategy = mock.MagicMock()
196-
mock_tpu_strategy.__class__ = tf.distribute.TPUStrategy
197-
mock.patch('tensorflow.distribute.TPUStrategy',
198-
return_value=mock_tpu_strategy).start()
200+
self.model.save.assert_called_with(
201+
self.run_experiment_kwargs['model_dir'])
199202

200203
def test_get_distribution_strategy_tpu(self):
201204
tpu_srategy = tf.distribute.TPUStrategy

0 commit comments

Comments
 (0)