Skip to content

Commit 4866bbd

Browse files
MoFHekarhdong
authored andcommitted
[feat] Competible with TF 2.13 by adding new absl status error type.
Also adapt to the new TF python level api. In particular, HKV is also compatible with TF2.10/2.11.
1 parent 51c188d commit 4866bbd

File tree

12 files changed

+179
-110
lines changed

12 files changed

+179
-110
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,11 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
653653
size_t last_hint_size_;
654654
size_t runtime_dim_;
655655
mutable mutex mu_;
656+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
657+
gpu::TableWrapperBase<K, V>* table_ = nullptr TF_GUARDED_BY(mu_);
658+
#else
656659
gpu::TableWrapperBase<K, V>* table_ = nullptr GUARDED_BY(mu_);
660+
#endif
657661
};
658662

659663
} // namespace lookup

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616

1717
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
1818
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h"
19-
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
2019

2120
#define EIGEN_USE_GPU
2221

@@ -37,7 +36,11 @@ limitations under the License.
3736
#include "tensorflow/core/util/env_var.h"
3837
#include "tensorflow/core/util/gpu_device_functions.h"
3938
#include "tensorflow/core/util/gpu_kernel_helper.h"
39+
#if TF_VERSION_INTEGER >= 2110 // 2.11.0
40+
#include "tensorflow/compiler/xla/stream_executor/stream.h"
41+
#else
4042
#include "tensorflow/stream_executor/stream.h"
43+
#endif
4144

4245
namespace tensorflow {
4346

@@ -187,14 +190,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
187190
is_full_default);
188191
CUDA_CHECK(cudaStreamSynchronize(stream));
189192
} catch (std::runtime_error& e) {
190-
return Status(tensorflow::error::INTERNAL, e.what());
193+
return gpu::ReturnInternalErrorStatus(e.what());
191194
}
192195
}
193196
CUDA_CHECK(cudaFreeAsync(d_status, stream));
194197
CUDA_CHECK(cudaStreamSynchronize(stream));
195198
}
196199

197-
return Status::OK();
200+
return TFOkStatus;
198201
}
199202

200203
Status FindWithExists(OpKernelContext* ctx, const Tensor& d_keys,
@@ -222,13 +225,13 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
222225
(V*)(default_value.tensor_data().data()), stream,
223226
is_full_default);
224227
} catch (std::runtime_error& e) {
225-
return Status(tensorflow::error::INTERNAL, e.what());
228+
return gpu::ReturnInternalErrorStatus(e.what());
226229
}
227230
}
228231
CUDA_CHECK(cudaStreamSynchronize(stream));
229232
}
230233

231-
return Status::OK();
234+
return TFOkStatus;
232235
}
233236

234237
Status Insert(OpKernelContext* ctx, const Tensor& keys,
@@ -241,12 +244,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
241244
table_->upsert((const K*)keys.tensor_data().data(),
242245
(const V*)(values.tensor_data().data()), len, stream);
243246
} catch (std::runtime_error& e) {
244-
return Status(tensorflow::error::INTERNAL, e.what());
247+
return gpu::ReturnInternalErrorStatus(e.what());
245248
}
246249
}
247250
CUDA_CHECK(cudaStreamSynchronize(stream));
248251

249-
return Status::OK();
252+
return TFOkStatus;
250253
}
251254

252255
Status Accum(OpKernelContext* ctx, const Tensor& keys,
@@ -260,12 +263,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
260263
(const V*)(values_or_deltas.tensor_data().data()),
261264
(const bool*)exists.tensor_data().data(), len, stream);
262265
} catch (std::runtime_error& e) {
263-
return Status(tensorflow::error::INTERNAL, e.what());
266+
return gpu::ReturnInternalErrorStatus(e.what());
264267
}
265268
}
266269
CUDA_CHECK(cudaStreamSynchronize(stream));
267270

268-
return Status::OK();
271+
return TFOkStatus;
269272
}
270273

271274
Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
@@ -285,14 +288,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
285288
try {
286289
table_->remove((const K*)d_keys, len, stream);
287290
} catch (std::runtime_error& e) {
288-
return Status(tensorflow::error::INTERNAL, e.what());
291+
return gpu::ReturnInternalErrorStatus(e.what());
289292
}
290293
}
291294
CUDA_CHECK(cudaFreeAsync(d_keys, stream));
292295
CUDA_CHECK(cudaStreamSynchronize(stream));
293296
}
294297

295-
return Status::OK();
298+
return TFOkStatus;
296299
}
297300

