@@ -210,10 +210,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
210
210
if (cur_size > 0 ) {
211
211
CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
212
212
CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
213
- CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
213
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
214
+ sizeof (V) * runtime_dim_ * cur_size));
214
215
table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
215
- d_dump_counter, stream);
216
- cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
216
+ d_dump_counter, stream);
217
+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ),
218
+ cudaMemcpyDeviceToHost, stream);
217
219
CUDA_CHECK (cudaStreamSynchronize (stream));
218
220
}
219
221
@@ -222,8 +224,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
222
224
CreateTable (new_max_size, &table_);
223
225
224
226
if (cur_size > 0 ) {
225
- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
226
- h_dump_counter, stream);
227
+ table_->upsert ((const K*)d_keys,
228
+ (const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
229
+ stream);
227
230
cudaStreamSynchronize (stream);
228
231
cudaFree (d_keys);
229
232
cudaFree (d_values);
@@ -383,6 +386,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
383
386
return Status::OK ();
384
387
}
385
388
389
+ Status ExportValuesToFile (OpKernelContext* ctx, const string filepath,
390
+ const size_t buffer_size) {
391
+ cudaStream_t _stream;
392
+ CUDA_CHECK (cudaStreamCreate (&_stream));
393
+
394
+ {
395
+ tf_shared_lock l (mu_);
396
+ table_->dump_to_file (ctx, filepath, runtime_dim_, _stream, buffer_size);
397
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
398
+ }
399
+ CUDA_CHECK (cudaStreamDestroy (_stream));
400
+ return Status::OK ();
401
+ }
402
+
403
+ Status ImportValuesFromFile (OpKernelContext* ctx, const string filepath,
404
+ const size_t buffer_size) {
405
+ cudaStream_t _stream;
406
+ CUDA_CHECK (cudaStreamCreate (&_stream));
407
+
408
+ {
409
+ tf_shared_lock l (mu_);
410
+
411
+ string keyfile = filepath + " .keys" ;
412
+ FILE* tmpfd = fopen (keyfile.c_str (), " rb" );
413
+ if (tmpfd == nullptr ) {
414
+ return errors::NotFound (" Failed to read key file" , keyfile);
415
+ }
416
+ fseek (tmpfd, 0 , SEEK_END);
417
+ long int filesize = ftell (tmpfd);
418
+ if (filesize <= 0 ) {
419
+ fclose (tmpfd);
420
+ return errors::NotFound (" Empty key file." );
421
+ }
422
+ size_t size = static_cast <size_t >(filesize) / sizeof (K);
423
+ fseek (tmpfd, 0 , SEEK_SET);
424
+ fclose (tmpfd);
425
+
426
+ table_->clear (_stream);
427
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
428
+ RehashIfNeeded (_stream, size);
429
+ table_->load_from_file (ctx, filepath, size, runtime_dim_, _stream,
430
+ buffer_size);
431
+ CUDA_CHECK (cudaStreamSynchronize (_stream));
432
+ }
433
+ CUDA_CHECK (cudaStreamDestroy (_stream));
434
+ return Status::OK ();
435
+ }
436
+
386
437
DataType key_dtype () const override { return DataTypeToEnum<K>::v (); }
387
438
DataType value_dtype () const override { return DataTypeToEnum<V>::v (); }
388
439
TensorShape key_shape () const final { return TensorShape (); }
@@ -621,6 +672,36 @@ REGISTER_KERNEL_BUILDER(
621
672
Name (PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
622
673
HashTableExportGpuOp);
623
674
675
+ // Op that export all keys and values to file.
676
+ template <class K , class V >
677
+ class HashTableExportToFileGpuOp : public OpKernel {
678
+ public:
679
+ explicit HashTableExportToFileGpuOp (OpKernelConstruction* ctx)
680
+ : OpKernel(ctx) {
681
+ int64 signed_buffer_size = 0 ;
682
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
683
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
684
+ }
685
+
686
+ void Compute (OpKernelContext* ctx) override {
687
+ lookup::LookupInterface* table;
688
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
689
+ core::ScopedUnref unref_me (table);
690
+
691
+ const Tensor& ftensor = ctx->input (1 );
692
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
693
+ errors::InvalidArgument (" filepath must be scalar." ));
694
+ string filepath = string (ftensor.scalar <tstring>()().data ());
695
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
696
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
697
+ OP_REQUIRES_OK (
698
+ ctx, table_cuckoo->ExportValuesToFile (ctx, filepath, buffer_size_));
699
+ }
700
+
701
+ private:
702
+ size_t buffer_size_;
703
+ };
704
+
624
705
// Clear the table and insert data.
625
706
class HashTableImportGpuOp : public OpKernel {
626
707
public:
@@ -647,33 +728,76 @@ REGISTER_KERNEL_BUILDER(
647
728
Name (PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
648
729
HashTableImportGpuOp);
649
730
731
+ // Op that import from file.
732
+ template <class K , class V >
733
+ class HashTableImportFromFileGpuOp : public OpKernel {
734
+ public:
735
+ explicit HashTableImportFromFileGpuOp (OpKernelConstruction* ctx)
736
+ : OpKernel(ctx) {
737
+ int64 signed_buffer_size = 0 ;
738
+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
739
+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
740
+ }
741
+
742
+ void Compute (OpKernelContext* ctx) override {
743
+ lookup::LookupInterface* table;
744
+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
745
+ core::ScopedUnref unref_me (table);
746
+
747
+ const Tensor& ftensor = ctx->input (1 );
748
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
749
+ errors::InvalidArgument (" filepath must be scalar." ));
750
+ string filepath = string (ftensor.scalar <tstring>()().data ());
751
+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
752
+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
753
+ OP_REQUIRES_OK (
754
+ ctx, table_cuckoo->ImportValuesFromFile (ctx, filepath, buffer_size_));
755
+ }
756
+
757
+ private:
758
+ size_t buffer_size_;
759
+ };
760
+
650
761
// Register the CuckooHashTableOfTensors op.
651
762
652
- #define REGISTER_KERNEL (key_dtype, value_dtype ) \
653
- REGISTER_KERNEL_BUILDER ( \
654
- Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
655
- .Device(DEVICE_GPU) \
656
- .TypeConstraint<key_dtype>(" key_dtype" ) \
657
- .TypeConstraint<value_dtype>(" value_dtype" ), \
658
- HashTableGpuOp< \
659
- lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
660
- key_dtype, value_dtype>); \
661
- REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
662
- .Device(DEVICE_GPU) \
663
- .TypeConstraint<key_dtype>(" key_dtype" ) \
664
- .TypeConstraint<value_dtype>(" value_dtype" ), \
665
- HashTableClearGpuOp<key_dtype, value_dtype>) \
666
- REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
667
- .Device(DEVICE_GPU) \
668
- .TypeConstraint<key_dtype>(" key_dtype" ) \
669
- .TypeConstraint<value_dtype>(" value_dtype" ), \
670
- HashTableAccumGpuOp<key_dtype, value_dtype>) \
671
- REGISTER_KERNEL_BUILDER( \
672
- Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
673
- .Device(DEVICE_GPU) \
674
- .TypeConstraint<key_dtype>(" Tin" ) \
675
- .TypeConstraint<value_dtype>(" Tout" ), \
676
- HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
763
+ #define REGISTER_KERNEL (key_dtype, value_dtype ) \
764
+ REGISTER_KERNEL_BUILDER ( \
765
+ Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
766
+ .Device(DEVICE_GPU) \
767
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
768
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
769
+ HashTableGpuOp< \
770
+ lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
771
+ key_dtype, value_dtype>); \
772
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
773
+ .Device(DEVICE_GPU) \
774
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
775
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
776
+ HashTableClearGpuOp<key_dtype, value_dtype>); \
777
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
778
+ .Device(DEVICE_GPU) \
779
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
780
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
781
+ HashTableAccumGpuOp<key_dtype, value_dtype>); \
782
+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
783
+ .Device(DEVICE_GPU) \
784
+ .HostMemory(" filepath" ) \
785
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
786
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
787
+ HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
788
+ REGISTER_KERNEL_BUILDER ( \
789
+ Name (PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
790
+ .Device(DEVICE_GPU) \
791
+ .HostMemory(" filepath" ) \
792
+ .TypeConstraint<key_dtype>(" key_dtype" ) \
793
+ .TypeConstraint<value_dtype>(" value_dtype" ), \
794
+ HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
795
+ REGISTER_KERNEL_BUILDER ( \
796
+ Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
797
+ .Device(DEVICE_GPU) \
798
+ .TypeConstraint<key_dtype>(" Tin" ) \
799
+ .TypeConstraint<value_dtype>(" Tout" ), \
800
+ HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
677
801
678
802
REGISTER_KERNEL (int64, float );
679
803
REGISTER_KERNEL (int64, Eigen::half);
0 commit comments