Skip to content

Commit 42ad7d2

Browse files
MoFHekarhdong
authored andcommitted
[feat] Add bfloat16 value type support to the HKV for being enhanced by Ampere GPU BF16 training feature.
1 parent a69c805 commit 42ad7d2

File tree

4 files changed

+4
-1
lines changed

4 files changed

+4
-1
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@ REGISTER_KERNEL(int64, int8);
10481048
REGISTER_KERNEL(int64, int32);
10491049
REGISTER_KERNEL(int64, int64);
10501050
REGISTER_KERNEL(int64, Eigen::half);
1051+
REGISTER_KERNEL(int64, Eigen::bfloat16);
10511052

10521053
#undef REGISTER_KERNEL
10531054

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv_impl.cu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ DEFINE_PURE_GPU_HASHTABLE(int64, int8);
2929
DEFINE_PURE_GPU_HASHTABLE(int64, int32);
3030
DEFINE_PURE_GPU_HASHTABLE(int64, int64);
3131
DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::half);
32+
DEFINE_PURE_GPU_HASHTABLE(int64, Eigen::bfloat16);
3233

3334
#undef DEFINE_PURE_GPU_HASHTABLE
3435

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def test_variable(self):
384384
dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 200]
385385
kv_list = [[dtypes.int64, dtypes.float32], [dtypes.int64, dtypes.int32],
386386
[dtypes.int64, dtypes.half], [dtypes.int64, dtypes.int8],
387-
[dtypes.int64, dtypes.int64]]
387+
[dtypes.int64, dtypes.int64], [dtypes.int64, dtypes.bfloat16]]
388388
else:
389389
dim_list = [1, 8, 16, 128]
390390
kv_list = [[dtypes.int32, dtypes.double], [dtypes.int32, dtypes.float32],

tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ def _get_default_devices():
589589
[dtypes.int64, dtypes.int32],
590590
[dtypes.int64, dtypes.int64],
591591
[dtypes.int64, dtypes.half],
592+
[dtypes.int64, dtypes.bfloat16],
592593
]
593594
if is_macos() and is_arm64():
594595
if value_dtype == dtypes.half or value_dtype == dtypes.bfloat16:

0 commit comments

Comments
 (0)