Skip to content

Commit 4b2e51d

Browse files
PWZERrhdong
authored andcommitted
add test case test_table_save_load_local_file_system_for_estimator
1 parent ac2533a commit 4b2e51d

File tree

2 files changed

+76
-21
lines changed

2 files changed

+76
-21
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
from tensorflow.python.training.tracking import data_structures
5959
from tensorflow.python.training.tracking import util as track_util
6060
from tensorflow.python.util import compat
61+
from tensorflow_estimator.python.estimator import estimator
62+
from tensorflow_estimator.python.estimator import estimator_lib
6163

6264
try:
6365
import tensorflow_io
@@ -958,6 +960,53 @@ def test_table_save_load_local_file_system(self):
958960

959961
del table
960962

963+
def test_table_save_load_local_file_system_for_estimator(self):
964+
965+
def input_fn():
966+
return {"x": constant_op.constant([1], dtype=dtypes.int64)}
967+
968+
def model_fn(features, labels, mode, params):
969+
file_system_saver = de.FileSystemSaver()
970+
embedding = de.get_variable(
971+
name="embedding",
972+
dim=3,
973+
trainable=False,
974+
key_dtype=dtypes.int64,
975+
value_dtype=dtypes.float32,
976+
initializer=-1.0,
977+
kv_creator=de.CuckooHashTableCreator(saver=file_system_saver),
978+
)
979+
lookup = de.embedding_lookup(embedding, features["x"])
980+
upsert = embedding.upsert(features["x"],
981+
constant_op.constant([[1.0, 2.0, 3.0]]))
982+
983+
with ops.control_dependencies([lookup, upsert]):
984+
train_op = state_ops.assign_add(training.get_global_step(), 1)
985+
986+
scaffold = training.Scaffold(
987+
saver=saver.Saver(sharded=True,
988+
max_to_keep=1,
989+
keep_checkpoint_every_n_hours=None,
990+
defer_build=True,
991+
save_relative_paths=True))
992+
est = estimator_lib.EstimatorSpec(mode=mode,
993+
scaffold=scaffold,
994+
loss=constant_op.constant(0.),
995+
train_op=train_op,
996+
predictions=lookup)
997+
return est
998+
999+
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
1000+
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
1001+
1002+
# train and save
1003+
est = estimator.Estimator(model_fn=model_fn, model_dir=save_path)
1004+
est.train(input_fn=input_fn, steps=1)
1005+
1006+
# restore and predict
1007+
predict_results = next(est.predict(input_fn=input_fn))
1008+
self.assertAllEqual(predict_results, [1.0, 2.0, 3.0])
1009+
9611010
def test_save_restore_only_table(self):
9621011
if context.executing_eagerly():
9631012
self.skipTest('skip eager test when using legacy Saver.')

tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -179,23 +179,26 @@ def _get_dynamic_embedding_save_ops(self):
179179

