Skip to content

Commit f19e308

Browse files
juanuribe28Tensorflow Cloud maintainers
authored andcommitted
Add wrapper for running experiments from TF Model Garden on GCP.
PiperOrigin-RevId: 384244039
1 parent 2fa03c1 commit f19e308

File tree

2 files changed

+173
-11
lines changed

2 files changed

+173
-11
lines changed

src/python/tensorflow_cloud/core/experimental/models.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
"""Module that contains the `run_models` wrapper for training models from TF Model Garden."""
1616

1717
import os
18-
from typing import Dict, Optional
18+
from typing import Any, Dict, Optional
1919

20+
from .. import machine_config
2021
from .. import run
2122
import tensorflow as tf
2223
import tensorflow_datasets as tfds
2324

25+
from official.core import train_lib
2426
from official.vision.image_classification.efficientnet import efficientnet_model
2527
from official.vision.image_classification.resnet import resnet_model
2628

@@ -224,3 +226,70 @@ def data_pipeline(original_ds, image_size, width_ratio, batch_size, num_classes,
224226
ds = ds.batch(batch_size, drop_remainder=True)
225227
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
226228
return ds
229+
230+
231+
def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
232+
run_kwargs: Optional[Dict[str, Any]] = None,
233+
) -> Optional[Dict[str, str]]:
234+
"""A wrapper for run API and tf-models-official run_experiment.
235+
236+
This method takes a dictionary of the parameters for run and a dictionary
237+
of the parameters for run_experiment to run the experiment directly on GCP.
238+
239+
Args:
240+
run_experiment_kwargs: keyword arguments for `train_lib.run_experiment`.
241+
The docs can be found at
242+
https://github.com/tensorflow/models/blob/master/official/core/train_lib.py
243+
The distribution_strategy param is ignored because the distirbution
244+
strategy is selected based on run_kwargs.
245+
run_kwargs: keyword arguments for `tfc.run`. The docs can be found at
246+
https://github.com/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/core/run.py
247+
The params entry_point and distribution_strategy are ignored.
248+
Returns:
249+
A dictionary with two keys.
250+
1. 'job_id': the training job id.
251+
2. 'docker_image': Docker image generated for the training job.
252+
"""
253+
if run_kwargs is None:
254+
run_kwargs = dict()
255+
256+
if run.remote():
257+
default_machine_config = machine_config.COMMON_MACHINE_CONFIGS['T4_1X']
258+
if 'chief_config' in run_kwargs:
259+
chief_config = run_kwargs['chief_config']
260+
else:
261+
chief_config = default_machine_config
262+
if 'worker_count' in run_kwargs:
263+
worker_count = run_kwargs['worker_count']
264+
else:
265+
worker_count = 0
266+
if 'worker_config' in run_kwargs:
267+
worker_config = run_kwargs['worker_config']
268+
else:
269+
worker_config = default_machine_config
270+
distribution_strategy = get_distribution_strategy(chief_config,
271+
worker_count,
272+
worker_config)
273+
run_experiment_kwargs.update(
274+
dict(distribution_strategy=distribution_strategy))
275+
train_lib.run_experiment(**run_experiment_kwargs)
276+
277+
run_kwargs.update(dict(entry_point=None,
278+
distribution_strategy=None))
279+
return run.run(**run_kwargs)
280+
281+
282+
def get_distribution_strategy(chief_config, worker_count, worker_config):
283+
"""Gets a tf distribution strategy based on the cloud run config."""
284+
if worker_count > 0:
285+
if machine_config.is_tpu_config(worker_config):
286+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
287+
tf.config.experimental_connect_to_cluster(resolver)
288+
tf.tpu.experimental.initialize_tpu_system(resolver)
289+
return tf.distribute.TPUStrategy(resolver)
290+
else:
291+
return tf.distribute.MultiWorkerMirroredStrategy()
292+
elif chief_config.accelerator_count > 1:
293+
return tf.distribute.MirroredStrategy()
294+
else:
295+
return tf.distribute.OneDeviceStrategy(device='/gpu:0')

src/python/tensorflow_cloud/core/experimental/tests/unit/models_test.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
import mock
1919

2020
import tensorflow as tf
21+
from tensorflow_cloud.core import machine_config
2122
from tensorflow_cloud.core import run
2223
from tensorflow_cloud.core.experimental import models
24+
from official.core import config_definitions
25+
from official.core import train_lib
2326
from official.vision.image_classification.efficientnet import efficientnet_model
2427

2528

@@ -41,12 +44,17 @@ def setup_normalize_img_and_label(self):
4144
3]
4245
self.label = tf.convert_to_tensor(4)
4346

44-
def setup_run_models(self, run_return_value=None, remote=True):
47+
def setup_run(self, remote=True):
48+
if remote:
49+
self.run_return_value = None
50+
else:
51+
self.run_return_value = {'job_id': 'job_id',
52+
'docker_image': 'docker_image'}
4553
self.run = mock.patch.object(
4654
run,
4755
'run',
4856
autospec=True,
49-
return_value=run_return_value,
57+
return_value=self.run_return_value,
5058
).start()
5159

5260
self.remote = mock.patch.object(
@@ -56,14 +64,29 @@ def setup_run_models(self, run_return_value=None, remote=True):
5664
return_value=remote,
5765
).start()
5866

