Skip to content

Commit 095db80

Browse files
authored
[Feat] Compatible with TensorFlow 2.15 (#386)
* [fix] check wrong array index in kvs_reply from Redis returning. Wanted to check the length of value, but the length of key was checked. * [fix] json would not printed when a wrong json value type was set. * [fix] some test should not be run when eager mode. Also add more new test for Redis KV backend. * [feat] Competible with Tensorflow 2.15
1 parent 8b071de commit 095db80

File tree

12 files changed

+1850
-63
lines changed

12 files changed

+1850
-63
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_table_op_util.hpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -273,32 +273,32 @@ Status ParseJsonConfig(const std::string *const redis_config_abs_dir,
273273
} \
274274
}
275275

276-
#define ReadArrayJsonToParams(json_key_name, json_val_type) \
277-
{ \
278-
json_hangar_it = json_hangar.find(#json_key_name); \
279-
if (json_hangar_it != json_hangar.end()) { \
280-
if (json_hangar_it->second->type == json_array) { \
281-
redis_connection_params->json_key_name.clear(); \
282-
for (unsigned i = 0; i < json_hangar_it->second->u.array.length; \
283-
++i) { \
284-
value_depth1 = json_hangar_it->second->u.array.values[i]; \
285-
if (value_depth1->type == json_##json_val_type) { \
286-
redis_connection_params->redis_host_port.push_back( \
287-
value_depth1->u.json_val_type); \
288-
} else { \
289-
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
290-
" array"; \
291-
return ReturnInvalidArgumentStatus( \
292-
" should be json " #json_val_type " array"); \
293-
} \
294-
} \
295-
} else { \
296-
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
297-
" array"; \
298-
return ReturnInvalidArgumentStatus(" should be json " #json_val_type \
299-
" array"); \
300-
} \
301-
} \
276+
#define ReadArrayJsonToParams(json_key_name, json_val_type) \
277+
{ \
278+
json_hangar_it = json_hangar.find(#json_key_name); \
279+
if (json_hangar_it != json_hangar.end()) { \
280+
if (json_hangar_it->second->type == json_array) { \
281+
redis_connection_params->json_key_name.clear(); \
282+
for (unsigned i = 0; i < json_hangar_it->second->u.array.length; \
283+
++i) { \
284+
value_depth1 = json_hangar_it->second->u.array.values[i]; \
285+
if (value_depth1->type == json_##json_val_type) { \
286+
redis_connection_params->redis_host_port.push_back( \
287+
value_depth1->u.json_val_type); \
288+
} else { \
289+
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
290+
" array"; \
291+
return ReturnInvalidArgumentStatus( \
292+
#json_key_name " should be json " #json_val_type " array"); \
293+
} \
294+
} \
295+
} else { \
296+
LOG(ERROR) << #json_key_name " should be json " #json_val_type \
297+
" array"; \
298+
return ReturnInvalidArgumentStatus( \
299+
#json_key_name " should be json " #json_val_type " array"); \
300+
} \
301+
} \
302302
}
303303

304304
#define ReadStringArrayJsonToParams(json_key_name) \

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,12 @@ class RedisTableOfTensors final : public LookupInterface {
599599
hscan_reply->elements > 1) {
600600
kvs_reply = hscan_reply->element[1];
601601
// fill Tensor keys and values
602+
if (kvs_reply->elements < 2 && cursor == 0) {
603+
// Find nothing in Redis
604+
break;
605+
}
602606
if constexpr (!std::is_same<V, tstring>::value) {
603-
if (kvs_reply->element[0]->len !=
607+
if (kvs_reply->element[1]->len !=
604608
runtime_value_dim_ * sizeof(V)) {
605609
return errors::InvalidArgument(
606610
"Embedding dim in Redis server is not equal to the OP "
@@ -1035,8 +1039,12 @@ class RedisTableOfTensors final : public LookupInterface {
10351039
}
10361040
kvs_reply = hscan_reply->element[1];
10371041
// fill Tensor keys and values
1042+
if (kvs_reply->elements < 2 && cursor == 0) {
1043+
// Find nothing in Redis
1044+
break;
1045+
}
10381046
if constexpr (!std::is_same<V, tstring>::value) {
1039-
if (kvs_reply->element[0]->len != runtime_value_dim_ * sizeof(V)) {
1047+
if (kvs_reply->element[1]->len != runtime_value_dim_ * sizeof(V)) {
10401048
return errors::InvalidArgument(
10411049
"Embedding dim in Redis server is not equal to the OP runtime "
10421050
"dim.");
@@ -1146,8 +1154,12 @@ class RedisTableOfTensors final : public LookupInterface {
11461154
}
11471155
kvs_reply = hscan_reply->element[1];
11481156
// fill Tensor keys and values
1157+
if (kvs_reply->elements < 2 && cursor == 0) {
1158+
// Find nothing in Redis
1159+
break;
1160+
}
11491161
if constexpr (!std::is_same<V, tstring>::value) {
1150-
if (kvs_reply->element[0]->len != runtime_value_dim_ * sizeof(V)) {
1162+
if (kvs_reply->element[1]->len != runtime_value_dim_ * sizeof(V)) {
11511163
return errors::InvalidArgument(
11521164
"Embedding dim in Redis server is not equal to the OP runtime "
11531165
"dim.");

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
from tensorflow.python.distribute import distribute_lib
3131
from tensorflow.python.keras.utils import tf_utils
32-
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
32+
try: # tf version >= 2.14.0
33+
from tensorflow.python.distribute import distribute_lib as distribute_ctx
34+
assert hasattr(distribute_ctx, 'has_strategy')
35+
except:
36+
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
3337
from tensorflow.python.distribute import values_util
3438
from tensorflow.python.framework import ops
3539
from tensorflow.python.eager import tape

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_ops_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
from tensorflow.python.ops import script_ops
5555
from tensorflow.python.ops import variables
5656
from tensorflow.python.ops import variable_scope
57+
try: # tf version >= 2.14.0
58+
from tensorflow.python.ops.array_ops_stack import stack
59+
except:
60+
from tensorflow.python.ops.array_ops import stack
5761
from tensorflow.python.platform import test
5862
from tensorflow.python.training import device_setter
5963
from tensorflow.python.training import server_lib
@@ -309,7 +313,7 @@ def test_max_norm_nontrivial(self):
309313
embedding = de.embedding_lookup(embeddings, ids, max_norm=2.0)
310314
norms = math_ops.sqrt(
311315
math_ops.reduce_sum(embedding_no_norm * embedding_no_norm, axis=1))
312-
normalized = embedding_no_norm / array_ops.stack([norms, norms], axis=1)
316+
normalized = embedding_no_norm / stack([norms, norms], axis=1)
313317
self.assertAllCloseAccordingToType(embedding.eval(),
314318
2 * self.evaluate(normalized))
315319

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@
5757
from tensorflow.python.training import saver
5858
from tensorflow.python.training import server_lib
5959
from tensorflow.python.training import training
60-
from tensorflow.python.training.tracking import util as track_util
60+
try: # tf version >= 2.14.0
61+
from tensorflow.python.checkpoint.checkpoint import Checkpoint
62+
except:
63+
from tensorflow.python.training.tracking.util import Checkpoint
6164
from tensorflow.python.util import compat
6265
from tensorflow_estimator.python.estimator import estimator
6366
from tensorflow_estimator.python.estimator import estimator_lib
@@ -1326,13 +1329,13 @@ def _loss_fn():
13261329
*sorted(zip(keys1, vals1), key=lambda x: x[0], reverse=False))
13271330
slot_keys_and_vals1 = [sv.export() for sv in model1.slot_vars]
13281331

1329-
ckpt1 = track_util.Checkpoint(model=model1, optimizer=model1.optmz)
1332+
ckpt1 = Checkpoint(model=model1, optimizer=model1.optmz)
13301333
ckpt_dir = self.get_temp_dir()
13311334
model_path = ckpt1.save(ckpt_dir)
13321335
del model1
13331336

13341337
model2 = TestModel()
1335-
ckpt2 = track_util.Checkpoint(model=model2, optimizer=model2.optmz)
1338+
ckpt2 = Checkpoint(model=model2, optimizer=model2.optmz)
13361339
model2.train(features) # Pre-build trace before restore.
13371340
ckpt2.restore(model_path)
13381341
loss2 = model2(features)

0 commit comments

Comments
 (0)