Skip to content

Commit 3fdfd57

Browse files
Lifannoppenheimli
authored andcommitted
Add ops of ExportToFile and ImportFromFile without full volume copying
1 parent 154182f commit 3fdfd57

File tree

7 files changed

+694
-31
lines changed

7 files changed

+694
-31
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface {
304304
return table_->export_values(ctx, value_dim);
305305
}
306306

307+
Status Save(OpKernelContext* ctx, const string filepath,
308+
const size_t buffer_size) {
309+
int64 value_dim = value_shape_.dim_size(0);
310+
return table_->save(ctx, value_dim, filepath, buffer_size);
311+
}
312+
313+
Status Load(OpKernelContext* ctx, const string filepath,
314+
const size_t buffer_size) {
315+
int64 value_dim = value_shape_.dim_size(0);
316+
return table_->load(ctx, value_dim, filepath, buffer_size);
317+
}
318+
307319
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
308320

309321
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
@@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel {
607619
}
608620
};
609621

622+
// Op that export all keys and values to file.
623+
template <class K, class V>
624+
class HashTableExportToFileOp : public HashTableOpKernel {
625+
public:
626+
explicit HashTableExportToFileOp(OpKernelConstruction* ctx)
627+
: HashTableOpKernel(ctx) {
628+
int64 signed_buffer_size = 0;
629+
ctx->GetAttr("buffer_size", &signed_buffer_size);
630+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
631+
}
632+
633+
void Compute(OpKernelContext* ctx) override {
634+
LookupInterface* table;
635+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
636+
core::ScopedUnref unref_me(table);
637+
638+
const Tensor& ftensor = ctx->input(1);
639+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
640+
errors::InvalidArgument("filepath must be scalar."));
641+
string filepath = string(ftensor.scalar<tstring>()().data());
642+
643+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
644+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
645+
OP_REQUIRES_OK(ctx, table_cuckoo->Save(ctx, filepath, buffer_size_));
646+
}
647+
648+
private:
649+
size_t buffer_size_;
650+
};
651+
610652
// Clear the table and insert data.
611653
class HashTableImportOp : public HashTableOpKernel {
612654
public:
@@ -637,6 +679,36 @@ class HashTableImportOp : public HashTableOpKernel {
637679
}
638680
};
639681

682+
// Op that export all keys and values to file.
683+
template <class K, class V>
684+
class HashTableImportFromFileOp : public HashTableOpKernel {
685+
public:
686+
explicit HashTableImportFromFileOp(OpKernelConstruction* ctx)
687+
: HashTableOpKernel(ctx) {
688+
int64 signed_buffer_size = 0;
689+
ctx->GetAttr("buffer_size", &signed_buffer_size);
690+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
691+
}
692+
693+
void Compute(OpKernelContext* ctx) override {
694+
LookupInterface* table;
695+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
696+
core::ScopedUnref unref_me(table);
697+
698+
const Tensor& ftensor = ctx->input(1);
699+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
700+
errors::InvalidArgument("filepath must be scalar."));
701+
string filepath = string(ftensor.scalar<tstring>()().data());
702+
703+
lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
704+
(lookup::CuckooHashTableOfTensors<K, V>*)table;
705+
OP_REQUIRES_OK(ctx, table_cuckoo->Load(ctx, filepath, buffer_size_));
706+
}
707+
708+
private:
709+
size_t buffer_size_;
710+
};
711+
640712
REGISTER_KERNEL_BUILDER(
641713
Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU),
642714
HashTableFindOp);
@@ -679,7 +751,17 @@ REGISTER_KERNEL_BUILDER(
679751
.Device(DEVICE_CPU) \
680752
.TypeConstraint<key_dtype>("Tin") \
681753
.TypeConstraint<value_dtype>("Tout"), \
682-
HashTableFindWithExistsOp<key_dtype, value_dtype>);
754+
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
755+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
756+
.Device(DEVICE_CPU) \
757+
.TypeConstraint<key_dtype>("key_dtype") \
758+
.TypeConstraint<value_dtype>("value_dtype"), \
759+
HashTableExportToFileOp<key_dtype, value_dtype>); \
760+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
761+
.Device(DEVICE_CPU) \
762+
.TypeConstraint<key_dtype>("key_dtype") \
763+
.TypeConstraint<value_dtype>("value_dtype"), \
764+
HashTableImportFromFileOp<key_dtype, value_dtype>);
683765

684766
REGISTER_KERNEL(int32, double);
685767
REGISTER_KERNEL(int32, float);

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 155 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
#if GOOGLE_CUDA
1616

