@@ -56,10 +56,10 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
56
56
Status status = ReadInt64FromEnvVar (" TF_HASHTABLE_INIT_SIZE" ,
57
57
1024 * 8 , // 8192 KV pairs by default
58
58
&env_var);
59
- min_size_ = (size_t )env_var;
59
+ min_size_ = (size_t )env_var / 2 ;
60
60
max_size_ = (size_t )env_var;
61
61
} else {
62
- min_size_ = init_size;
62
+ min_size_ = init_size / 2 ;
63
63
max_size_ = init_size;
64
64
}
65
65
@@ -176,47 +176,65 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
176
176
}
177
177
178
178
void RehashIfNeeded (cudaStream_t stream) {
179
- K* d_keys;
180
- gpu::ValueArrayBase<V>* d_values;
179
+ RehashIfNeeded (stream, min_size_);
180
+ }
181
+
182
+ void RehashIfNeeded (cudaStream_t stream, size_t expecting) {
183
+ K* d_keys = nullptr ;
184
+ gpu::ValueArrayBase<V>* d_values = nullptr ;
181
185
size_t * d_dump_counter;
182
- size_t new_max_size = max_size_;
186
+ size_t capacity = table_->get_capacity ();
187
+
188
+ size_t cur_size = table_->get_size (stream);
189
+ expecting += cur_size;
190
+ expecting = max (min_size_, expecting);
183
191
184
- size_t total_size = table_->get_size (stream);
185
192
CUDA_CHECK (cudaStreamSynchronize (stream));
186
- if (total_size >= 0.75 * max_size_) {
187
- new_max_size = max_size_ * 2 ;
193
+ size_t new_max_size = capacity;
194
+ bool need_rehash = false ;
195
+ while (expecting > 0.75 * new_max_size) {
196
+ new_max_size *= 2 ;
197
+ need_rehash = true ;
188
198
}
189
- if (total_size < 0.25 * max_size_ && max_size_ > min_size_) {
190
- new_max_size = max_size_ / 2 ;
199
+ if (expecting < 0.25 * capacity) {
200
+ new_max_size = capacity / 2 ;
201
+ need_rehash = true ;
191
202
}
192
- if (new_max_size != max_size_) { // rehash manually.
193
- size_t capacity = table_->get_capacity ();
203
+
204
+ if (need_rehash) {
205
+ LOG (INFO) << " Need to rehash GPU HashTable, where capacity=" << capacity
206
+ << " , current_size=" << cur_size << " and expecting "
207
+ << expecting;
194
208
size_t h_dump_counter = 0 ;
195
- CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
196
- CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * capacity));
197
- CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
198
- sizeof (V) * runtime_dim_ * capacity));
199
- table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
209
+
210
+ if (cur_size > 0 ) {
211
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
212
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
213
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
214
+ table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
200
215
d_dump_counter, stream);
201
- CUDA_CHECK (cudaStreamSynchronize (stream));
216
+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
217
+ CUDA_CHECK (cudaStreamSynchronize (stream));
218
+ }
202
219
203
220
delete table_;
204
221
table_ = NULL ;
205
222
CreateTable (new_max_size, &table_);
206
- CUDA_CHECK (cudaStreamSynchronize (stream));
207
- CUDA_CHECK (cudaMemcpy ((size_t *)&h_dump_counter, (size_t *)d_dump_counter,
208
- sizeof (size_t ), cudaMemcpyDefault));
209
- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
223
+
224
+ if (cur_size > 0 ) {
225
+ table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
210
226
h_dump_counter, stream);
211
- CUDA_CHECK (cudaStreamSynchronize (stream));
212
- CUDA_CHECK (cudaFree (d_keys));
213
- CUDA_CHECK (cudaFree (d_values));
214
- CUDA_CHECK (cudaFree (d_dump_counter));
227
+ cudaStreamSynchronize (stream);
228
+ cudaFree (d_keys);
229
+ cudaFree (d_values);
230
+ cudaFree (d_dump_counter);
231
+ }
215
232
max_size_ = new_max_size;
233
+
216
234
LOG (INFO) << " HashTable on GPU changes to new status: [size="
217
- << total_size << " , max_size=" << max_size_
235
+ << h_dump_counter << " , max_size=" << max_size_
218
236
<< " , load factor=" << std::setprecision (2 )
219
- << (float )total_size / (float )max_size_ << " ]." ;
237
+ << (float )h_dump_counter / (float )max_size_ << " ]." ;
220
238
}
221
239
}
222
240
@@ -227,7 +245,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
227
245
CUDA_CHECK (cudaStreamCreate (&_stream));
228
246
{
229
247
mutex_lock l (mu_);
230
- RehashIfNeeded (_stream);
248
+ RehashIfNeeded (_stream, len );
231
249
table_->upsert ((const K*)keys.tensor_data ().data (),
232
250
(const gpu::ValueArrayBase<V>*)values.tensor_data ().data (),
233
251
len, _stream);
@@ -245,7 +263,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
245
263
CUDA_CHECK (cudaStreamCreate (&_stream));
246
264
{
247
265
mutex_lock l (mu_);
248
- RehashIfNeeded (_stream);
266
+ RehashIfNeeded (_stream, len );
249
267
table_->accum (
250
268
(const K*)keys.tensor_data ().data (),
251
269
(const gpu::ValueArrayBase<V>*)values_or_deltas.tensor_data ().data (),
@@ -300,6 +318,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
300
318
if (len > 0 ) {
301
319
cudaStream_t _stream;
302
320
CUDA_CHECK (cudaStreamCreate (&_stream));
321
+ RehashIfNeeded (_stream, len);
303
322
CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * len));
304
323
CUDA_CHECK (
305
324
cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * len));
0 commit comments