@@ -16,7 +16,6 @@ limitations under the License.
16
16
17
17
#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
18
18
#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h"
19
- #include " tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
20
19
21
20
#define EIGEN_USE_GPU
22
21
@@ -37,7 +36,11 @@ limitations under the License.
37
36
#include " tensorflow/core/util/env_var.h"
38
37
#include " tensorflow/core/util/gpu_device_functions.h"
39
38
#include " tensorflow/core/util/gpu_kernel_helper.h"
39
+ #if TF_VERSION_INTEGER >= 2110 // 2.11.0
40
+ #include " tensorflow/compiler/xla/stream_executor/stream.h"
41
+ #else
40
42
#include " tensorflow/stream_executor/stream.h"
43
+ #endif
41
44
42
45
namespace tensorflow {
43
46
@@ -187,14 +190,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
187
190
is_full_default);
188
191
CUDA_CHECK (cudaStreamSynchronize (stream));
189
192
} catch (std::runtime_error& e) {
190
- return Status (tensorflow::error::INTERNAL, e.what ());
193
+ return gpu::ReturnInternalErrorStatus ( e.what ());
191
194
}
192
195
}
193
196
CUDA_CHECK (cudaFreeAsync (d_status, stream));
194
197
CUDA_CHECK (cudaStreamSynchronize (stream));
195
198
}
196
199
197
- return Status::OK () ;
200
+ return TFOkStatus ;
198
201
}
199
202
200
203
Status FindWithExists (OpKernelContext* ctx, const Tensor& d_keys,
@@ -222,13 +225,13 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
222
225
(V*)(default_value.tensor_data ().data ()), stream,
223
226
is_full_default);
224
227
} catch (std::runtime_error& e) {
225
- return Status (tensorflow::error::INTERNAL, e.what ());
228
+ return gpu::ReturnInternalErrorStatus ( e.what ());
226
229
}
227
230
}
228
231
CUDA_CHECK (cudaStreamSynchronize (stream));
229
232
}
230
233
231
- return Status::OK () ;
234
+ return TFOkStatus ;
232
235
}
233
236
234
237
Status Insert (OpKernelContext* ctx, const Tensor& keys,
@@ -241,12 +244,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
241
244
table_->upsert ((const K*)keys.tensor_data ().data (),
242
245
(const V*)(values.tensor_data ().data ()), len, stream);
243
246
} catch (std::runtime_error& e) {
244
- return Status (tensorflow::error::INTERNAL, e.what ());
247
+ return gpu::ReturnInternalErrorStatus ( e.what ());
245
248
}
246
249
}
247
250
CUDA_CHECK (cudaStreamSynchronize (stream));
248
251
249
- return Status::OK () ;
252
+ return TFOkStatus ;
250
253
}
251
254
252
255
Status Accum (OpKernelContext* ctx, const Tensor& keys,
@@ -260,12 +263,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
260
263
(const V*)(values_or_deltas.tensor_data ().data ()),
261
264
(const bool *)exists.tensor_data ().data (), len, stream);
262
265
} catch (std::runtime_error& e) {
263
- return Status (tensorflow::error::INTERNAL, e.what ());
266
+ return gpu::ReturnInternalErrorStatus ( e.what ());
264
267
}
265
268
}
266
269
CUDA_CHECK (cudaStreamSynchronize (stream));
267
270
268
- return Status::OK () ;
271
+ return TFOkStatus ;
269
272
}
270
273
271
274
Status Remove (OpKernelContext* ctx, const Tensor& keys) override {
@@ -285,14 +288,14 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
285
288
try {
286
289
table_->remove ((const K*)d_keys, len, stream);
287
290
} catch (std::runtime_error& e) {
288
- return Status (tensorflow::error::INTERNAL, e.what ());
291
+ return gpu::ReturnInternalErrorStatus ( e.what ());
289
292
}
290
293
}
291
294
CUDA_CHECK (cudaFreeAsync (d_keys, stream));
292
295
CUDA_CHECK (cudaStreamSynchronize (stream));
293
296
}
294
297
295
- return Status::OK () ;
298
+ return TFOkStatus ;
296
299
}
297
300
298
301
Status Clear (OpKernelContext* ctx) {
@@ -302,11 +305,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
302
305
try {
303
306
table_->clear (stream);
304
307
} catch (std::runtime_error& e) {
305
- return Status (tensorflow::error::INTERNAL, e.what ());
308
+ return gpu::ReturnInternalErrorStatus ( e.what ());
306
309
}
307
310
}
308
311
CUDA_CHECK (cudaStreamSynchronize (stream));
309
- return Status::OK () ;
312
+ return TFOkStatus ;
310
313
}
311
314
312
315
Status ImportValues (OpKernelContext* ctx, const Tensor& keys,
@@ -345,7 +348,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
345
348
table_->upsert ((const K*)d_keys, (const V*)d_values, len, stream);
346
349
CUDA_CHECK (cudaStreamSynchronize (stream));
347
350
} catch (std::runtime_error& e) {
348
- return Status (tensorflow::error::INTERNAL, e.what ());
351
+ return gpu::ReturnInternalErrorStatus ( e.what ());
349
352
}
350
353
}
351
354
if (keys_attr.type != cudaMemoryTypeDevice) {
@@ -355,7 +358,7 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
355
358
CUDA_CHECK (cudaFree (d_values));
356
359
}
357
360
}
358
- return Status::OK () ;
361
+ return TFOkStatus ;
359
362
}
360
363
361
364
Status ExportValues (OpKernelContext* ctx) override {
@@ -397,12 +400,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
397
400
d_dump_counter, stream);
398
401
CUDA_CHECK (cudaStreamSynchronize (stream));
399
402
} catch (std::runtime_error& e) {
400
- return Status (tensorflow::error::INTERNAL, e.what ());
403
+ return gpu::ReturnInternalErrorStatus ( e.what ());
401
404
}
402
405
}
403
406
CUDA_CHECK (cudaFreeAsync (d_dump_counter, stream));
404
407
CUDA_CHECK (cudaStreamSynchronize (stream));
405
- return Status::OK () ;
408
+ return TFOkStatus ;
406
409
}
407
410
408
411
Status ExportValuesWithScores (OpKernelContext* ctx) {
@@ -448,12 +451,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
448
451
len, d_dump_counter, stream);
449
452
CUDA_CHECK (cudaStreamSynchronize (stream));
450
453
} catch (std::runtime_error& e) {
451
- return Status (tensorflow::error::INTERNAL, e.what ());
454
+ return gpu::ReturnInternalErrorStatus ( e.what ());
452
455
}
453
456
}
454
457
CUDA_CHECK (cudaFreeAsync (d_dump_counter, stream));
455
458
CUDA_CHECK (cudaStreamSynchronize (stream));
456
- return Status::OK () ;
459
+ return TFOkStatus ;
457
460
}
458
461
459
462
Status ExportKeysAndScores (OpKernelContext* ctx, size_t split_size) {
@@ -486,12 +489,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
486
489
static_cast <size_t >(size), split_size,
487
490
stream);
488
491
} catch (std::runtime_error& e) {
489
- return Status (tensorflow::error::INTERNAL, e.what ());
492
+ return gpu::ReturnInternalErrorStatus ( e.what ());
490
493
}
491
494
}
492
495
}
493
496
CUDA_CHECK (cudaStreamSynchronize (stream));
494
- return Status::OK () ;
497
+ return TFOkStatus ;
495
498
}
496
499
497
500
Status ExportValuesToFile (OpKernelContext* ctx, const string filepath,
@@ -507,12 +510,12 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
507
510
table_->dump_to_file (fs, filepath, runtime_dim_, stream, buffer_size,
508
511
append_to_file);
509
512
} catch (std::runtime_error& e) {
510
- return Status (tensorflow::error::INTERNAL, e.what ());
513
+ return gpu::ReturnInternalErrorStatus ( e.what ());
511
514
}
512
515
}
513
516
CUDA_CHECK (cudaStreamSynchronize (stream));
514
517
515
- return Status::OK () ;
518
+ return TFOkStatus ;
516
519
}
517
520
518
521
Status ImportValuesFromFile (OpKernelContext* ctx, const string& dirpath,
@@ -564,11 +567,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
564
567
buffer_size);
565
568
}
566
569
} catch (std::runtime_error& e) {
567
- return Status (tensorflow::error::INTERNAL, e.what ());
570
+ return gpu::ReturnInternalErrorStatus ( e.what ());
568
571
}
569
572
}
570
573
CUDA_CHECK (cudaStreamSynchronize (stream));
571
- return Status::OK () ;
574
+ return TFOkStatus ;
572
575
}
573
576
574
577
DataType key_dtype () const override { return DataTypeToEnum<K>::v (); }
@@ -580,7 +583,11 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
580
583
TensorShape value_shape_;
581
584
size_t runtime_dim_;
582
585
mutable mutex mu_;
586
+ #if TF_VERSION_INTEGER >= 2130 // 2.13.0
587
+ gpu::TableWrapper<K, V>* table_ = nullptr TF_GUARDED_BY (mu_);
588
+ #else
583
589
gpu::TableWrapper<K, V>* table_ = nullptr GUARDED_BY (mu_);
590
+ #endif
584
591
};
585
592
586
593
} // namespace lookup
0 commit comments