67+
def setup_run_models(self):
5968
self.classifier_trainer = mock.patch.object(
6069
models,
6170
'classifier_trainer',
6271
autospec=True,
6372
).start()
6473

65-
def cleanup_run_models(self):
74+
def setup_run_experiment(self):
75+
config = config_definitions.ExperimentConfig()
76+
self.run_experiment_kwargs = dict(task=config.task,
77+
mode='train_and_eval',
78+
params=config,
79+
model_dir='model_path')
80+
81+
self.run_experiment = mock.patch.object(
82+
train_lib,
83+
'run_experiment',
84+
autospec=True,
85+
).start()
86+
87+
def tearDown(self):
6688
mock.patch.stopall()
89+
super(ModelsTest, self).tearDown()
6790

6891
def test_get_model_resnet(self):
6992
self.setup_get_model()
@@ -114,10 +137,8 @@ def test_normalize_image_and_label_with_one_hot(self):
114137
self.assertTrue((result_label == expected_label).numpy().all())
115138

116139
def test_run_models_locally(self):
117-
run_return = {'job_id': 'job_id',
118-
'docker_image': 'docker_image'}
119-
120-
self.setup_run_models(run_return, remote=False)
140+
self.setup_run(remote=False)
141+
self.setup_run_models()
121142
run_kwargs = {'entry_point': 'entry_point',
122143
'requirements_txt': 'requirements_txt',
123144
'worker_count': 5,}
@@ -130,9 +151,8 @@ def test_run_models_locally(self):
130151
'model_checkpoint', 'save_model']
131152
self.assertListEqual(list(result.keys()), return_keys)
132153

133-
self.cleanup_run_models()
134-
135154
def test_run_models_remote(self):
155+
self.setup_run()
136156
self.setup_run_models()
137157
result = models.run_models('dataset_name', 'model_name', 'gcs_bucket',
138158
'train')
@@ -142,7 +162,80 @@ def test_run_models_remote(self):
142162

143163
self.assertIsNone(result)
144164

145-
self.cleanup_run_models()
165+
def test_run_experiment_cloud_locally(self):
166+
self.setup_run(remote=False)
167+
self.setup_run_experiment()
168+
models.run_experiment_cloud(
169+
run_experiment_kwargs=self.run_experiment_kwargs)
170+
171+
self.remote.assert_called()
172+
self.run_experiment.assert_not_called()
173+
self.run.assert_called()
174+
175+
def test_run_experiment_cloud_remote(self):
176+
self.setup_run()
177+
self.setup_run_experiment()
178+
models.run_experiment_cloud(
179+
run_experiment_kwargs=self.run_experiment_kwargs)
180+
181+
self.remote.assert_called()
182+
self.run_experiment.assert_called()
183+
self.run.assert_called()
184+
185+
def setup_tpu(self):
186+
mock.patch.object(tf.tpu.experimental,
187+
'initialize_tpu_system',
188+
autospec=True).start()
189+
mock.patch.object(tf.config,
190+
'experimental_connect_to_cluster',
191+
autospec=True).start()
192+
mock.patch('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
193+
).start()
194+
mock_tpu_strategy = mock.MagicMock()
195+
mock_tpu_strategy.__class__ = tf.distribute.TPUStrategy
196+
mock.patch('tensorflow.distribute.TPUStrategy',
197+
return_value=mock_tpu_strategy).start()
198+
199+
def test_get_distribution_strategy_tpu(self):
200+
tpu_srategy = tf.distribute.TPUStrategy
201+
self.setup_tpu()
202+
chief_config = None
203+
worker_count = 1
204+
worker_config = machine_config.COMMON_MACHINE_CONFIGS['TPU']
205+
strategy = models.get_distribution_strategy(chief_config,
206+
worker_count,
207+
worker_config)
208+
self.assertIsInstance(strategy,
209+
tpu_srategy)
210+
211+
def test_get_distribution_strategy_multi_mirror(self):
212+
chief_config = None
213+
worker_count = 1
214+
worker_config = None
215+
strategy = models.get_distribution_strategy(chief_config,
216+
worker_count,
217+
worker_config)
218+
self.assertIsInstance(strategy,
219+
tf.distribute.MultiWorkerMirroredStrategy)
220+
221+
def test_get_distribution_strategy_mirror(self):
222+
chief_config = machine_config.COMMON_MACHINE_CONFIGS['K80_4X']
223+
worker_count = 0
224+
worker_config = None
225+
strategy = models.get_distribution_strategy(chief_config,
226+
worker_count,
227+
worker_config)
228+
self.assertIsInstance(strategy, tf.distribute.MirroredStrategy)
229+
230+
def test_get_distribution_strategy_one_device(self):
231+
chief_config = machine_config.COMMON_MACHINE_CONFIGS['K80_1X']
232+
worker_count = 0
233+
worker_config = None
234+
strategy = models.get_distribution_strategy(chief_config,
235+
worker_count,
236+
worker_config)
237+
self.assertIsInstance(strategy, tf.distribute.OneDeviceStrategy)
238+
146239

147240
if __name__ == '__main__':
148241
absltest.main()

0 commit comments

Comments
 (0)