|
20 | 20 | import re
|
21 | 21 |
|
22 | 22 | from tensorflow_recommenders_addons import dynamic_embedding as de
|
| 23 | +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import TrainableWrapper, DEResourceVariable |
23 | 24 | from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable \
|
24 | 25 | import load_de_variable_from_file_system
|
25 | 26 |
|
@@ -293,6 +294,20 @@ def _get_dynamic_embedding_restore_ops(self):
|
293 | 294 | return control_flow_ops.group(restore_ops.as_list())
|
294 | 295 |
|
295 | 296 | def _build(self, checkpoint_path, build_save, build_restore):
|
| 297 | + # TrainableWrapper and DEResourceVariable should not be save or restore parameter. |
| 298 | + filter_lambda = lambda x: (isinstance(x, TrainableWrapper)) or (isinstance( |
| 299 | + x, DEResourceVariable)) |
| 300 | + if isinstance(self._var_list, dict): |
| 301 | + for key, value in self._var_list.items(): |
| 302 | + if filter_lambda(value): |
| 303 | + self._var_list.pop(key) |
| 304 | + elif isinstance(self._var_list, list): |
| 305 | + _tmp_var_list = [] |
| 306 | + for value in self._var_list: |
| 307 | + if not filter_lambda(value): |
| 308 | + _tmp_var_list.append(value) |
| 309 | + self._var_list = _tmp_var_list |
| 310 | + |
296 | 311 | super(_DynamicEmbeddingSaver, self)._build(checkpoint_path, build_save,
|
297 | 312 | build_restore)
|
298 | 313 |
|
|
0 commit comments