Skip to content

Commit 9149973

Browse files
MoFHekarhdong
authored andcommitted
[fix] TrainableWrapper and DEResourceVariable should not be save or restore parameter when using tf.train.Saver.
1 parent 29df8b3 commit 9149973

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121

2222
from tensorflow_recommenders_addons import dynamic_embedding as de
23+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import TrainableWrapper, DEResourceVariable
2324
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable \
2425
import load_de_variable_from_file_system
2526

@@ -293,6 +294,20 @@ def _get_dynamic_embedding_restore_ops(self):
293294
return control_flow_ops.group(restore_ops.as_list())
294295

295296
def _build(self, checkpoint_path, build_save, build_restore):
297+
# TrainableWrapper and DEResourceVariable should not be save or restore parameter.
298+
filter_lambda = lambda x: (isinstance(x, TrainableWrapper)) or (isinstance(
299+
x, DEResourceVariable))
300+
if isinstance(self._var_list, dict):
301+
for key, value in self._var_list.items():
302+
if filter_lambda(value):
303+
self._var_list.pop(key)
304+
elif isinstance(self._var_list, list):
305+
_tmp_var_list = []
306+
for value in self._var_list:
307+
if not filter_lambda(value):
308+
_tmp_var_list.append(value)
309+
self._var_list = _tmp_var_list
310+
296311
super(_DynamicEmbeddingSaver, self)._build(checkpoint_path, build_save,
297312
build_restore)
298313

0 commit comments

Comments
 (0)