180180
for var in self._var_list:
181181
de_var = None
182-
if isinstance(var, (de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
183-
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
182+
if isinstance(
183+
var,
184+
(de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
185+
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
184186
de_var = var._de_variable
185187
elif isinstance(var, de.Variable) and var._saveable_object_creator:
186188
de_var = var
187189

188-
if de_var and isinstance(de_var._saveable_object_creator, de.FileSystemSaver):
190+
if de_var and isinstance(de_var._saveable_object_creator,
191+
de.FileSystemSaver):
189192
if de_var._saveable_object_creator.config.save_path:
190193
de_variable_folder_dir = de_var._saveable_object_creator.config.save_path
191194
else:
192195
de_variable_folder_dir = self._de_var_fs_save_dir
193196

194197
save_op = de_var.save_to_file_system(
195-
dirpath=de_variable_folder_dir,
196-
proc_size=de_var._saveable_object_creator.config.proc_size,
197-
proc_rank=de_var._saveable_object_creator.config.proc_rank,
198-
buffer_size=de_var._saveable_object_creator.config.buffer_size)
198+
dirpath=de_variable_folder_dir,
199+
proc_size=de_var._saveable_object_creator.config.proc_size,
200+
proc_rank=de_var._saveable_object_creator.config.proc_rank,
201+
buffer_size=de_var._saveable_object_creator.config.buffer_size)
199202
save_ops.as_list().append(save_op)
200203
return control_flow_ops.group(save_ops.as_list())
201204

@@ -206,33 +209,36 @@ def _get_dynamic_embedding_restore_ops(self):
206209

207210
for var in self._var_list:
208211
de_var = None
209-
if isinstance(var, (de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
210-
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
212+
if isinstance(
213+
var,
214+
(de.FileSystemSaver._DynamicEmbeddingShardFileSystemSaveable,
215+
de.FileSystemSaver._DynamicEmbeddingVariabelFileSystemSaveable)):
211216
de_var = var._de_variable
212217
elif isinstance(var, de.Variable) and var._saveable_object_creator:
213218
de_var = var
214219

215-
if de_var and isinstance(de_var._saveable_object_creator, de.FileSystemSaver):
220+
if de_var and isinstance(de_var._saveable_object_creator,
221+
de.FileSystemSaver):
216222
if de_var._saveable_object_creator.config.save_path:
217223
de_variable_folder_dir = de_var._saveable_object_creator.config.save_path
218224
else:
219225
de_variable_folder_dir = self._de_var_fs_save_dir
220226

221227
restore_op = de_var.load_from_file_system_with_restore_function(
222-
dirpath=de_variable_folder_dir,
223-
proc_size=de_var._saveable_object_creator.config.proc_size,
224-
proc_rank=de_var._saveable_object_creator.config.proc_rank,
225-
buffer_size=de_var._saveable_object_creator.config.buffer_size)
228+
dirpath=de_variable_folder_dir,
229+
proc_size=de_var._saveable_object_creator.config.proc_size,
230+
proc_rank=de_var._saveable_object_creator.config.proc_rank,
231+
buffer_size=de_var._saveable_object_creator.config.buffer_size)
226232
restore_ops.as_list().append(restore_op)
227233
return control_flow_ops.group(restore_ops.as_list())
228234

229235
def _build(self, checkpoint_path, build_save, build_restore):
230-
super(_DynamicEmbeddingSaver, self)._build(
231-
checkpoint_path, build_save, build_restore)
236+
super(_DynamicEmbeddingSaver, self)._build(checkpoint_path, build_save,
237+
build_restore)
232238

233239
with ops.name_scope("FileSystemSaver", "save_to_file_system", []) as name:
234240
self._de_var_fs_save_dir = array_ops.placeholder(
235-
dtype=dtypes.string, shape=(), name="de_var_file_system_save_dir")
241+
dtype=dtypes.string, shape=(), name="de_var_file_system_save_dir")
236242
self._de_save_ops = self._get_dynamic_embedding_save_ops()
237243
self._de_restore_ops = self._get_dynamic_embedding_restore_ops()
238244

@@ -337,14 +343,14 @@ def save(self,
337343

338344
if global_step is not None:
339345
de_variable_folder_dir = os.path.join(
340-
save_path_parent, "TFRADynamicEmbedding-{}".format(global_step))
346+
save_path_parent, "TFRADynamicEmbedding-{}".format(global_step))
341347
if self._pad_step_number:
342348
# Zero-pads the step numbers, so that they are sorted when listed.
343349
de_variable_folder_dir = os.path.join(
344-
save_path_parent, "TFRADynamicEmbedding-{:08d}".format(global_step))
350+
save_path_parent, "TFRADynamicEmbedding-{:08d}".format(global_step))
345351
else:
346-
de_variable_folder_dir = os.path.join(
347-
save_path_parent, "TFRADynamicEmbedding")
352+
de_variable_folder_dir = os.path.join(save_path_parent,
353+
"TFRADynamicEmbedding")
348354

349355
if not self._is_empty:
350356
try:

0 commit comments

Comments
 (0)