1717
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
18+
1819
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
1920
#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
2021

@@ -214,10 +215,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
214215
if (cur_size > 0) {
215216
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
216217
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
217-
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
218+
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
219+
sizeof(V) * runtime_dim_ * cur_size));
218220
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
219-
d_dump_counter, stream);
220-
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
221+
d_dump_counter, stream);
222+
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t),
223+
cudaMemcpyDeviceToHost, stream);
221224
CUDA_CHECK(cudaStreamSynchronize(stream));
222225
}
223226

@@ -226,8 +229,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
226229
CreateTable(new_max_size, &table_);
227230

228231
if (cur_size > 0) {
229-
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
230-
h_dump_counter, stream);
232+
table_->upsert((const K*)d_keys,
233+
(const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
234+
stream);
231235
cudaStreamSynchronize(stream);
232236
cudaFree(d_keys);
233237
cudaFree(d_values);
@@ -387,6 +391,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
387391
return Status::OK();
388392
}
389393

394+
Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
395+
const size_t buffer_size) {
396+
cudaStream_t _stream;
397+
CUDA_CHECK(cudaStreamCreate(&_stream));
398+
399+
{
400+
tf_shared_lock l(mu_);
401+
table_->dump_to_file(ctx, filepath, runtime_dim_, _stream, buffer_size);
402+
CUDA_CHECK(cudaStreamSynchronize(_stream));
403+
}
404+
CUDA_CHECK(cudaStreamDestroy(_stream));
405+
return Status::OK();
406+
}
407+
408+
Status ImportValuesFromFile(OpKernelContext* ctx, const string filepath,
409+
const size_t buffer_size) {
410+
cudaStream_t _stream;
411+
CUDA_CHECK(cudaStreamCreate(&_stream));
412+
413+
{
414+
tf_shared_lock l(mu_);
415+
416+
string keyfile = filepath + ".keys";
417+
FILE* tmpfd = fopen(keyfile.c_str(), "rb");
418+
if (tmpfd == nullptr) {
419+
return errors::NotFound("Failed to read key file", keyfile);
420+
}
421+
fseek(tmpfd, 0, SEEK_END);
422+
long int filesize = ftell(tmpfd);
423+
if (filesize <= 0) {
424+
fclose(tmpfd);
425+
return errors::NotFound("Empty key file.");
426+
}
427+
size_t size = static_cast<size_t>(filesize) / sizeof(K);
428+
fseek(tmpfd, 0, SEEK_SET);
429+
fclose(tmpfd);
430+
431+
table_->clear(_stream);
432+
CUDA_CHECK(cudaStreamSynchronize(_stream));
433+
RehashIfNeeded(_stream, size);
434+
table_->load_from_file(ctx, filepath, size, runtime_dim_, _stream,
435+
buffer_size);
436+
CUDA_CHECK(cudaStreamSynchronize(_stream));
437+
}
438+
CUDA_CHECK(cudaStreamDestroy(_stream));
439+
return Status::OK();
440+
}
441+
390442
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
391443
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
392444
TensorShape key_shape() const final { return TensorShape(); }
@@ -625,6 +677,36 @@ REGISTER_KERNEL_BUILDER(
625677
Name(PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
626678
HashTableExportGpuOp);
627679

680+
// Op that export all keys and values to file.
681+
template <class K, class V>
682+
class HashTableExportToFileGpuOp : public OpKernel {
683+
public:
684+
explicit HashTableExportToFileGpuOp(OpKernelConstruction* ctx)
685+
: OpKernel(ctx) {
686+
int64 signed_buffer_size = 0;
687+
ctx->GetAttr("buffer_size", &signed_buffer_size);
688+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
689+
}
690+
691+
void Compute(OpKernelContext* ctx) override {
692+
lookup::LookupInterface* table;
693+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
694+
core::ScopedUnref unref_me(table);
695+
696+
const Tensor& ftensor = ctx->input(1);
697+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
698+
errors::InvalidArgument("filepath must be scalar."));
699+
string filepath = string(ftensor.scalar<tstring>()().data());
700+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
701+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
702+
OP_REQUIRES_OK(
703+
ctx, table_cuckoo->ExportValuesToFile(ctx, filepath, buffer_size_));
704+
}
705+
706+
private:
707+
size_t buffer_size_;
708+
};
709+
628710
// Clear the table and insert data.
629711
class HashTableImportGpuOp : public OpKernel {
630712
public:
@@ -651,33 +733,76 @@ REGISTER_KERNEL_BUILDER(
651733
Name(PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
652734
HashTableImportGpuOp);
653735

736+
// Op that import from file.
737+
template <class K, class V>
738+
class HashTableImportFromFileGpuOp : public OpKernel {
739+
public:
740+
explicit HashTableImportFromFileGpuOp(OpKernelConstruction* ctx)
741+
: OpKernel(ctx) {
742+
int64 signed_buffer_size = 0;
743+
ctx->GetAttr("buffer_size", &signed_buffer_size);
744+
buffer_size_ = static_cast<size_t>(signed_buffer_size);
745+
}
746+
747+
void Compute(OpKernelContext* ctx) override {
748+
lookup::LookupInterface* table;
749+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
750+
core::ScopedUnref unref_me(table);
751+
752+
const Tensor& ftensor = ctx->input(1);
753+
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
754+
errors::InvalidArgument("filepath must be scalar."));
755+
string filepath = string(ftensor.scalar<tstring>()().data());
756+
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
757+
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
758+
OP_REQUIRES_OK(
759+
ctx, table_cuckoo->ImportValuesFromFile(ctx, filepath, buffer_size_));
760+
}
761+
762+
private:
763+
size_t buffer_size_;
764+
};
765+
654766
// Register the CuckooHashTableOfTensors op.
655767

656-
#define REGISTER_KERNEL(key_dtype, value_dtype) \
657-
REGISTER_KERNEL_BUILDER( \
658-
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
659-
.Device(DEVICE_GPU) \
660-
.TypeConstraint<key_dtype>("key_dtype") \
661-
.TypeConstraint<value_dtype>("value_dtype"), \
662-
HashTableGpuOp< \
663-
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
664-
key_dtype, value_dtype>); \
665-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
666-
.Device(DEVICE_GPU) \
667-
.TypeConstraint<key_dtype>("key_dtype") \
668-
.TypeConstraint<value_dtype>("value_dtype"), \
669-
HashTableClearGpuOp<key_dtype, value_dtype>) \
670-
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
671-
.Device(DEVICE_GPU) \
672-
.TypeConstraint<key_dtype>("key_dtype") \
673-
.TypeConstraint<value_dtype>("value_dtype"), \
674-
HashTableAccumGpuOp<key_dtype, value_dtype>) \
675-
REGISTER_KERNEL_BUILDER( \
676-
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
677-
.Device(DEVICE_GPU) \
678-
.TypeConstraint<key_dtype>("Tin") \
679-
.TypeConstraint<value_dtype>("Tout"), \
680-
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
768+
#define REGISTER_KERNEL(key_dtype, value_dtype) \
769+
REGISTER_KERNEL_BUILDER( \
770+
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
771+
.Device(DEVICE_GPU) \
772+
.TypeConstraint<key_dtype>("key_dtype") \
773+
.TypeConstraint<value_dtype>("value_dtype"), \
774+
HashTableGpuOp< \
775+
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
776+
key_dtype, value_dtype>); \
777+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
778+
.Device(DEVICE_GPU) \
779+
.TypeConstraint<key_dtype>("key_dtype") \
780+
.TypeConstraint<value_dtype>("value_dtype"), \
781+
HashTableClearGpuOp<key_dtype, value_dtype>); \
782+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
783+
.Device(DEVICE_GPU) \
784+
.TypeConstraint<key_dtype>("key_dtype") \
785+
.TypeConstraint<value_dtype>("value_dtype"), \
786+
HashTableAccumGpuOp<key_dtype, value_dtype>); \
787+
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
788+
.Device(DEVICE_GPU) \
789+
.HostMemory("filepath") \
790+
.TypeConstraint<key_dtype>("key_dtype") \
791+
.TypeConstraint<value_dtype>("value_dtype"), \
792+
HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
793+
REGISTER_KERNEL_BUILDER( \
794+
Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
795+
.Device(DEVICE_GPU) \
796+
.HostMemory("filepath") \
797+
.TypeConstraint<key_dtype>("key_dtype") \
798+
.TypeConstraint<value_dtype>("value_dtype"), \
799+
HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
800+
REGISTER_KERNEL_BUILDER( \
801+
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
802+
.Device(DEVICE_GPU) \
803+
.TypeConstraint<key_dtype>("Tin") \
804+
.TypeConstraint<value_dtype>("Tout"), \
805+
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
681806

682807
REGISTER_KERNEL(int64, float);
683808
REGISTER_KERNEL(int64, Eigen::half);

0 commit comments

Comments
 (0)