Skip to content

Commit c2e897d

Browse files
Lifannrhdong
authored andcommitted
Add int64-int64 key-value in combination on GPU.
1 parent 3a55f58 commit c2e897d

File tree

7 files changed

+103
-1
lines changed

7 files changed

+103
-1
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ REGISTER_KERNEL_BUILDER(
662662

663663
REGISTER_KERNEL(int64, float);
664664
REGISTER_KERNEL(int64, Eigen::half);
665+
REGISTER_KERNEL(int64, int64);
665666
REGISTER_KERNEL(int64, int32);
666667
REGISTER_KERNEL(int64, int8);
667668

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ void CreateTableImpl(TableWrapperBase<K, V>** pptable, size_t max_size,
176176

177177
DECLARE_CREATE_TABLE(int64, float);
178178
DECLARE_CREATE_TABLE(int64, Eigen::half);
179+
DECLARE_CREATE_TABLE(int64, int64);
179180
DECLARE_CREATE_TABLE(int64, int32);
180181
DECLARE_CREATE_TABLE(int64, int8);
181182

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
17+
namespace tensorflow {
18+
namespace recommenders_addons {
19+
namespace lookup {
20+
namespace gpu {
21+
DEFINE_CREATE_TABLE(0, int64, int64, 0, 0);
22+
} // namespace gpu
23+
} // namespace lookup
24+
} // namespace recommenders_addons
25+
} // namespace tensorflow
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
17+
namespace tensorflow {
18+
namespace recommenders_addons {
19+
namespace lookup {
20+
namespace gpu {
21+
DEFINE_CREATE_TABLE(1, int64, int64, 0, 5);
22+
} // namespace gpu
23+
} // namespace lookup
24+
} // namespace recommenders_addons
25+
} // namespace tensorflow
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
17+
namespace tensorflow {
18+
namespace recommenders_addons {
19+
namespace lookup {
20+
namespace gpu {
21+
DEFINE_CREATE_TABLE(2, int64, int64, 1, 0);
22+
} // namespace gpu
23+
} // namespace lookup
24+
} // namespace recommenders_addons
25+
} // namespace tensorflow
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
17+
namespace tensorflow {
18+
namespace recommenders_addons {
19+
namespace lookup {
20+
namespace gpu {
21+
DEFINE_CREATE_TABLE(3, int64, int64, 1, 5);
22+
} // namespace gpu
23+
} // namespace lookup
24+
} // namespace recommenders_addons
25+
} // namespace tensorflow

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_variable(self):
370370
dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200]
371371
kv_list = [[dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32],
372372
[dtypes.int64, dtypes.half], [dtypes.int64, dtypes.int8],
373-
[dtypes.int32, dtypes.float32]]
373+
[dtypes.int32, dtypes.float32], [dtypes.int64, dtypes.int64]]
374374
else:
375375
dim_list = [1, 8, 16, 128]
376376
kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32],

0 commit comments

Comments
 (0)