Skip to content

Commit 439d515

Browse files
Internal change
PiperOrigin-RevId: 417875109
1 parent e6337e3 commit 439d515

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

official/core/train_utils_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Tests for official.core.train_utils."""
16+
import json
1617
import os
1718
import pprint
1819

@@ -138,5 +139,60 @@ def test_construct_experiment_from_flags(self):
138139
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
139140

140141

142+
class BestCheckpointExporterTest(tf.test.TestCase):
143+
144+
def test_maybe_export(self):
145+
model_dir = self.create_tempdir().full_path
146+
best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
147+
metric_name = 'test_metric|metric_1'
148+
exporter = train_utils.BestCheckpointExporter(
149+
model_dir, metric_name, 'higher')
150+
v = tf.Variable(1.0)
151+
checkpoint = tf.train.Checkpoint(v=v)
152+
ret = exporter.maybe_export_checkpoint(
153+
checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
154+
with self.subTest(name='Successful first save.'):
155+
self.assertEqual(ret, True)
156+
v_2 = tf.Variable(2.0)
157+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
158+
checkpoint_2.restore(best_ckpt_path)
159+
self.assertEqual(v_2.numpy(), 1.0)
160+
161+
v = tf.Variable(3.0)
162+
checkpoint = tf.train.Checkpoint(v=v)
163+
ret = exporter.maybe_export_checkpoint(
164+
checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
165+
with self.subTest(name='Successful better metic save.'):
166+
self.assertEqual(ret, True)
167+
v_2 = tf.Variable(2.0)
168+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
169+
checkpoint_2.restore(best_ckpt_path)
170+
self.assertEqual(v_2.numpy(), 3.0)
171+
172+
v = tf.Variable(5.0)
173+
checkpoint = tf.train.Checkpoint(v=v)
174+
ret = exporter.maybe_export_checkpoint(
175+
checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
176+
with self.subTest(name='Worse metic no save.'):
177+
self.assertEqual(ret, False)
178+
v_2 = tf.Variable(2.0)
179+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
180+
checkpoint_2.restore(best_ckpt_path)
181+
self.assertEqual(v_2.numpy(), 3.0)
182+
183+
def test_export_best_eval_metric(self):
184+
model_dir = self.create_tempdir().full_path
185+
metric_name = 'test_metric|metric_1'
186+
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
187+
'higher')
188+
exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
189+
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
190+
'rb') as reader:
191+
metric = json.loads(reader.read())
192+
self.assertAllEqual(
193+
metric,
194+
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
195+
196+
141197
if __name__ == '__main__':
142198
tf.test.main()

0 commit comments

Comments
 (0)