298301
Status Clear(OpKernelContext* ctx) {
@@ -302,11 +305,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
302305
try {
303306
table_->clear(stream);
304307
} catch (std::runtime_error& e) {
305-
return Status(tensorflow::error::INTERNAL, e.what());
308+
return gpu::ReturnInternalErrorStatus(e.what());
306309
}
307310
}
308311
CUDA_CHECK(cudaStreamSynchronize(stream));
309-
return Status::OK();
312+
return TFOkStatus;
310313
}
311314

312315
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
@@ -345,7 +348,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
345348
table_->upsert((const K*)d_keys, (const V*)d_values, len, stream);
346349
CUDA_CHECK(cudaStreamSynchronize(stream));
347350
} catch (std::runtime_error& e) {
348-
return Status(tensorflow::error::INTERNAL, e.what());
351+
return gpu::ReturnInternalErrorStatus(e.what());
349352
}
350353
}
351354
if (keys_attr.type != cudaMemoryTypeDevice) {
@@ -355,7 +358,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
355358
CUDA_CHECK(cudaFree(d_values));
356359
}
357360
}
358-
return Status::OK();
361+
return TFOkStatus;
359362
}
360363

361364
Status ExportValues(OpKernelContext* ctx) override {
@@ -397,12 +400,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
397400
d_dump_counter, stream);
398401
CUDA_CHECK(cudaStreamSynchronize(stream));
399402
} catch (std::runtime_error& e) {
400-
return Status(tensorflow::error::INTERNAL, e.what());
403+
return gpu::ReturnInternalErrorStatus(e.what());
401404
}
402405
}
403406
CUDA_CHECK(cudaFreeAsync(d_dump_counter, stream));
404407
CUDA_CHECK(cudaStreamSynchronize(stream));
405-
return Status::OK();
408+
return TFOkStatus;
406409
}
407410

408411
Status ExportValuesWithScores(OpKernelContext* ctx) {
@@ -448,12 +451,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
448451
len, d_dump_counter, stream);
449452
CUDA_CHECK(cudaStreamSynchronize(stream));
450453
} catch (std::runtime_error& e) {
451-
return Status(tensorflow::error::INTERNAL, e.what());
454+
return gpu::ReturnInternalErrorStatus(e.what());
452455
}
453456
}
454457
CUDA_CHECK(cudaFreeAsync(d_dump_counter, stream));
455458
CUDA_CHECK(cudaStreamSynchronize(stream));
456-
return Status::OK();
459+
return TFOkStatus;
457460
}
458461

459462
Status ExportKeysAndScores(OpKernelContext* ctx, size_t split_size) {
@@ -486,12 +489,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
486489
static_cast<size_t>(size), split_size,
487490
stream);
488491
} catch (std::runtime_error& e) {
489-
return Status(tensorflow::error::INTERNAL, e.what());
492+
return gpu::ReturnInternalErrorStatus(e.what());
490493
}
491494
}
492495
}
493496
CUDA_CHECK(cudaStreamSynchronize(stream));
494-
return Status::OK();
497+
return TFOkStatus;
495498
}
496499

497500
Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
@@ -507,12 +510,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
507510
table_->dump_to_file(fs, filepath, runtime_dim_, stream, buffer_size,
508511
append_to_file);
509512
} catch (std::runtime_error& e) {
510-
return Status(tensorflow::error::INTERNAL, e.what());
513+
return gpu::ReturnInternalErrorStatus(e.what());
511514
}
512515
}
513516
CUDA_CHECK(cudaStreamSynchronize(stream));
514517

515-
return Status::OK();
518+
return TFOkStatus;
516519
}
517520

518521
Status ImportValuesFromFile(OpKernelContext* ctx, const string& dirpath,
@@ -564,11 +567,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
564567
buffer_size);
565568
}
566569
} catch (std::runtime_error& e) {
567-
return Status(tensorflow::error::INTERNAL, e.what());
570+
return gpu::ReturnInternalErrorStatus(e.what());
568571
}
569572
}
570573
CUDA_CHECK(cudaStreamSynchronize(stream));
571-
return Status::OK();
574+
return TFOkStatus;
572575
}
573576

574577
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
@@ -580,7 +583,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
580583
TensorShape value_shape_;
581584
size_t runtime_dim_;
582585
mutable mutex mu_;
586+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
587+
gpu::TableWrapper<K, V>* table_ = nullptr TF_GUARDED_BY(mu_);
588+
#else
583589
gpu::TableWrapper<K, V>* table_ = nullptr GUARDED_BY(mu_);
590+
#endif
584591
};
585592

586593
} // namespace lookup

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,21 @@ limitations under the License.
4343
#include "tensorflow/core/lib/io/random_inputstream.h"
4444
#include "tensorflow/core/platform/macros.h"
4545
#include "tensorflow/core/platform/thread_annotations.h"
46+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
4647

