Skip to content

Commit dfc5667

Browse files
committed
Add ops of ExportToFile and ImportFromFile without full volume copying
1 parent 3592668 commit dfc5667

File tree

8 files changed

+757
-31
lines changed

8 files changed

+757
-31
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface {
304304
return table_->export_values(ctx, value_dim);
305305
}
306306

307+
Status SaveToFile(OpKernelContext* ctx, const string filepath,
308+
const size_t buffer_size) {
309+
int64 value_dim = value_shape_.dim_size(0);
310+
return table_->save_to_file(ctx, value_dim, filepath, buffer_size);
311+
}
312+
313+
Status LoadFromFile(OpKernelContext* ctx, const string filepath,
314+
const size_t buffer_size) {
315+
int64 value_dim = value_shape_.dim_size(0);
316+
return table_->load_from_file(ctx, value_dim, filepath, buffer_size);
317+
}
318+
307319
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
308320

309321
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
@@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel {
607619
}
608620
};
609621

622+
// Op that export all keys and values to file.
623+
template <class K, class V>
624+
class HashTableExportToFileOp : public HashTableOpKernel {
625+
public:
626+
explicit HashTableExportToFileOp(OpKernelConstruction* ctx)
627+
: HashTableOpKernel(ctx) {
628+
int64 signed_buffer_size = 0;
629+
ctx->GetAttr("buffer_size", &signed_buffer_size);
630+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
631+
}
632+
633+
void Compute(OpKernelContext* ctx) override {
634+
LookupInterface* table;
635+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
636+
core::ScopedUnref unref_me(table);
637+
638+
const Tensor& ftensor = ctx->input(1);
639+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
640+
errors::InvalidArgument("filepath must be scalar."));
641+
string filepath = string(ftensor.scalar<tstring>()().data());
642+
643+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
644+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
645+
OP_REQUIRES_OK(ctx, table_cuckoo->SaveToFile(ctx, filepath, buffer_size_));
646+
}
647+
648+
private:
649+
size_t buffer_size_;
650+
};
651+
610652
// Clear the table and insert data.
611653
class HashTableImportOp : public HashTableOpKernel {
612654
public:
@@ -637,6 +679,37 @@ class HashTableImportOp : public HashTableOpKernel {
637679
}
638680
};
639681

682+
// Op that export all keys and values to file.
683+
template <class K, class V>
684+
class HashTableImportFromFileOp : public HashTableOpKernel {
685+
public:
686+
explicit HashTableImportFromFileOp(OpKernelConstruction* ctx)
687+
: HashTableOpKernel(ctx) {
688+
int64 signed_buffer_size = 0;
689+
ctx->GetAttr("buffer_size", &signed_buffer_size);
690+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
691+
}
692+
693+
void Compute(OpKernelContext* ctx) override {
694+
LookupInterface* table;
695+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
696+
core::ScopedUnref unref_me(table);
697+
698+
const Tensor& ftensor = ctx->input(1);
699+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
700+
errors::InvalidArgument("filepath must be scalar."));
701+
string filepath = string(ftensor.scalar<tstring>()().data());
702+
703+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
704+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
705+
OP_REQUIRES_OK(ctx,
706+
table_cuckoo->LoadFromFile(ctx, filepath, buffer_size_));
707+
}
708+
709+
private:
710+
size_t buffer_size_;
711+
};
712+
640713
REGISTER_KERNEL_BUILDER(
641714
Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU),
642715
HashTableFindOp);
@@ -679,7 +752,17 @@ REGISTER_KERNEL_BUILDER(
679752
.Device(DEVICE_CPU) \
680753
.TypeConstraint<key_dtype>("Tin") \
681754
.TypeConstraint<value_dtype>("Tout"), \
682-
HashTableFindWithExistsOp<key_dtype, value_dtype>);
755+
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
756+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
757+
.Device(DEVICE_CPU) \
758+
.TypeConstraint<key_dtype>("key_dtype") \
759+
.TypeConstraint<value_dtype>("value_dtype"), \
760+
HashTableExportToFileOp<key_dtype, value_dtype>); \
761+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
762+
.Device(DEVICE_CPU) \
763+
.TypeConstraint<key_dtype>("key_dtype") \
764+
.TypeConstraint<value_dtype>("value_dtype"), \
765+
HashTableImportFromFileOp<key_dtype, value_dtype>);
683766

