Skip to content

Commit c9acc63

Browse files
committed
Add ops of ExportToFile and ImportFromFile without full volume copying
1 parent 097bd21 commit c9acc63

File tree

9 files changed

+762
-36
lines changed

9 files changed

+762
-36
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/BUILD

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ custom_op_library(
1414
"kernels/cuckoo_hashtable_op.h",
1515
"kernels/cuckoo_hashtable_op.cc",
1616
"ops/cuckoo_hashtable_ops.cc",
17-
"utils/utils.h",
17+
"utils/filebuffer.h",
1818
"utils/types.h",
19+
"utils/utils.h",
1920
] + glob(["kernels/lookup_impl/lookup_table_op_cpu*"]),
2021
cuda_deps = if_cuda_for_tf_serving(
2122
["//tensorflow_recommenders_addons/dynamic_embedding/core/lib/nvhash:nvhashtable"],
@@ -26,8 +27,9 @@ custom_op_library(
2627
"kernels/cuckoo_hashtable_op.h",
2728
"kernels/cuckoo_hashtable_op_gpu.h",
2829
"kernels/cuckoo_hashtable_op_gpu.cu.cc",
29-
"utils/utils.h",
30+
"utils/filebuffer.h",
3031
"utils/types.h",
32+
"utils/utils.h",
3133
] + glob(["kernels/lookup_impl/lookup_table_op_gpu*"])),
3234
deps = ["//tensorflow_recommenders_addons/dynamic_embedding/core/lib/cuckoo:cuckoohash"],
3335
)

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface {
304304
return table_->export_values(ctx, value_dim);
305305
}
306306

307+
Status SaveToFile(OpKernelContext* ctx, const string filepath,
308+
const size_t buffer_size) {
309+
int64 value_dim = value_shape_.dim_size(0);
310+
return table_->save_to_file(ctx, value_dim, filepath, buffer_size);
311+
}
312+
313+
Status LoadFromFile(OpKernelContext* ctx, const string filepath,
314+
const size_t buffer_size) {
315+
int64 value_dim = value_shape_.dim_size(0);
316+
return table_->load_from_file(ctx, value_dim, filepath, buffer_size);
317+
}
318+
307319
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
308320

309321
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
@@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel {
607619
}
608620
};
609621

622+
// Op that export all keys and values to file.
623+
template <class K, class V>
624+
class HashTableExportToFileOp : public HashTableOpKernel {
625+
public:
626+
explicit HashTableExportToFileOp(OpKernelConstruction* ctx)
627+
: HashTableOpKernel(ctx) {
628+
int64 signed_buffer_size = 0;
629+
ctx->GetAttr("buffer_size", &signed_buffer_size);
630+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
631+
}
632+
633+
void Compute(OpKernelContext* ctx) override {
634+
LookupInterface* table;
635+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
636+
core::ScopedUnref unref_me(table);
637+
638+
const Tensor& ftensor = ctx->input(1);
639+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
640+
errors::InvalidArgument("filepath must be scalar."));
641+
string filepath = string(ftensor.scalar<tstring>()().data());
642+
643+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
644+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
645+
OP_REQUIRES_OK(ctx, table_cuckoo->SaveToFile(ctx, filepath, buffer_size_));
646+
}
647+
648+
private:
649+
size_t buffer_size_;
650+
};
651+
610652
// Clear the table and insert data.
611653
class HashTableImportOp : public HashTableOpKernel {
612654
public:
@@ -637,6 +679,37 @@ class HashTableImportOp : public HashTableOpKernel {
637679
}
638680
};
639681

682+
// Op that export all keys and values to file.
683+
template <class K, class V>
684+
class HashTableImportFromFileOp : public HashTableOpKernel {
685+
public:
686+
explicit HashTableImportFromFileOp(OpKernelConstruction* ctx)
687+
: HashTableOpKernel(ctx) {
688+
int64 signed_buffer_size = 0;
689+
ctx->GetAttr("buffer_size", &signed_buffer_size);
690+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
691+
}
692+
693+
void Compute(OpKernelContext* ctx) override {
694+
LookupInterface* table;
695+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
696+
core::ScopedUnref unref_me(table);
697+
698+
const Tensor& ftensor = ctx->input(1);
699+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
700+
errors::InvalidArgument("filepath must be scalar."));
701+
string filepath = string(ftensor.scalar<tstring>()().data());
702+
703+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
704+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
705+
OP_REQUIRES_OK(ctx,
706+
table_cuckoo->LoadFromFile(ctx, filepath, buffer_size_));
707+
}
708+
709+
private:
710+
size_t buffer_size_;
711+
};
712+
640713
REGISTER_KERNEL_BUILDER(
641714
Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU),
642715
HashTableFindOp);
@@ -679,7 +752,17 @@ REGISTER_KERNEL_BUILDER(
679752
.Device(DEVICE_CPU) \
680753
.TypeConstraint<key_dtype>("Tin") \
681754
.TypeConstraint<value_dtype>("Tout"), \
682-
HashTableFindWithExistsOp<key_dtype, value_dtype>);
755+
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
756+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
757+
.Device(DEVICE_CPU) \
758+
.TypeConstraint<key_dtype>("key_dtype") \
759+
.TypeConstraint<value_dtype>("value_dtype"), \
760+
HashTableExportToFileOp<key_dtype, value_dtype>); \
761+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
762+
.Device(DEVICE_CPU) \
763+
.TypeConstraint<key_dtype>("key_dtype") \
764+
.TypeConstraint<value_dtype>("value_dtype"), \
765+
HashTableImportFromFileOp<key_dtype, value_dtype>);
683766