4748
namespace tensorflow {
4849
namespace recommenders_addons {
4950
namespace lookup {
5051
namespace gpu {
5152

53+
inline Status ReturnInternalErrorStatus(const char* const str) {
54+
#if TF_VERSION_INTEGER >= 2130 /* 2.13.0 */
55+
return Status(absl::StatusCode::kInternal, str);
56+
#else
57+
return Status(tensorflow::error::INTERNAL, str);
58+
#endif
59+
}
60+
5261
template <typename K, typename V, typename S>
5362
class KVOnlyFile : public nv::merlin::BaseKVFile<K, V, S> {
5463
public:
@@ -173,7 +182,7 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
173182
auto has_atomic_move_ret =
174183
fs_->HasAtomicMove(filepath_, &has_atomic_move);
175184
bool need_tmp_file =
176-
(has_atomic_move == false) || (has_atomic_move_ret != Status::OK());
185+
(has_atomic_move == false) || (has_atomic_move_ret != TFOkStatus);
177186

178187
if (!need_tmp_file) {
179188
key_tmpfilepath = key_filepath;
@@ -193,7 +202,7 @@ class RandomKVFile : public nv::merlin::BaseKVFile<K, V, S> {
193202
fs_->NewWritableFile(value_tmpfilepath, &value_writer_));
194203
}
195204
}
196-
return Status::OK();
205+
return TFOkStatus;
197206
}
198207

199208
void close() {
@@ -445,9 +454,9 @@ class TableWrapper {
445454
try {
446455
table_->init(mkv_options_, allocator);
447456
} catch (std::runtime_error& e) {
448-
return Status(tensorflow::error::INTERNAL, e.what());
457+
return ReturnInternalErrorStatus(e.what());
449458
}
450-
return Status::OK();
459+
return TFOkStatus;
451460
}
452461

453462
~TableWrapper() { delete table_; }
@@ -534,7 +543,7 @@ class TableWrapper {
534543
string valuefile = filepath + "-values";
535544
string scorefile = filepath + "-scores";
536545
bool has_scores = false;
537-
Status status = Status::OK();
546+
Status status = TFOkStatus;
538547

539548
if (is_valid_scores(keyfile, scorefile)) {
540549
wfile.reset(new nv::merlin::LocalKVFile<K, V, uint64_t>);
@@ -585,7 +594,7 @@ class TableWrapper {
585594
string valuefile = filepath + "-values";
586595
string scorefile = filepath + "-scores";
587596
bool has_scores = false;
588-
Status status = Status::OK();
597+
Status status = TFOkStatus;
589598

590599
if (is_valid_scores(keyfile, scorefile)) {
591600
rfile.reset(new nv::merlin::LocalKVFile<K, V, uint64_t>);

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_cluster_connection_pool.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,22 @@ class RedisWrapper<RedisInstance, K, V,
144144
if (this->isRedisConnect == false) {
145145
LOG(ERROR) << "Can not connect to the Redis Cluster servers.";
146146
if (redis_conn_read == nullptr && redis_conn_write != nullptr) {
147+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
148+
return Status(absl::StatusCode::kUnavailable,
149+
"Can not access Redis Slave servers, Exit without any "
150+
"Redis connection.");
151+
#else
147152
return Status(error::UNAVAILABLE,
148153
"Can not access Redis Slave servers, Exit without any "
149154
"Redis connection.");
155+
#endif
150156
}
157+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
158+
return Status(absl::StatusCode::kUnavailable,
159+
"Exit without any Redis connection.");
160+
#else
151161
return Status(error::UNAVAILABLE, "Exit without any Redis connection.");
162+
#endif
152163
}
153164
}
154165
return TFOkStatus;

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_connection_pool.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,22 @@ class RedisWrapper<
198198
if (this->isRedisConnect == false) {
199199
LOG(ERROR) << "Can not connect to the Redis Master servers.";
200200
if (redis_conn_read == nullptr && redis_conn_write != nullptr) {
201+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
202+
return Status(absl::StatusCode::kUnavailable,
203+
"Can not access Redis Slave servers, Exit without any "
204+
"Redis connection.");
205+
#else
201206
return Status(error::UNAVAILABLE,
202-
"Can not access Redis Slave service, Exit without any "
207+
"Can not access Redis Slave servers, Exit without any "
203208
"Redis connection.");
209+
#endif
204210
}
211+
#if TF_VERSION_INTEGER >= 2130 // 2.13.0
212+
return Status(absl::StatusCode::kUnavailable,
213+
"Exit without any Redis connection.");
214+
#else
205215
return Status(error::UNAVAILABLE, "Exit without any Redis connection.");
216+
#endif
206217
}
207218
}
208219
return TFOkStatus;

0 commit comments

Comments
 (0)