|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | """Tests for official.core.train_utils."""
|
| 16 | +import json |
16 | 17 | import os
|
17 | 18 | import pprint
|
18 | 19 |
|
@@ -138,5 +139,60 @@ def test_construct_experiment_from_flags(self):
|
138 | 139 | self.assertEqual(params_from_obj.trainer.validation_steps, 11)
|
139 | 140 |
|
140 | 141 |
|
| 142 | +class BestCheckpointExporterTest(tf.test.TestCase): |
| 143 | + |
| 144 | + def test_maybe_export(self): |
| 145 | + model_dir = self.create_tempdir().full_path |
| 146 | + best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1') |
| 147 | + metric_name = 'test_metric|metric_1' |
| 148 | + exporter = train_utils.BestCheckpointExporter( |
| 149 | + model_dir, metric_name, 'higher') |
| 150 | + v = tf.Variable(1.0) |
| 151 | + checkpoint = tf.train.Checkpoint(v=v) |
| 152 | + ret = exporter.maybe_export_checkpoint( |
| 153 | + checkpoint, {'test_metric': {'metric_1': 5.0}}, 100) |
| 154 | + with self.subTest(name='Successful first save.'): |
| 155 | + self.assertEqual(ret, True) |
| 156 | + v_2 = tf.Variable(2.0) |
| 157 | + checkpoint_2 = tf.train.Checkpoint(v=v_2) |
| 158 | + checkpoint_2.restore(best_ckpt_path) |
| 159 | + self.assertEqual(v_2.numpy(), 1.0) |
| 160 | + |
| 161 | + v = tf.Variable(3.0) |
| 162 | + checkpoint = tf.train.Checkpoint(v=v) |
| 163 | + ret = exporter.maybe_export_checkpoint( |
| 164 | + checkpoint, {'test_metric': {'metric_1': 6.0}}, 200) |
| 165 | + with self.subTest(name='Successful better metic save.'): |
| 166 | + self.assertEqual(ret, True) |
| 167 | + v_2 = tf.Variable(2.0) |
| 168 | + checkpoint_2 = tf.train.Checkpoint(v=v_2) |
| 169 | + checkpoint_2.restore(best_ckpt_path) |
| 170 | + self.assertEqual(v_2.numpy(), 3.0) |
| 171 | + |
| 172 | + v = tf.Variable(5.0) |
| 173 | + checkpoint = tf.train.Checkpoint(v=v) |
| 174 | + ret = exporter.maybe_export_checkpoint( |
| 175 | + checkpoint, {'test_metric': {'metric_1': 1.0}}, 300) |
| 176 | + with self.subTest(name='Worse metic no save.'): |
| 177 | + self.assertEqual(ret, False) |
| 178 | + v_2 = tf.Variable(2.0) |
| 179 | + checkpoint_2 = tf.train.Checkpoint(v=v_2) |
| 180 | + checkpoint_2.restore(best_ckpt_path) |
| 181 | + self.assertEqual(v_2.numpy(), 3.0) |
| 182 | + |
| 183 | + def test_export_best_eval_metric(self): |
| 184 | + model_dir = self.create_tempdir().full_path |
| 185 | + metric_name = 'test_metric|metric_1' |
| 186 | + exporter = train_utils.BestCheckpointExporter(model_dir, metric_name, |
| 187 | + 'higher') |
| 188 | + exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100) |
| 189 | + with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'), |
| 190 | + 'rb') as reader: |
| 191 | + metric = json.loads(reader.read()) |
| 192 | + self.assertAllEqual( |
| 193 | + metric, |
| 194 | + {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0}) |
| 195 | + |
| 196 | + |
141 | 197 | if __name__ == '__main__':
|
142 | 198 | tf.test.main()
|
0 commit comments