684767
REGISTER_KERNEL(int32, double);
685768
REGISTER_KERNEL(int32, float);

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 154 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
210210
if (cur_size > 0) {
211211
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
212212
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
213-
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
213+
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
214+
sizeof(V) * runtime_dim_ * cur_size));
214215
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
215-
d_dump_counter, stream);
216-
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
216+
d_dump_counter, stream);
217+
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t),
218+
cudaMemcpyDeviceToHost, stream);
217219
CUDA_CHECK(cudaStreamSynchronize(stream));
218220
}
219221

@@ -222,8 +224,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
222224
CreateTable(new_max_size, &table_);
223225

224226
if (cur_size > 0) {
225-
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
226-
h_dump_counter, stream);
227+
table_->upsert((const K*)d_keys,
228+
(const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
229+
stream);
227230
cudaStreamSynchronize(stream);
228231
cudaFree(d_keys);
229232
cudaFree(d_values);
@@ -383,6 +386,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
383386
return Status::OK();
384387
}
385388

389+
Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
390+
const size_t buffer_size) {
391+
cudaStream_t _stream;
392+
CUDA_CHECK(cudaStreamCreate(&_stream));
393+
394+
{
395+
tf_shared_lock l(mu_);
396+
table_->dump_to_file(ctx, filepath, runtime_dim_, _stream, buffer_size);
397+
CUDA_CHECK(cudaStreamSynchronize(_stream));
398+
}
399+
CUDA_CHECK(cudaStreamDestroy(_stream));
400+
return Status::OK();
401+
}
402+
403+
Status ImportValuesFromFile(OpKernelContext* ctx, const string filepath,
404+
const size_t buffer_size) {
405+
cudaStream_t _stream;
406+
CUDA_CHECK(cudaStreamCreate(&_stream));
407+
408+
{
409+
tf_shared_lock l(mu_);
410+
411+
string keyfile = filepath + ".keys";
412+
FILE* tmpfd = fopen(keyfile.c_str(), "rb");
413+
if (tmpfd == nullptr) {
414+
return errors::NotFound("Failed to read key file", keyfile);
415+
}
416+
fseek(tmpfd, 0, SEEK_END);
417+
long int filesize = ftell(tmpfd);
418+
if (filesize <= 0) {
419+
fclose(tmpfd);
420+
return errors::NotFound("Empty key file.");
421+
}
422+
size_t size = static_cast<size_t>(filesize) / sizeof(K);
423+
fseek(tmpfd, 0, SEEK_SET);
424+
fclose(tmpfd);
425+
426+
table_->clear(_stream);
427+
CUDA_CHECK(cudaStreamSynchronize(_stream));
428+
RehashIfNeeded(_stream, size);
429+
table_->load_from_file(ctx, filepath, size, runtime_dim_, _stream,
430+
buffer_size);
431+
CUDA_CHECK(cudaStreamSynchronize(_stream));
432+
}
433+
CUDA_CHECK(cudaStreamDestroy(_stream));
434+
return Status::OK();
435+
}
436+
386437
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
387438
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
388439
TensorShape key_shape() const final { return TensorShape(); }
@@ -621,6 +672,36 @@ REGISTER_KERNEL_BUILDER(
621672
Name(PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
622673
HashTableExportGpuOp);
623674

675+
// Op that export all keys and values to file.
676+
template <class K, class V>
677+
class HashTableExportToFileGpuOp : public OpKernel {
678+
public:
679+
explicit HashTableExportToFileGpuOp(OpKernelConstruction* ctx)
680+
: OpKernel(ctx) {
681+
int64 signed_buffer_size = 0;
682+
ctx->GetAttr("buffer_size", &signed_buffer_size);
683+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
684+
}
685+
686+
void Compute(OpKernelContext* ctx) override {
687+
lookup::LookupInterface* table;
688+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
689+
core::ScopedUnref unref_me(table);
690+
691+
const Tensor& ftensor = ctx->input(1);
692+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
693+
errors::InvalidArgument("filepath must be scalar."));
694+
string filepath = string(ftensor.scalar<tstring>()().data());
695+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
696+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
697+
OP_REQUIRES_OK(
698+
ctx, table_cuckoo->ExportValuesToFile(ctx, filepath, buffer_size_));
699+
}
700+
701+
private:
702+
size_t buffer_size_;
703+
};
704+
624705
// Clear the table and insert data.
625706
class HashTableImportGpuOp : public OpKernel {
626707
public:
@@ -647,33 +728,76 @@ REGISTER_KERNEL_BUILDER(
647728
Name(PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
648729
HashTableImportGpuOp);
649730

731+
// Op that import from file.
732+
template <class K, class V>
733+
class HashTableImportFromFileGpuOp : public OpKernel {
734+
public:
735+
explicit HashTableImportFromFileGpuOp(OpKernelConstruction* ctx)
736+
: OpKernel(ctx) {
737+
int64 signed_buffer_size = 0;
738+
ctx->GetAttr("buffer_size", &signed_buffer_size);
739+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
740+
}
741+
742+
void Compute(OpKernelContext* ctx) override {
743+
lookup::LookupInterface* table;
744+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
745+
core::ScopedUnref unref_me(table);
746+
747+
const Tensor& ftensor = ctx->input(1);
748+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
749+
errors::InvalidArgument("filepath must be scalar."));
750+
string filepath = string(ftensor.scalar<tstring>()().data());
751+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
752+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
753+
OP_REQUIRES_OK(
754+
ctx, table_cuckoo->ImportValuesFromFile(ctx, filepath, buffer_size_));
755+
}
756+
757+
private:
758+
size_t buffer_size_;
759+
};
760+
650761
// Register the CuckooHashTableOfTensors op.
651762

652-
#define REGISTER_KERNEL(key_dtype, value_dtype) \
653-
REGISTER_KERNEL_BUILDER( \
654-
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
655-
.Device(DEVICE_GPU) \
656-
.TypeConstraint<key_dtype>("key_dtype") \
657-
.TypeConstraint<value_dtype>("value_dtype"), \
658-
HashTableGpuOp< \
659-
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
660-
key_dtype, value_dtype>); \
661-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
662-
.Device(DEVICE_GPU) \
663-
.TypeConstraint<key_dtype>("key_dtype") \
664-
.TypeConstraint<value_dtype>("value_dtype"), \
665-
HashTableClearGpuOp<key_dtype, value_dtype>) \
666-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
667-
.Device(DEVICE_GPU) \
668-
.TypeConstraint<key_dtype>("key_dtype") \
669-
.TypeConstraint<value_dtype>("value_dtype"), \
670-
HashTableAccumGpuOp<key_dtype, value_dtype>) \
671-
REGISTER_KERNEL_BUILDER( \
672-
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
673-
.Device(DEVICE_GPU) \
674-
.TypeConstraint<key_dtype>("Tin") \
675-
.TypeConstraint<value_dtype>("Tout"), \
676-
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
763+
#define REGISTER_KERNEL(key_dtype, value_dtype) \
764+
REGISTER_KERNEL_BUILDER( \
765+
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
766+
.Device(DEVICE_GPU) \
767+
.TypeConstraint<key_dtype>("key_dtype") \
768+
.TypeConstraint<value_dtype>("value_dtype"), \
769+
HashTableGpuOp< \
770+
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
771+
key_dtype, value_dtype>); \
772+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
773+
.Device(DEVICE_GPU) \
774+
.TypeConstraint<key_dtype>("key_dtype") \
775+
.TypeConstraint<value_dtype>("value_dtype"), \
776+
HashTableClearGpuOp<key_dtype, value_dtype>); \
777+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
778+
.Device(DEVICE_GPU) \
779+
.TypeConstraint<key_dtype>("key_dtype") \
780+
.TypeConstraint<value_dtype>("value_dtype"), \
781+
HashTableAccumGpuOp<key_dtype, value_dtype>); \
782+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
783+
.Device(DEVICE_GPU) \
784+
.HostMemory("filepath") \
785+
.TypeConstraint<key_dtype>("key_dtype") \
786+
.TypeConstraint<value_dtype>("value_dtype"), \
787+
HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
788+
REGISTER_KERNEL_BUILDER( \
789+
Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
790+
.Device(DEVICE_GPU) \
791+
.HostMemory("filepath") \
792+
.TypeConstraint<key_dtype>("key_dtype") \
793+
.TypeConstraint<value_dtype>("value_dtype"), \
794+
HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
795+
REGISTER_KERNEL_BUILDER( \
796+
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
797+
.Device(DEVICE_GPU) \
798+
.TypeConstraint<key_dtype>("Tin") \
799+
.TypeConstraint<value_dtype>("Tout"), \
800+
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
677801

678802
REGISTER_KERNEL(int64, float);
679803
REGISTER_KERNEL(int64, Eigen::half);

0 commit comments

Comments
 (0)