Skip to content

Commit b02d738

Browse files
thorneliurhdong
authored andcommitted
add HMACCUM redis module and enable BPV2 in redis impl
1 parent c2e897d commit b02d738

File tree

14 files changed

+2365
-97
lines changed

14 files changed

+2365
-97
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_cluster_connection_pool.hpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,107 @@ every bucket has its own BucketContext for sending data---for locating reply-
12441244
return Status::OK();
12451245
}
12461246

1247+
virtual Status MaccumCommand(
1248+
const Tensor &keys, const Tensor &values_or_delta, const Tensor &exists,
1249+
ThreadContext *thread_context, const int64 begin, const int64 max_i,
1250+
const int64 Velems_per_dim0,
1251+
const std::vector<std::string> &keys_prefix_name_slices) override {
1252+
const int &&total = max_i - begin;
1253+
const int &&argc = total * 2 + 4;
1254+
1255+
const static char *redis_command = "HMACCUM";
1256+
const static std::size_t &&redis_command_byte = 7;
1257+
std::string dTypestr = DataTypeString(values_or_delta.dtype());
1258+
size_t dTypeStrsize = dTypestr.size();
1259+
1260+
const K *const pk_raw_end =
1261+
reinterpret_cast<const K *>(keys.tensor_data().data()) + max_i;
1262+
const K *pk_raw =
1263+
reinterpret_cast<const K *>(keys.tensor_data().data()) + begin;
1264+
1265+
const std::size_t &&V_byte_size = Velems_per_dim0 * sizeof(V);
1266+
1267+
const V *pv_raw =
1268+
reinterpret_cast<const V *>(values_or_delta.tensor_data().data()) +
1269+
begin * Velems_per_dim0;
1270+
1271+
const unsigned &storage_slice = redis_connection_params.storage_slice;
1272+
const unsigned &&vector_len =
1273+
(static_cast<int64>(reinterpret_cast<int>(argc)) /
1274+
redis_connection_params.storage_slice) +
1275+
4;
1276+
1277+
thread_context->HandleReserve(storage_slice, vector_len, total);
1278+
1279+
for (unsigned i = 0; i < storage_slice; ++i) {
1280+
thread_context->HandlePushBack(i, redis_command, redis_command_byte);
1281+
thread_context->HandlePushBack(i, keys_prefix_name_slices[i].data(),
1282+
keys_prefix_name_slices[i].size());
1283+
thread_context->HandlePushBack(i, dTypestr.c_str(), dTypeStrsize);
1284+
}
1285+
1286+
VContentAndTypeSizeResult VCATS_temp;
1287+
// std::vector<char> for storage all string in one KV pair
1288+
std::vector<std::vector<char>> buff_temp(total);
1289+
unsigned key_bucket_locs = 0;
1290+
for (int i = 0; pk_raw != pk_raw_end;
1291+
++i, ++pk_raw, pv_raw += Velems_per_dim0) {
1292+
VCATS_temp = VContentAndTypeSize<V>(VCATS_temp, Velems_per_dim0,
1293+
V_byte_size, pv_raw, buff_temp[i]);
1294+
key_bucket_locs =
1295+
KBucketNum<K>(pk_raw, storage_slice); // TODO: change it to AVX512
1296+
1297+
// Direct access to Tensor data in TensorFlow
1298+
thread_context->HandlePushBack(
1299+
key_bucket_locs, KContentPointer<K>(pk_raw), KTypeSize<K>(pk_raw));
1300+
thread_context->HandlePushBack(
1301+
key_bucket_locs, VCATS_temp.VContentPointer, VCATS_temp.VTypeSize);
1302+
}
1303+
1304+
const bool *pe_raw =
1305+
reinterpret_cast<const bool *>(exists.tensor_data().data()) + begin;
1306+
for (unsigned i = 0; i < storage_slice; ++i) {
1307+
thread_context->HandlePushBack(i, KContentPointer<bool>(pe_raw),
1308+
total * KTypeSize<bool>(pe_raw));
1309+
}
1310+
1311+
auto cmd = [](::sw::redis::Connection &connection,
1312+
const ::sw::redis::StringView &hkey,
1313+
const std::vector<const char *> *ptrs_i,
1314+
const std::vector<std::size_t> *sizes_i) {
1315+
assert(strcmp(ptrs_i->front(), "HMACCUM") == 0);
1316+
assert(sizes_i->front() == redis_command_byte);
1317+
assert(std::string(hkey.data()).compare(ptrs_i[1]) == 0);
1318+
1319+
connection.send(static_cast<int>(ptrs_i->size()),
1320+
const_cast<const char **>(ptrs_i->data()),
1321+
sizes_i->data());
1322+
};
1323+
1324+
std::vector<
1325+
std::future<std::unique_ptr<redisReply, ::sw::redis::ReplyDeleter>>>
1326+
results;
1327+
try {
1328+
for (unsigned i = 0; i < storage_slice; ++i) {
1329+
results.emplace_back(
1330+
network_worker_pool->enqueue([this, &cmd, &thread_context, i] {
1331+
return PipeExecWrite(cmd, 6U, thread_context->buckets[i]);
1332+
}));
1333+
}
1334+
for (auto &&result : results) {
1335+
result.wait();
1336+
}
1337+
if (error_ptr) {
1338+
std::rethrow_exception(error_ptr);
1339+
}
1340+
} catch (const std::exception &err) {
1341+
error_ptr = nullptr;
1342+
return errors::Unknown(err.what());
1343+
}
1344+
1345+
return Status::OK();
1346+
}
1347+
12471348
virtual Status DelCommand(
12481349
const Tensor &keys, ThreadContext *thread_context, const int64 begin,
12491350
const int64 max_i,

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_connection_pool.hpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,95 @@ every bucket has its own BucketContext for sending data---for locating reply-
10511051
return Status::OK();
10521052
}
10531053

