@@ -77,13 +77,28 @@ def setup_run_experiment(self):
77
77
mode = 'train_and_eval' ,
78
78
params = config ,
79
79
model_dir = 'model_path' )
80
-
80
+ self . model = mock . MagicMock ()
81
81
self .run_experiment = mock .patch .object (
82
82
train_lib ,
83
83
'run_experiment' ,
84
84
autospec = True ,
85
+ return_value = (self .model , {})
85
86
).start ()
86
87
88
+ def setup_tpu (self ):
89
+ mock .patch .object (tf .tpu .experimental ,
90
+ 'initialize_tpu_system' ,
91
+ autospec = True ).start ()
92
+ mock .patch .object (tf .config ,
93
+ 'experimental_connect_to_cluster' ,
94
+ autospec = True ).start ()
95
+ mock .patch ('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
96
+ ).start ()
97
+ mock_tpu_strategy = mock .MagicMock (
98
+ spec = tf .distribute .TPUStrategy )
99
+ mock .patch ('tensorflow.distribute.TPUStrategy' ,
100
+ return_value = mock_tpu_strategy ).start ()
101
+
87
102
def tearDown (self ):
88
103
mock .patch .stopall ()
89
104
super (ModelsTest , self ).tearDown ()
@@ -182,20 +197,8 @@ def test_run_experiment_cloud_remote(self):
182
197
self .remote .assert_called ()
183
198
self .run_experiment .assert_called ()
184
199
self .run .assert_called ()
185
-
186
- def setup_tpu (self ):
187
- mock .patch .object (tf .tpu .experimental ,
188
- 'initialize_tpu_system' ,
189
- autospec = True ).start ()
190
- mock .patch .object (tf .config ,
191
- 'experimental_connect_to_cluster' ,
192
- autospec = True ).start ()
193
- mock .patch ('tensorflow.distribute.cluster_resolver.TPUClusterResolver'
194
- ).start ()
195
- mock_tpu_strategy = mock .MagicMock ()
196
- mock_tpu_strategy .__class__ = tf .distribute .TPUStrategy
197
- mock .patch ('tensorflow.distribute.TPUStrategy' ,
198
- return_value = mock_tpu_strategy ).start ()
200
+ self .model .save .assert_called_with (
201
+ self .run_experiment_kwargs ['model_dir' ])
199
202
200
203
def test_get_distribution_strategy_tpu (self ):
201
204
tpu_srategy = tf .distribute .TPUStrategy
0 commit comments