18
18
import mock
19
19
20
20
import tensorflow as tf
21
+ from tensorflow_cloud .core import machine_config
21
22
from tensorflow_cloud .core import run
22
23
from tensorflow_cloud .core .experimental import models
24
+ from official .core import config_definitions
25
+ from official .core import train_lib
23
26
from official .vision .image_classification .efficientnet import efficientnet_model
24
27
25
28
@@ -41,12 +44,17 @@ def setup_normalize_img_and_label(self):
41
44
3 ]
42
45
self .label = tf .convert_to_tensor (4 )
43
46
44
- def setup_run_models (self , run_return_value = None , remote = True ):
47
+ def setup_run (self , remote = True ):
48
+ if remote :
49
+ self .run_return_value = None
50
+ else :
51
+ self .run_return_value = {'job_id' : 'job_id' ,
52
+ 'docker_image' : 'docker_image' }
45
53
self .run = mock .patch .object (
46
54
run ,
47
55
'run' ,
48
56
autospec = True ,
49
- return_value = run_return_value ,
57
+ return_value = self . run_return_value ,
50
58
).start ()
51
59
52
60
self .remote = mock .patch .object (
@@ -56,14 +64,29 @@ def setup_run_models(self, run_return_value=None, remote=True):
56
64
return_value = remote ,
57
65
).start ()
58
66
67
+ def setup_run_models (self ):
59
68
self .classifier_trainer = mock .patch .object (
60
69
models ,
61
70
'classifier_trainer' ,
62
71
autospec = True ,
63
72
).start ()
64
73
65
- def cleanup_run_models (self ):
74
+ def setup_run_experiment (self ):
75
+ config = config_definitions .ExperimentConfig ()
76
+ self .run_experiment_kwargs = dict (task = config .task ,
77
+ mode = 'train_and_eval' ,
78
+ params = config ,
79
+ model_dir = 'model_path' )
80
+
81
+ self .run_experiment = mock .patch .object (
82
+ train_lib ,
83
+ 'run_experiment' ,
84
+ autospec = True ,
85
+ ).start ()
86
+
87
+ def tearDown (self ):
66
88
mock .patch .stopall ()
89
+ super (ModelsTest , self ).tearDown ()
67
90
68
91
def test_get_model_resnet (self ):
69
92
self .setup_get_model ()
@@ -114,10 +137,8 @@ def test_normalize_image_and_label_with_one_hot(self):
114
137
self .assertTrue ((result_label == expected_label ).numpy ().all ())
115
138
116
139
def test_run_models_locally (self ):
117
- run_return = {'job_id' : 'job_id' ,
118
- 'docker_image' : 'docker_image' }
119
-
120
- self .setup_run_models (run_return , remote = False )
140
+ self .setup_run (remote = False )
141
+ self .setup_run_models ()
121
142
run_kwargs = {'entry_point' : 'entry_point' ,
122
143
'requirements_txt' : 'requirements_txt' ,
123
144
'worker_count' : 5 ,}
@@ -130,9 +151,8 @@ def test_run_models_locally(self):
130
151
'model_checkpoint' , 'save_model' ]
131
152
self .assertListEqual (list (result .keys ()), return_keys )
132
153
133
- self .cleanup_run_models ()
134
-
135
154
def test_run_models_remote (self ):
155
+ self .setup_run ()
136
156
self .setup_run_models ()
137
157
result = models .run_models ('dataset_name' , 'model_name' , 'gcs_bucket' ,
138
158
'train' )
@@ -142,7 +162,80 @@ def test_run_models_remote(self):
142
162
143
163
self .assertIsNone (result )
144
164
145
- self .cleanup_run_models ()
165
+ def test_run_experiment_cloud_locally (self ):
166
+ self .setup_run (remote = False )
167
+ self .setup_run_experiment ()
168
+ models .run_experiment_cloud (
169
+ run_experiment_kwargs = self .run_experiment_kwargs )
170
+
171
+ self .remote .assert_called ()
172
+ self .run_experiment .assert_not_called ()
173
+ self .run .assert_called ()
174
+
175
+ def test_run_experiment_cloud_remote (self ):
176
+ self .setup_run ()
177
+ self .setup_run_experiment ()
178
+ models .run_experiment_cloud (
179
+ run_experiment_kwargs = self .run_experiment_kwargs )
180
+
181
+ self .remote .assert_called ()
182
+ self .run_experiment .assert_called ()
183
+ self .run .assert_called ()
184
+
185
+ def setup_tpu (self ):
186
+ mock .patch .object (tf .tpu .experimental ,
187
+ 'initialize_tpu_system' ,
188
+ autospec = True ).start ()
189
+ mock .patch .object (tf .config ,
190
+ 'experimental_connect_to_cluster' ,
191
+ autospec = True ).start ()
192
+ mock .patch ('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
193
+ ).start ()
194
+ mock_tpu_strategy = mock .MagicMock ()
195
+ mock_tpu_strategy .__class__ = tf .distribute .TPUStrategy
196
+ mock .patch ('tensorflow.distribute.TPUStrategy' ,
197
+ return_value = mock_tpu_strategy ).start ()
198
+
199
+ def test_get_distribution_strategy_tpu (self ):
200
+ tpu_srategy = tf .distribute .TPUStrategy
201
+ self .setup_tpu ()
202
+ chief_config = None
203
+ worker_count = 1
204
+ worker_config = machine_config .COMMON_MACHINE_CONFIGS ['TPU' ]
205
+ strategy = models .get_distribution_strategy (chief_config ,
206
+ worker_count ,
207
+ worker_config )
208
+ self .assertIsInstance (strategy ,
209
+ tpu_srategy )
210
+
211
+ def test_get_distribution_strategy_multi_mirror (self ):
212
+ chief_config = None
213
+ worker_count = 1
214
+ worker_config = None
215
+ strategy = models .get_distribution_strategy (chief_config ,
216
+ worker_count ,
217
+ worker_config )
218
+ self .assertIsInstance (strategy ,
219
+ tf .distribute .MultiWorkerMirroredStrategy )
220
+
221
+ def test_get_distribution_strategy_mirror (self ):
222
+ chief_config = machine_config .COMMON_MACHINE_CONFIGS ['K80_4X' ]
223
+ worker_count = 0
224
+ worker_config = None
225
+ strategy = models .get_distribution_strategy (chief_config ,
226
+ worker_count ,
227
+ worker_config )
228
+ self .assertIsInstance (strategy , tf .distribute .MirroredStrategy )
229
+
230
+ def test_get_distribution_strategy_one_device (self ):
231
+ chief_config = machine_config .COMMON_MACHINE_CONFIGS ['K80_1X' ]
232
+ worker_count = 0
233
+ worker_config = None
234
+ strategy = models .get_distribution_strategy (chief_config ,
235
+ worker_count ,
236
+ worker_config )
237
+ self .assertIsInstance (strategy , tf .distribute .OneDeviceStrategy )
238
+
146
239
147
240
if __name__ == '__main__' :
148
241
absltest .main ()
0 commit comments