@@ -214,10 +214,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
214
214
if (cur_size > 0 ) {
215
215
CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
216
216
CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
217
- CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
217
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
218
+ sizeof (V) * runtime_dim_ * cur_size));
218
219
table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
219
- d_dump_counter, stream);
220
- cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
220
+ d_dump_counter, stream);
221
+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ),
222
+ cudaMemcpyDeviceToHost, stream);
221
223
CUDA_CHECK (cudaStreamSynchronize (stream));
222
224
}
223
225
@@ -226,8 +228,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
226
228
CreateTable (new_max_size, &table_);
227
229
228
230
if (cur_size > 0 ) {
229
- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
230
- h_dump_counter, stream);
231
+ table_->upsert ((const K*)d_keys,
232
+ (const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
233
+ stream);
231
234
cudaStreamSynchronize (stream);
232
235
cudaFree (d_keys);
233
236
cudaFree (d_values);
@@ -387,6 +390,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
387
390
return Status::OK ();
388
391
}
389
392
393
+ Status ExportValuesToFile (OpKernelContext* ctx, const string filepath,
394
+ const size_t buffer_size) {
395
+ cudaStream_t _stream;
396
+ CUDA_CHECK (cudaStreamCreate (&_stream));
397
+
398
+ {
399
+ tf_shared_lock l (mu_);
400
+ table_->dump_to_file (ctx, filepath, runtime_dim_, _stream, buffer_size);
401
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
402
+ }
403
+ CUDA_CHECK (cudaStreamDestroy (_stream));
404
+ return Status::OK ();
405
+ }
406
+
407
+ Status ImportValuesFromFile (OpKernelContext* ctx, const string filepath,
408
+ const size_t buffer_size) {
409
+ cudaStream_t _stream;
410
+ CUDA_CHECK (cudaStreamCreate (&_stream));
411
+
412
+ {
413
+ tf_shared_lock l (mu_);
414
+
415
+ string keyfile = filepath + " .keys" ;
416
+ FILE* tmpfd = fopen (keyfile.c_str (), " rb" );
417
+ if (tmpfd == nullptr ) {
418
+ return errors::NotFound (" Failed to read key file" , keyfile);
419
+ }
420
+ fseek (tmpfd, 0 , SEEK_END);
421
+ long int filesize = ftell (tmpfd);
422
+ if (filesize <= 0 ) {
423
+ fclose (tmpfd);
424
+ return errors::NotFound (" Empty key file." );
425
+ }
426
+ size_t size = static_cast <size_t >(filesize) / sizeof (K);
427
+ fseek (tmpfd, 0 , SEEK_SET);
428
+ fclose (tmpfd);
429
+
430
+ table_->clear (_stream);
431
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
432
+ RehashIfNeeded (_stream, size);
433
+ table_->load_from_file (ctx, filepath, size, runtime_dim_, _stream,
434
+ buffer_size);
435
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
436
+ }
437
+ CUDA_CHECK (cudaStreamDestroy (_stream));
438
+ return Status::OK ();
439
+ }
440
+
390
441
DataType key_dtype () const override { return DataTypeToEnum<K>::v (); }
391
442
DataType value_dtype () const override { return DataTypeToEnum<V>::v (); }
392
443
TensorShape key_shape () const final { return TensorShape (); }
@@ -625,6 +676,36 @@ REGISTER_KERNEL_BUILDER(
625
676
Name (PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
626
677
HashTableExportGpuOp);
627
678
679
+ // Op that export all keys and values to file.
680
+ template <class K , class V >
681
+ class HashTableExportToFileGpuOp : public OpKernel {
682
+ public:
683
+ explicit HashTableExportToFileGpuOp (OpKernelConstruction* ctx)
684
+ : OpKernel(ctx) {
685
+ int64 signed_buffer_size = 0 ;
686
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
687
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
688
+ }
689
+
690
+ void Compute (OpKernelContext* ctx) override {
691
+ lookup::LookupInterface* table;
692
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
693
+ core::ScopedUnref unref_me (table);
694
+
695
+ const Tensor& ftensor = ctx->input (1 );
696
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
697
+ errors::InvalidArgument (" filepath must be scalar." ));
698
+ string filepath = string (ftensor.scalar <tstring>()().data ());
699
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
700
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
701
+ OP_REQUIRES_OK (
702
+ ctx, table_cuckoo->ExportValuesToFile (ctx, filepath, buffer_size_));
703
+ }
704
+
705
+ private:
706
+ size_t buffer_size_;
707
+ };
708
+
628
709
// Clear the table and insert data.
629
710
class HashTableImportGpuOp : public OpKernel {
630
711
public:
@@ -651,33 +732,76 @@ REGISTER_KERNEL_BUILDER(
651
732
Name (PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
652
733
HashTableImportGpuOp);
653
734
735
+ // Op that import from file.
736
+ template <class K , class V >
737
+ class HashTableImportFromFileGpuOp : public OpKernel {
738
+ public:
739
+ explicit HashTableImportFromFileGpuOp (OpKernelConstruction* ctx)
740
+ : OpKernel(ctx) {
741
+ int64 signed_buffer_size = 0 ;
742
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
743
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
744
+ }
745
+
746
+ void Compute (OpKernelContext* ctx) override {
747
+ lookup::LookupInterface* table;
748
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
749
+ core::ScopedUnref unref_me (table);
750
+
751
+ const Tensor& ftensor = ctx->input (1 );
752
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
753
+ errors::InvalidArgument (" filepath must be scalar." ));
754
+ string filepath = string (ftensor.scalar <tstring>()().data ());
755
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
756
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
757
+ OP_REQUIRES_OK (
758
+ ctx, table_cuckoo->ImportValuesFromFile (ctx, filepath, buffer_size_));
759
+ }
760
+
761
+ private:
762
+ size_t buffer_size_;
763
+ };
764
+
654
765
// Register the CuckooHashTableOfTensors op.
655
766
656
- #define REGISTER_KERNEL (key_dtype, value_dtype ) \
657
- REGISTER_KERNEL_BUILDER ( \
658
- Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
659
- .Device(DEVICE_GPU) \
660
- .TypeConstraint<key_dtype>(" key_dtype" ) \
661
- .TypeConstraint<value_dtype>(" value_dtype" ), \
662
- HashTableGpuOp< \
663
- lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
664
- key_dtype, value_dtype>); \
665
- REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
666
- .Device(DEVICE_GPU) \
667
- .TypeConstraint<key_dtype>(" key_dtype" ) \
668
- .TypeConstraint<value_dtype>(" value_dtype" ), \
669
- HashTableClearGpuOp<key_dtype, value_dtype>) \
670
- REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
671
- .Device(DEVICE_GPU) \
672
- .TypeConstraint<key_dtype>(" key_dtype" ) \
673
- .TypeConstraint<value_dtype>(" value_dtype" ), \
674
- HashTableAccumGpuOp<key_dtype, value_dtype>) \
675
- REGISTER_KERNEL_BUILDER( \
676
- Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
677
- .Device(DEVICE_GPU) \
678
- .TypeConstraint<key_dtype>(" Tin" ) \
679
- .TypeConstraint<value_dtype>(" Tout" ), \
680
- HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
767
+ #define REGISTER_KERNEL (key_dtype, value_dtype ) \
768
+ REGISTER_KERNEL_BUILDER ( \
769
+ Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
770
+ .Device(DEVICE_GPU) \
771
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
772
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
773
+ HashTableGpuOp< \
774
+ lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
775
+ key_dtype, value_dtype>); \
776
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
777
+ .Device(DEVICE_GPU) \
778
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
779
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
780
+ HashTableClearGpuOp<key_dtype, value_dtype>); \
781
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
782
+ .Device(DEVICE_GPU) \
783
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
784
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
785
+ HashTableAccumGpuOp<key_dtype, value_dtype>); \
786
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
787
+ .Device(DEVICE_GPU) \
788
+ .HostMemory(" filepath" ) \
789
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
790
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
791
+ HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
792
+ REGISTER_KERNEL_BUILDER ( \
793
+ Name (PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
794
+ .Device(DEVICE_GPU) \
795
+ .HostMemory(" filepath" ) \
796
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
797
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
798
+ HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
799
+ REGISTER_KERNEL_BUILDER ( \
800
+ Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
801
+ .Device(DEVICE_GPU) \
802
+ .TypeConstraint<key_dtype>(" Tin" ) \
803
+ .TypeConstraint<value_dtype>(" Tout" ), \
804
+ HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
681
805
682
806
REGISTER_KERNEL (int64, float );
683
807
REGISTER_KERNEL (int64, Eigen::half);
0 commit comments