1054+
virtual Status MaccumCommand(
1055+
const Tensor &keys, const Tensor &values_or_delta, const Tensor &exists,
1056+
ThreadContext *thread_context, const int64 begin, const int64 max_i,
1057+
const int64 Velems_per_dim0,
1058+
const std::vector<std::string> &keys_prefix_name_slices) override {
1059+
const int &&total = max_i - begin;
1060+
const int &&argc = total * 2 + 4;
1061+
1062+
const static char *redis_command = "HMACCUM";
1063+
const static std::size_t redis_command_byte = 7;
1064+
std::string dTypestr = DataTypeString(values_or_delta.dtype());
1065+
1066+
thread_context->HandleReserve(1U, argc, 0);
1067+
1068+
std::vector<const char *> *ptrs_0 = thread_context->buckets[0]->ptrs.get();
1069+
std::vector<std::size_t> *sizes_0 = thread_context->buckets[0]->sizes.get();
1070+
1071+
const K *const pk_raw_end =
1072+
reinterpret_cast<const K *>(keys.tensor_data().data()) + max_i;
1073+
const K *pk_raw =
1074+
reinterpret_cast<const K *>(keys.tensor_data().data()) + begin;
1075+
1076+
const std::size_t &&V_byte_size = Velems_per_dim0 * sizeof(V);
1077+
1078+
const V *pv_raw =
1079+
reinterpret_cast<const V *>(values_or_delta.tensor_data().data()) +
1080+
begin * Velems_per_dim0;
1081+
1082+
auto ptrs_iter = ptrs_0->begin();
1083+
*ptrs_iter = redis_command;
1084+
++ptrs_iter;
1085+
*ptrs_iter = keys_prefix_name_slices[0].data();
1086+
++ptrs_iter;
1087+
*ptrs_iter = dTypestr.c_str();
1088+
++ptrs_iter;
1089+
1090+
auto sizes_iter = sizes_0->begin();
1091+
*sizes_iter = redis_command_byte;
1092+
++sizes_iter;
1093+
*sizes_iter = keys_prefix_name_slices[0].size();
1094+
++sizes_iter;
1095+
*sizes_iter = dTypestr.size();
1096+
++sizes_iter;
1097+
1098+
VContentAndTypeSizeResult VCATS_temp;
1099+
// std::vector<char> for storage all string in one KV pair
1100+
std::vector<std::vector<char>> buff_temp(total);
1101+
1102+
for (int i = 0; pk_raw != pk_raw_end;
1103+
++i, ++pk_raw, pv_raw += Velems_per_dim0) {
1104+
VCATS_temp = VContentAndTypeSize<V>(VCATS_temp, Velems_per_dim0,
1105+
V_byte_size, pv_raw, buff_temp[i]);
1106+
1107+
*ptrs_iter = KContentPointer<K>(
1108+
pk_raw); // Direct access to Tensor data in TensorFlow
1109+
*(++ptrs_iter) = VCATS_temp.VContentPointer;
1110+
++ptrs_iter;
1111+
1112+
*sizes_iter = KTypeSize<K>(pk_raw); // key data char size
1113+
*(++sizes_iter) = VCATS_temp.VTypeSize;
1114+
++sizes_iter;
1115+
}
1116+
1117+
const bool *pe_raw =
1118+
reinterpret_cast<const bool *>(exists.tensor_data().data()) + begin;
1119+
*ptrs_iter = KContentPointer<bool>(pe_raw);
1120+
*sizes_iter = total * KTypeSize<bool>(pe_raw);
1121+
1122+
assert(ptrs_0->front() == redis_command);
1123+
assert(sizes_0->front() == redis_command_byte);
1124+
1125+
auto cmd = [](::sw::redis::Connection &connection, const int argc,
1126+
const std::vector<const char *> *ptrs_0,
1127+
const std::vector<std::size_t> *sizes_0) {
1128+
connection.send(argc, const_cast<const char **>(ptrs_0->data()),
1129+
sizes_0->data());
1130+
};
1131+
1132+
try {
1133+
redis_conn_write->command(cmd, argc, ptrs_0, sizes_0);
1134+
} catch (const std::exception &err) {
1135+
LOG(ERROR) << "RedisHandler error in MACCUM_COMMAND for HMACCUM "
1136+
<< keys_prefix_name_slices[0] << " -- " << err.what();
1137+
return errors::Unknown(err.what());
1138+
}
1139+
1140+
return Status::OK();
1141+
}
1142+
10541143
virtual Status DelCommand(
10551144
const Tensor &keys, ThreadContext *thread_context, const int64 begin,
10561145
const int64 max_i,

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_connection_util.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,12 @@ class RedisVirtualWrapper {
430430
const int64 begin, const int64 max_i, const int64 Velems_per_dim0,
431431
const std::vector<std::string> &keys_prefix_name_slices) = 0;
432432

433+
virtual Status MaccumCommand(
434+
const Tensor &keys, const Tensor &values, const Tensor &exists,
435+
ThreadContext *thread_context, const int64 begin, const int64 max_i,
436+
const int64 Velems_per_dim0,
437+
const std::vector<std::string> &keys_prefix_name_slices) = 0;
438+
433439
virtual Status DelCommand(
434440
const Tensor &keys, ThreadContext *thread_context, const int64 begin,
435441
const int64 max_i,

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_impl/redis_table_op_util.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ Status launchInsertCore(std::shared_ptr<RedisVirtualWrapper> _table_instance,
132132
return statu;
133133
}
134134

135+
Status launchAccumCore(std::shared_ptr<RedisVirtualWrapper> _table_instance,
136+
std::vector<std::string> &keys_prefix_name_slices,
137+
const Tensor &keys, const Tensor &values_or_delta,
138+
const Tensor &exists, const int64 &Velems_per_flat2_dim0,
139+
std::vector<ThreadContext *> &threads_Insert,
140+
std::mutex &threads_Accum_mutex, const int64 begin,
141+
const int64 end) {
142+
size_t thread_context_id =
143+
SelectAvailableThreadContext(threads_Insert, threads_Accum_mutex);
144+
145+
auto statu = _table_instance->MaccumCommand(
146+
keys, values_or_delta, exists, threads_Insert.at(thread_context_id),
147+
begin, end, Velems_per_flat2_dim0, keys_prefix_name_slices);
148+
149+
threads_Insert[thread_context_id]->thread_occupied.store(
150+
false, std::memory_order_release);
151+
152+
return statu;
153+
}
154+
135155
Status launchDeleteCore(std::shared_ptr<RedisVirtualWrapper> _table_instance,
136156
std::vector<std::string> &keys_prefix_name_slices,
137157
const Tensor &keys,

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/redis_table_op.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class RedisTableOfTensors final : public LookupInterface {
7878
std::vector<ThreadContext *> threads_Delete;
7979
std::mutex threads_Find_mutex;
8080
std::mutex threads_Insert_mutex;
81+
std::mutex threads_Accum_mutex;
8182
std::mutex threads_Delete_mutex;
8283

8384
std::vector<aiocb> IMPORT_content;
@@ -211,6 +212,42 @@ class RedisTableOfTensors final : public LookupInterface {
211212
threads_Insert_mutex, 0, total));
212213
}
213214

215+
void launchAccum_parallel(OpKernelContext *ctx,
216+
std::vector<std::string> &keys_prefix_name_slices,
217+
const Tensor &keys, const Tensor &values_or_delta,
218+
const Tensor &exists, const int64 &total,
219+
const int64 &Velems_per_flat2_dim0,
220+
std::vector<ThreadContext *> &threads_Insert) {
221+
const int64 max_parallelism = (total / multi_redis_cmd_max_argc) + 1;
222+
223+
auto shard = [this, &ctx, &total, &keys_prefix_name_slices, &keys,
224+
&values_or_delta, &exists, &Velems_per_flat2_dim0,
225+
&threads_Insert](int64 begin, int64 end) {
226+
const int64 max_i = std::min(total, end);
227+
228+
OP_REQUIRES_OK(
229+
ctx,
230+
launchAccumCore(_table_instance, keys_prefix_name_slices, keys,
231+
values_or_delta, exists, Velems_per_flat2_dim0,
232+
threads_Insert, threads_Accum_mutex, begin, max_i));
233+
};
234+
int64 slices_size = std::min(total, multi_redis_cmd_max_argc - 1);
235+
auto &worker_threads = *ctx->device()->tensorflow_cpu_worker_threads();
236+
Shard(max_parallelism, worker_threads.workers, total, slices_size, shard);
237+
}
238+
239+
void launchAccum(OpKernelContext *ctx,
240+
std::vector<std::string> &keys_prefix_name_slices,
241+
const Tensor &keys, const Tensor &values_or_delta,
242+
const Tensor &exists, const int64 &total,
243+
const int64 &Velems_per_flat2_dim0,
244+
std::vector<ThreadContext *> &threads_Insert) {
245+
OP_REQUIRES_OK(
246+
ctx, launchAccumCore(_table_instance, keys_prefix_name_slices, keys,
247+
values_or_delta, exists, Velems_per_flat2_dim0,
248+
threads_Insert, threads_Insert_mutex, 0, total));
249+
}
250+
214251
void launchDelete_parallel(OpKernelContext *ctx,
215252
std::vector<std::string> &keys_prefix_name_slices,
216253
const Tensor &keys, const int64 &total,
@@ -691,11 +728,35 @@ class RedisTableOfTensors final : public LookupInterface {
691728
return Status::OK();
692729
}
693730

731+
Status DoAccum(OpKernelContext *ctx, const Tensor &keys,
732+
const Tensor &values_or_delta, const Tensor &exists) {
733+
int64 total = keys.NumElements();
734+
const int64 Velems_per_flat2_dim0 =
735+
values_or_delta.NumElements() / keys.NumElements();
736+
737+
if (total < (multi_redis_cmd_max_argc - 1)) {
738+
launchAccum(ctx, keys_prefix_name_slices, keys, values_or_delta, exists,
739+
total, Velems_per_flat2_dim0, threads_Insert);
740+
} else {
741+
launchAccum_parallel(
742+
ctx, keys_prefix_name_slices, keys, values_or_delta, exists, total,
743+
Velems_per_flat2_dim0,
744+
threads_Insert); // redis commmand args > multi_redis_cmd_max_argc
745+
}
746+
747+
return Status::OK();
748+
}
749+
694750
Status Insert(OpKernelContext *ctx, const Tensor &keys,
695751
const Tensor &values) override {
696752
return DoInsert(false, ctx, keys, values);
697753
}
698754

755+
Status Accum(OpKernelContext *ctx, const Tensor &keys,
756+
const Tensor &values_or_delta, const Tensor &exists) {
757+
return DoAccum(ctx, keys, values_or_delta, exists);
758+
}
759+
699760
Status Remove(OpKernelContext *ctx, const Tensor &keys) override {
700761
int64 total = keys.NumElements();
701762
if (total > 0) {
@@ -1129,6 +1190,45 @@ class HashTableInsertOp : public HashTableOpKernel {
11291190
}
11301191
};
11311192

1193+
// Table accum op.
1194+
template <class K, class V>
1195+
class HashTableAccumOp : public HashTableOpKernel {
1196+
public:
1197+
using HashTableOpKernel::HashTableOpKernel;
1198+
1199+
void Compute(OpKernelContext *ctx) override {
1200+
LookupInterface *table;
1201+
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
1202+
core::ScopedUnref unref_me(table);
1203+
1204+
RedisTableOfTensors<K, V> *redisTable = (RedisTableOfTensors<K, V> *)table;
1205+
1206+
DataTypeVector expected_inputs = {expected_input_0_, table->key_dtype(),
1207+
table->value_dtype(),
1208+
DataTypeToEnum<bool>::v()};
1209+
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
1210+
1211+
const Tensor &keys = ctx->input(1);
1212+
const Tensor &values_or_deltas = ctx->input(2);
1213+
const Tensor &exists = ctx->input(3);
1214+
OP_REQUIRES(ctx, (values_or_deltas.dtype() != DataTypeToEnum<tstring>::v()),
1215+
errors::InvalidArgument(
1216+
"AccumOP is not supporting tstring value type!"));
1217+
OP_REQUIRES_OK(
1218+
ctx, table->CheckKeyAndValueTensorsForInsert(keys, values_or_deltas));
1219+
1220+
int64 memory_used_before = 0;
1221+
if (ctx->track_allocations()) {
1222+
memory_used_before = table->MemoryUsed();
1223+
}
1224+
OP_REQUIRES_OK(ctx, redisTable->Accum(ctx, keys, values_or_deltas, exists));
1225+
if (ctx->track_allocations()) {
1226+
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
1227+
memory_used_before);
1228+
}
1229+
}
1230+
};
1231+
11321232
// Table remove op.
11331233
class HashTableRemoveOp : public HashTableOpKernel {
11341234
public:
@@ -1275,6 +1375,12 @@ REGISTER_KERNEL_BUILDER(
12751375
.TypeConstraint<key_dtype>("key_dtype") \
12761376
.TypeConstraint<value_dtype>("value_dtype"), \
12771377
redis_table::HashTableClearOp<key_dtype, value_dtype>); \
1378+
REGISTER_KERNEL_BUILDER( \
1379+
Name(PREFIX_OP_NAME(RedisTableAccum)) \
1380+
.Device(DEVICE_CPU) \
1381+
.TypeConstraint<key_dtype>("key_dtype") \
1382+
.TypeConstraint<value_dtype>("value_dtype"), \
1383+
redis_table::HashTableAccumOp<key_dtype, value_dtype>); \
12781384
REGISTER_KERNEL_BUILDER( \
12791385
Name(PREFIX_OP_NAME(RedisTableFindWithExists)) \
12801386
.Device(DEVICE_CPU) \

0 commit comments

Comments
 (0)