Skip to content

Commit 3592668

Browse files
committed
Bugfix: rehashing in GPU hashtable is not enough when meeting large insert
1 parent 2ab0abd commit 3592668

File tree

1 file changed

+49
-30
lines changed

1 file changed

+49
-30
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
5656
Status status = ReadInt64FromEnvVar("TF_HASHTABLE_INIT_SIZE",
5757
1024 * 8, // 8192 KV pairs by default
5858
&env_var);
59-
min_size_ = (size_t)env_var;
59+
min_size_ = (size_t)env_var / 2;
6060
max_size_ = (size_t)env_var;
6161
} else {
62-
min_size_ = init_size;
62+
min_size_ = init_size / 2;
6363
max_size_ = init_size;
6464
}
6565

@@ -180,47 +180,65 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
180180
}
181181

182182
void RehashIfNeeded(cudaStream_t stream) {
183-
K* d_keys;
184-
gpu::ValueArrayBase<V>* d_values;
183+
RehashIfNeeded(stream, min_size_);
184+
}
185+
186+
void RehashIfNeeded(cudaStream_t stream, size_t expecting) {
187+
K* d_keys = nullptr;
188+
gpu::ValueArrayBase<V>* d_values = nullptr;
185189
size_t* d_dump_counter;
186-
size_t new_max_size = max_size_;
190+
size_t capacity = table_->get_capacity();
191+
192+
size_t cur_size = table_->get_size(stream);
193+
expecting += cur_size;
194+
expecting = max(min_size_, expecting);
187195

188-
size_t total_size = table_->get_size(stream);
189196
CUDA_CHECK(cudaStreamSynchronize(stream));
190-
if (total_size >= 0.75 * max_size_) {
191-
new_max_size = max_size_ * 2;
197+
size_t new_max_size = capacity;
198+
bool need_rehash = false;
199+
while (expecting > 0.75 * new_max_size) {
200+
new_max_size *= 2;
201+
need_rehash = true;
192202
}
193-
if (total_size < 0.25 * max_size_ && max_size_ > min_size_) {
194-
new_max_size = max_size_ / 2;
203+
if (expecting < 0.25 * capacity) {
204+
new_max_size = capacity / 2;
205+
need_rehash = true;
195206
}
196-
if (new_max_size != max_size_) { // rehash manually.
197-
size_t capacity = table_->get_capacity();
207+
208+
if (need_rehash) {
209+
LOG(INFO) << "Need to rehash GPU HashTable, where capacity=" << capacity
210+
<< ", current_size=" << cur_size << " and expecting "
211+
<< expecting;
198212
size_t h_dump_counter = 0;
199-
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
200-
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * capacity));
201-
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
202-
sizeof(V) * runtime_dim_ * capacity));
203-
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
213+
214+
if (cur_size > 0) {
215+
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
216+
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
217+
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
218+
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
204219
d_dump_counter, stream);
205-
CUDA_CHECK(cudaStreamSynchronize(stream));
220+
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
221+
CUDA_CHECK(cudaStreamSynchronize(stream));
222+
}
206223

207224
delete table_;
208225
table_ = NULL;
209226
CreateTable(new_max_size, &table_);
210-
CUDA_CHECK(cudaStreamSynchronize(stream));
211-
CUDA_CHECK(cudaMemcpy((size_t*)&h_dump_counter, (size_t*)d_dump_counter,
212-
sizeof(size_t), cudaMemcpyDefault));
213-
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
227+
228+
if (cur_size > 0) {
229+
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
214230
h_dump_counter, stream);
215-
CUDA_CHECK(cudaStreamSynchronize(stream));
216-
CUDA_CHECK(cudaFree(d_keys));
217-
CUDA_CHECK(cudaFree(d_values));
218-
CUDA_CHECK(cudaFree(d_dump_counter));
231+
cudaStreamSynchronize(stream);
232+
cudaFree(d_keys);
233+
cudaFree(d_values);
234+
cudaFree(d_dump_counter);
235+
}
219236
max_size_ = new_max_size;
237+
220238
LOG(INFO) << "HashTable on GPU changes to new status: [size="
221-
<< total_size << ", max_size=" << max_size_
239+
<< h_dump_counter << ", max_size=" << max_size_
222240
<< ", load factor=" << std::setprecision(2)
223-
<< (float)total_size / (float)max_size_ << "].";
241+
<< (float)h_dump_counter / (float)max_size_ << "].";
224242
}
225243
}
226244

@@ -231,7 +249,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
231249
CUDA_CHECK(cudaStreamCreate(&_stream));
232250
{
233251
mutex_lock l(mu_);
234-
RehashIfNeeded(_stream);
252+
RehashIfNeeded(_stream, len);
235253
table_->upsert((const K*)keys.tensor_data().data(),
236254
(const gpu::ValueArrayBase<V>*)values.tensor_data().data(),
237255
len, _stream);
@@ -249,7 +267,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
249267
CUDA_CHECK(cudaStreamCreate(&_stream));
250268
{
251269
mutex_lock l(mu_);
252-
RehashIfNeeded(_stream);
270+
RehashIfNeeded(_stream, len);
253271
table_->accum(
254272
(const K*)keys.tensor_data().data(),
255273
(const gpu::ValueArrayBase<V>*)values_or_deltas.tensor_data().data(),
@@ -304,6 +322,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
304322
if (len > 0) {
305323
cudaStream_t _stream;
306324
CUDA_CHECK(cudaStreamCreate(&_stream));
325+
RehashIfNeeded(_stream, len);
307326
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
308327
CUDA_CHECK(
309328
cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * len));

0 commit comments

Comments
 (0)