@@ -15,6 +15,7 @@ limitations under the License.
15
15
#if GOOGLE_CUDA
16
16
17
17
#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
18
+
18
19
#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
19
20
#include " tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
20
21
@@ -214,10 +215,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
214
215
if (cur_size > 0 ) {
215
216
CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
216
217
CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
217
- CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
218
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
219
+ sizeof (V) * runtime_dim_ * cur_size));
218
220
table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
219
- d_dump_counter, stream);
220
- cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
221
+ d_dump_counter, stream);
222
+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ),
223
+ cudaMemcpyDeviceToHost, stream);
221
224
CUDA_CHECK (cudaStreamSynchronize (stream));
222
225
}
223
226
@@ -226,8 +229,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
226
229
CreateTable (new_max_size, &table_);
227
230
228
231
if (cur_size > 0 ) {
229
- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
230
- h_dump_counter, stream);
232
+ table_->upsert ((const K*)d_keys,
233
+ (const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
234
+ stream);
231
235
cudaStreamSynchronize (stream);
232
236
cudaFree (d_keys);
233
237
cudaFree (d_values);
@@ -387,6 +391,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
387
391
return Status::OK ();
388
392
}
389
393
394
+ Status ExportValuesToFile (OpKernelContext* ctx, const string filepath,
395
+ const size_t buffer_size) {
396
+ cudaStream_t _stream;
397
+ CUDA_CHECK (cudaStreamCreate (&_stream));
398
+
399
+ {
400
+ tf_shared_lock l (mu_);
401
+ table_->dump_to_file (ctx, filepath, runtime_dim_, _stream, buffer_size);
402
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
403
+ }
404
+ CUDA_CHECK (cudaStreamDestroy (_stream));
405
+ return Status::OK ();
406
+ }
407
+
408
+ Status ImportValuesFromFile (OpKernelContext* ctx, const string filepath,
409
+ const size_t buffer_size) {
410
+ cudaStream_t _stream;
411
+ CUDA_CHECK (cudaStreamCreate (&_stream));
412
+
413
+ {
414
+ tf_shared_lock l (mu_);
415
+
416
+ string keyfile = filepath + " .keys" ;
417
+ FILE* tmpfd = fopen (keyfile.c_str (), " rb" );
418
+ if (tmpfd == nullptr ) {
419
+ return errors::NotFound (" Failed to read key file" , keyfile);
420
+ }
421
+ fseek (tmpfd, 0 , SEEK_END);
422
+ long int filesize = ftell (tmpfd);
423
+ if (filesize <= 0 ) {
424
+ fclose (tmpfd);
425
+ return errors::NotFound (" Empty key file." );
426
+ }
427
+ size_t size = static_cast <size_t >(filesize) / sizeof (K);
428
+ fseek (tmpfd, 0 , SEEK_SET);
429
+ fclose (tmpfd);
430
+
431
+ table_->clear (_stream);
432
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
433
+ RehashIfNeeded (_stream, size);
434
+ table_->load_from_file (ctx, filepath, size, runtime_dim_, _stream,
435
+ buffer_size);
436
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
437
+ }
438
+ CUDA_CHECK (cudaStreamDestroy (_stream));
439
+ return Status::OK ();
440
+ }
441
+
390
442
DataType key_dtype () const override { return DataTypeToEnum<K>::v (); }
391
443
DataType value_dtype () const override { return DataTypeToEnum<V>::v (); }
392
444
TensorShape key_shape () const final { return TensorShape (); }
@@ -625,6 +677,36 @@ REGISTER_KERNEL_BUILDER(
625
677
Name (PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
626
678
HashTableExportGpuOp);
627
679
680
+ // Op that export all keys and values to file.
681
+ template <class K , class V >
682
+ class HashTableExportToFileGpuOp : public OpKernel {
683
+ public:
684
+ explicit HashTableExportToFileGpuOp (OpKernelConstruction* ctx)
685
+ : OpKernel(ctx) {
686
+ int64 signed_buffer_size = 0 ;
687
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
688
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
689
+ }
690
+
691
+ void Compute (OpKernelContext* ctx) override {
692
+ lookup::LookupInterface* table;
693
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
694
+ core::ScopedUnref unref_me (table);
695
+
696
+ const Tensor& ftensor = ctx->input (1 );
697
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
698
+ errors::InvalidArgument (" filepath must be scalar." ));
699
+ string filepath = string (ftensor.scalar <tstring>()().data ());
700
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
701
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
702
+ OP_REQUIRES_OK (
703
+ ctx, table_cuckoo->ExportValuesToFile (ctx, filepath, buffer_size_));
704
+ }
705
+
706
+ private:
707
+ size_t buffer_size_;
708
+ };
709
+
628
710
// Clear the table and insert data.
629
711
class HashTableImportGpuOp : public OpKernel {
630
712
public:
@@ -651,33 +733,76 @@ REGISTER_KERNEL_BUILDER(
651
733
Name (PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
652
734
HashTableImportGpuOp);
653
735
736
+ // Op that import from file.
737
+ template <class K , class V >
738
+ class HashTableImportFromFileGpuOp : public OpKernel {
739
+ public:
740
+ explicit HashTableImportFromFileGpuOp (OpKernelConstruction* ctx)
741
+ : OpKernel(ctx) {
742
+ int64 signed_buffer_size = 0 ;
743
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
744
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
745
+ }
746
+
747
+ void Compute (OpKernelContext* ctx) override {
748
+ lookup::LookupInterface* table;
749
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
750
+ core::ScopedUnref unref_me (table);
751
+
752
+ const Tensor& ftensor = ctx->input (1 );
753
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
754
+ errors::InvalidArgument (" filepath must be scalar." ));
755
+ string filepath = string (ftensor.scalar <tstring>()().data ());
756
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
757
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
758
+ OP_REQUIRES_OK (
759
+ ctx, table_cuckoo->ImportValuesFromFile (ctx, filepath, buffer_size_));
760
+ }
761
+
762
+ private:
763
+ size_t buffer_size_;
764
+ };
765
+
654
766
// Register the CuckooHashTableOfTensors op.
655
767
656
- #define REGISTER_KERNEL (key_dtype, value_dtype ) \
657
- REGISTER_KERNEL_BUILDER ( \
658
- Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
659
- .Device(DEVICE_GPU) \
660
- .TypeConstraint<key_dtype>(" key_dtype" ) \
661
- .TypeConstraint<value_dtype>(" value_dtype" ), \
662
- HashTableGpuOp< \
663
- lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
664
- key_dtype, value_dtype>); \
665
- REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
666
- .Device(DEVICE_GPU) \
667
- .TypeConstraint<key_dtype>(" key_dtype" ) \
668
- .TypeConstraint<value_dtype>(" value_dtype" ), \
669
- HashTableClearGpuOp<key_dtype, value_dtype>) \
670
- REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
671
- .Device(DEVICE_GPU) \
672
- .TypeConstraint<key_dtype>(" key_dtype" ) \
673
- .TypeConstraint<value_dtype>(" value_dtype" ), \
674
- HashTableAccumGpuOp<key_dtype, value_dtype>) \
675
- REGISTER_KERNEL_BUILDER( \
676
- Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
677
- .Device(DEVICE_GPU) \
678
- .TypeConstraint<key_dtype>(" Tin" ) \
679
- .TypeConstraint<value_dtype>(" Tout" ), \
680
- HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
768
+ #define REGISTER_KERNEL (key_dtype, value_dtype ) \
769
+ REGISTER_KERNEL_BUILDER ( \
770
+ Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
771
+ .Device(DEVICE_GPU) \
772
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
773
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
774
+ HashTableGpuOp< \
775
+ lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
776
+ key_dtype, value_dtype>); \
777
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
778
+ .Device(DEVICE_GPU) \
779
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
780
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
781
+ HashTableClearGpuOp<key_dtype, value_dtype>); \
782
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
783
+ .Device(DEVICE_GPU) \
784
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
785
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
786
+ HashTableAccumGpuOp<key_dtype, value_dtype>); \
787
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
788
+ .Device(DEVICE_GPU) \
789
+ .HostMemory(" filepath" ) \
790
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
791
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
792
+ HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
793
+ REGISTER_KERNEL_BUILDER ( \
794
+ Name (PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
795
+ .Device(DEVICE_GPU) \
796
+ .HostMemory(" filepath" ) \
797
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
798
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
799
+ HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
800
+ REGISTER_KERNEL_BUILDER ( \
801
+ Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
802
+ .Device(DEVICE_GPU) \
803
+ .TypeConstraint<key_dtype>(" Tin" ) \
804
+ .TypeConstraint<value_dtype>(" Tout" ), \
805
+ HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
681
806
682
807
REGISTER_KERNEL (int64, float );
683
808
REGISTER_KERNEL (int64, Eigen::half);
0 commit comments