684767
REGISTER_KERNEL(int32, double);
685768
REGISTER_KERNEL(int32, float);

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 154 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
214214
if (cur_size > 0) {
215215
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
216216
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
217-
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
217+
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
218+
sizeof(V) * runtime_dim_ * cur_size));
218219
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
219-
d_dump_counter, stream);
220-
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
220+
d_dump_counter, stream);
221+
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t),
222+
cudaMemcpyDeviceToHost, stream);
221223
CUDA_CHECK(cudaStreamSynchronize(stream));
222224
}
223225

@@ -226,8 +228,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
226228
CreateTable(new_max_size, &table_);
227229

228230
if (cur_size > 0) {
229-
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
230-
h_dump_counter, stream);
231+
table_->upsert((const K*)d_keys,
232+
(const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
233+
stream);
231234
cudaStreamSynchronize(stream);
232235
cudaFree(d_keys);
233236
cudaFree(d_values);
@@ -387,6 +390,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
387390
return Status::OK();
388391
}
389392

393+
Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
394+
const size_t buffer_size) {
395+
cudaStream_t _stream;
396+
CUDA_CHECK(cudaStreamCreate(&_stream));
397+
398+
{
399+
tf_shared_lock l(mu_);
400+
table_->dump_to_file(ctx, filepath, runtime_dim_, _stream, buffer_size);
401+
CUDA_CHECK(cudaStreamSynchronize(_stream));
402+
}
403+
CUDA_CHECK(cudaStreamDestroy(_stream));
404+
return Status::OK();
405+
}
406+
407+
Status ImportValuesFromFile(OpKernelContext* ctx, const string filepath,
408+
const size_t buffer_size) {
409+
cudaStream_t _stream;
410+
CUDA_CHECK(cudaStreamCreate(&_stream));
411+
412+
{
413+
tf_shared_lock l(mu_);
414+
415+
string keyfile = filepath + ".keys";
416+
FILE* tmpfd = fopen(keyfile.c_str(), "rb");
417+
if (tmpfd == nullptr) {
418+
return errors::NotFound("Failed to read key file", keyfile);
419+
}
420+
fseek(tmpfd, 0, SEEK_END);
421+
long int filesize = ftell(tmpfd);
422+
if (filesize <= 0) {
423+
fclose(tmpfd);
424+
return errors::NotFound("Empty key file.");
425+
}
426+
size_t size = static_cast<size_t>(filesize) / sizeof(K);
427+
fseek(tmpfd, 0, SEEK_SET);
428+
fclose(tmpfd);
429+
430+
table_->clear(_stream);
431+
CUDA_CHECK(cudaStreamSynchronize(_stream));
432+
RehashIfNeeded(_stream, size);
433+
table_->load_from_file(ctx, filepath, size, runtime_dim_, _stream,
434+
buffer_size);
435+
CUDA_CHECK(cudaStreamSynchronize(_stream));
436+
}
437+
CUDA_CHECK(cudaStreamDestroy(_stream));
438+
return Status::OK();
439+
}
440+
390441
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
391442
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
392443
TensorShape key_shape() const final { return TensorShape(); }
@@ -625,6 +676,36 @@ REGISTER_KERNEL_BUILDER(
625676
Name(PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
626677
HashTableExportGpuOp);
627678

679+
// Op that export all keys and values to file.
680+
template <class K, class V>
681+
class HashTableExportToFileGpuOp : public OpKernel {
682+
public:
683+
explicit HashTableExportToFileGpuOp(OpKernelConstruction* ctx)
684+
: OpKernel(ctx) {
685+
int64 signed_buffer_size = 0;
686+
ctx->GetAttr("buffer_size", &signed_buffer_size);
687+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
688+
}
689+
690+
void Compute(OpKernelContext* ctx) override {
691+
lookup::LookupInterface* table;
692+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
693+
core::ScopedUnref unref_me(table);
694+
695+
const Tensor& ftensor = ctx->input(1);
696+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
697+
errors::InvalidArgument("filepath must be scalar."));
698+
string filepath = string(ftensor.scalar<tstring>()().data());
699+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
700+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
701+
OP_REQUIRES_OK(
702+
ctx, table_cuckoo->ExportValuesToFile(ctx, filepath, buffer_size_));
703+
}
704+
705+
private:
706+
size_t buffer_size_;
707+
};
708+
628709
// Clear the table and insert data.
629710
class HashTableImportGpuOp : public OpKernel {
630711
public:
@@ -651,33 +732,76 @@ REGISTER_KERNEL_BUILDER(
651732
Name(PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
652733
HashTableImportGpuOp);
653734

735+
// Op that import from file.
736+
template <class K, class V>
737+
class HashTableImportFromFileGpuOp : public OpKernel {
738+
public:
739+
explicit HashTableImportFromFileGpuOp(OpKernelConstruction* ctx)
740+
: OpKernel(ctx) {
741+
int64 signed_buffer_size = 0;
742+
ctx->GetAttr("buffer_size", &signed_buffer_size);
743+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
744+
}
745+
746+
void Compute(OpKernelContext* ctx) override {
747+
lookup::LookupInterface* table;
748+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
749+
core::ScopedUnref unref_me(table);
750+
751+
const Tensor& ftensor = ctx->input(1);
752+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
753+
errors::InvalidArgument("filepath must be scalar."));
754+
string filepath = string(ftensor.scalar<tstring>()().data());
755+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
756+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
757+
OP_REQUIRES_OK(
758+
ctx, table_cuckoo->ImportValuesFromFile(ctx, filepath, buffer_size_));
759+
}
760+
761+
private:
762+
size_t buffer_size_;
763+
};
764+
654765
// Register the CuckooHashTableOfTensors op.
655766

656-
#define REGISTER_KERNEL(key_dtype, value_dtype) \
657-
REGISTER_KERNEL_BUILDER( \
658-
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
659-
.Device(DEVICE_GPU) \
660-
.TypeConstraint<key_dtype>("key_dtype") \
661-
.TypeConstraint<value_dtype>("value_dtype"), \
662-
HashTableGpuOp< \
663-
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
664-
key_dtype, value_dtype>); \
665-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
666-
.Device(DEVICE_GPU) \
667-
.TypeConstraint<key_dtype>("key_dtype") \
668-
.TypeConstraint<value_dtype>("value_dtype"), \
669-
HashTableClearGpuOp<key_dtype, value_dtype>) \
670-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
671-
.Device(DEVICE_GPU) \
672-
.TypeConstraint<key_dtype>("key_dtype") \
673-
.TypeConstraint<value_dtype>("value_dtype"), \
674-
HashTableAccumGpuOp<key_dtype, value_dtype>) \
675-
REGISTER_KERNEL_BUILDER( \
676-
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
677-
.Device(DEVICE_GPU) \
678-
.TypeConstraint<key_dtype>("Tin") \
679-
.TypeConstraint<value_dtype>("Tout"), \
680-
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
767+
#define REGISTER_KERNEL(key_dtype, value_dtype) \
768+
REGISTER_KERNEL_BUILDER( \
769+
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
770+
.Device(DEVICE_GPU) \
771+
.TypeConstraint<key_dtype>("key_dtype") \
772+
.TypeConstraint<value_dtype>("value_dtype"), \
773+
HashTableGpuOp< \
774+
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
775+
key_dtype, value_dtype>); \
776+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
777+
.Device(DEVICE_GPU) \
778+
.TypeConstraint<key_dtype>("key_dtype") \
779+
.TypeConstraint<value_dtype>("value_dtype"), \
780+
HashTableClearGpuOp<key_dtype, value_dtype>); \
781+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
782+
.Device(DEVICE_GPU) \
783+
.TypeConstraint<key_dtype>("key_dtype") \
784+
.TypeConstraint<value_dtype>("value_dtype"), \
785+
HashTableAccumGpuOp<key_dtype, value_dtype>); \
786+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
787+
.Device(DEVICE_GPU) \
788+
.HostMemory("filepath") \
789+
.TypeConstraint<key_dtype>("key_dtype") \
790+
.TypeConstraint<value_dtype>("value_dtype"), \
791+
HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
792+
REGISTER_KERNEL_BUILDER( \
793+
Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
794+
.Device(DEVICE_GPU) \
795+
.HostMemory("filepath") \
796+
.TypeConstraint<key_dtype>("key_dtype") \
797+
.TypeConstraint<value_dtype>("value_dtype"), \
798+
HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
799+
REGISTER_KERNEL_BUILDER( \
800+
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
801+
.Device(DEVICE_GPU) \
802+
.TypeConstraint<key_dtype>("Tin") \
803+
.TypeConstraint<value_dtype>("Tout"), \
804+
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
681805

682806
REGISTER_KERNEL(int64, float);
683807
REGISTER_KERNEL(int64, Eigen::half);

0 commit comments

Comments
 (0)