Skip to content

Commit 097bd21

Browse files
committed
Bugfix: rehashing in GPU hashtable is not enough when meeting large insert
1 parent ac8ec80 commit 097bd21

File tree

1 file changed

+49
-30
lines changed

1 file changed

+49
-30
lines changed

tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
5656
Status status = ReadInt64FromEnvVar("TF_HASHTABLE_INIT_SIZE",
5757
1024 * 8, // 8192 KV pairs by default
5858
&env_var);
59-
min_size_ = (size_t)env_var;
59+
min_size_ = (size_t)env_var / 2;
6060
max_size_ = (size_t)env_var;
6161
} else {
62-
min_size_ = init_size;
62+
min_size_ = init_size / 2;
6363
max_size_ = init_size;
6464
}
6565

@@ -176,47 +176,65 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
176176
}
177177

178178
void RehashIfNeeded(cudaStream_t stream) {
179-
K* d_keys;
180-
gpu::ValueArrayBase<V>* d_values;
179+
RehashIfNeeded(stream, min_size_);
180+
}
181+
182+
void RehashIfNeeded(cudaStream_t stream, size_t expecting) {
183+
K* d_keys = nullptr;
184+
gpu::ValueArrayBase<V>* d_values = nullptr;
181185
size_t* d_dump_counter;
182-
size_t new_max_size = max_size_;
186+
size_t capacity = table_->get_capacity();
187+
188+
size_t cur_size = table_->get_size(stream);
189+
expecting += cur_size;
190+
expecting = max(min_size_, expecting);
183191

184-
size_t total_size = table_->get_size(stream);
185192
CUDA_CHECK(cudaStreamSynchronize(stream));
186-
if (total_size >= 0.75 * max_size_) {
187-
new_max_size = max_size_ * 2;
193+
size_t new_max_size = capacity;
194+
bool need_rehash = false;
195+
while (expecting > 0.75 * new_max_size) {
196+
new_max_size *= 2;
197+
need_rehash = true;
188198
}
189-
if (total_size < 0.25 * max_size_ && max_size_ > min_size_) {
190-
new_max_size = max_size_ / 2;
199+
if (expecting < 0.25 * capacity) {
200+
new_max_size = capacity / 2;
201+
need_rehash = true;
191202
}
192-
if (new_max_size != max_size_) { // rehash manually.
193-
size_t capacity = table_->get_capacity();
203+
204+
if (need_rehash) {
205+
LOG(INFO) << "Need to rehash GPU HashTable, where capacity=" << capacity
206+
<< ", current_size=" << cur_size << " and expecting "
207+
<< expecting;
194208
size_t h_dump_counter = 0;
195-
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
196-
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * capacity));
197-
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
198-
sizeof(V) * runtime_dim_ * capacity));
199-
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
209+
210+
if (cur_size > 0) {
211+
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
212+
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
213+
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
214+
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
200215
d_dump_counter, stream);
201-
CUDA_CHECK(cudaStreamSynchronize(stream));
216+
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
217+
CUDA_CHECK(cudaStreamSynchronize(stream));
218+
}
202219

203220
delete table_;
204221
table_ = NULL;
205222
CreateTable(new_max_size, &table_);
206-
CUDA_CHECK(cudaStreamSynchronize(stream));
207-
CUDA_CHECK(cudaMemcpy((size_t*)&h_dump_counter, (size_t*)d_dump_counter,
208-
sizeof(size_t), cudaMemcpyDefault));
209-
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
223+
224+
if (cur_size > 0) {
225+
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
210226
h_dump_counter, stream);
211-
CUDA_CHECK(cudaStreamSynchronize(stream));
212-
CUDA_CHECK(cudaFree(d_keys));
213-
CUDA_CHECK(cudaFree(d_values));
214-
CUDA_CHECK(cudaFree(d_dump_counter));
227+
cudaStreamSynchronize(stream);
228+
cudaFree(d_keys);
229+
cudaFree(d_values);
230+
cudaFree(d_dump_counter);
231+
}
215232
max_size_ = new_max_size;
233+
216234
LOG(INFO) << "HashTable on GPU changes to new status: [size="
217-
<< total_size << ", max_size=" << max_size_
235+
<< h_dump_counter << ", max_size=" << max_size_
218236
<< ", load factor=" << std::setprecision(2)
219-
<< (float)total_size / (float)max_size_ << "].";
237+
<< (float)h_dump_counter / (float)max_size_ << "].";
220238
}
221239
}
222240

@@ -227,7 +245,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
227245
CUDA_CHECK(cudaStreamCreate(&_stream));
228246
{
229247
mutex_lock l(mu_);
230-
RehashIfNeeded(_stream);
248+
RehashIfNeeded(_stream, len);
231249
table_->upsert((const K*)keys.tensor_data().data(),
232250
(const gpu::ValueArrayBase<V>*)values.tensor_data().data(),
233251
len, _stream);
@@ -245,7 +263,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
245263
CUDA_CHECK(cudaStreamCreate(&_stream));
246264
{
247265
mutex_lock l(mu_);
248-
RehashIfNeeded(_stream);
266+
RehashIfNeeded(_stream, len);
249267
table_->accum(
250268
(const K*)keys.tensor_data().data(),
251269
(const gpu::ValueArrayBase<V>*)values_or_deltas.tensor_data().data(),
@@ -300,6 +318,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
300318
if (len > 0) {
301319
cudaStream_t _stream;
302320
CUDA_CHECK(cudaStreamCreate(&_stream));
321+
RehashIfNeeded(_stream, len);
303322
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * len));
304323
CUDA_CHECK(
305324
cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * len));

0 commit comments

Comments
 (0)