@@ -56,10 +56,10 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
56
56
Status status = ReadInt64FromEnvVar (" TF_HASHTABLE_INIT_SIZE" ,
57
57
1024 * 8 , // 8192 KV pairs by default
58
58
&env_var);
59
- min_size_ = (size_t )env_var;
59
+ min_size_ = (size_t )env_var / 2 ;
60
60
max_size_ = (size_t )env_var;
61
61
} else {
62
- min_size_ = init_size;
62
+ min_size_ = init_size / 2 ;
63
63
max_size_ = init_size;
64
64
}
65
65
@@ -180,47 +180,65 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
180
180
}
181
181
182
182
void RehashIfNeeded (cudaStream_t stream) {
183
- K* d_keys;
184
- gpu::ValueArrayBase<V>* d_values;
183
+ RehashIfNeeded (stream, min_size_);
184
+ }
185
+
186
+ void RehashIfNeeded (cudaStream_t stream, size_t expecting) {
187
+ K* d_keys = nullptr ;
188
+ gpu::ValueArrayBase<V>* d_values = nullptr ;
185
189
size_t * d_dump_counter;
186
- size_t new_max_size = max_size_;
190
+ size_t capacity = table_->get_capacity ();
191
+
192
+ size_t cur_size = table_->get_size (stream);
193
+ expecting += cur_size;
194
+ expecting = max (min_size_, expecting);
187
195
188
- size_t total_size = table_->get_size (stream);
189
196
CUDA_CHECK (cudaStreamSynchronize (stream));
190
- if (total_size >= 0.75 * max_size_) {
191
- new_max_size = max_size_ * 2 ;
197
+ size_t new_max_size = capacity;
198
+ bool need_rehash = false ;
199
+ while (expecting > 0.75 * new_max_size) {
200
+ new_max_size *= 2 ;
201
+ need_rehash = true ;
192
202
}
193
- if (total_size < 0.25 * max_size_ && max_size_ > min_size_) {
194
- new_max_size = max_size_ / 2 ;
203
+ if (expecting < 0.25 * capacity) {
204
+ new_max_size = capacity / 2 ;
205
+ need_rehash = true ;
195
206
}
196
- if (new_max_size != max_size_) { // rehash manually.
197
- size_t capacity = table_->get_capacity ();
207
+
208
+ if (need_rehash) {
209
+ LOG (INFO) << " Need to rehash GPU HashTable, where capacity=" << capacity
210
+ << " , current_size=" << cur_size << " and expecting "
211
+ << expecting;
198
212
size_t h_dump_counter = 0 ;
199
- CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
200
- CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * capacity));
201
- CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
202
- sizeof (V) * runtime_dim_ * capacity));
203
- table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
213
+
214
+ if (cur_size > 0 ) {
215
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
216
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
217
+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
218
+ table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
204
219
d_dump_counter, stream);
205
- CUDA_CHECK (cudaStreamSynchronize (stream));
220
+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
221
+ CUDA_CHECK (cudaStreamSynchronize (stream));
222
+ }
206
223
207
224
delete table_;
208
225
table_ = NULL ;
209
226
CreateTable (new_max_size, &table_);
210
- CUDA_CHECK (cudaStreamSynchronize (stream));
211
- CUDA_CHECK (cudaMemcpy ((size_t *)&h_dump_counter, (size_t *)d_dump_counter,
212
- sizeof (size_t ), cudaMemcpyDefault));
213
- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
227
+
228
+ if (cur_size > 0 ) {
229
+ table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
214
230
h_dump_counter, stream);
215
- CUDA_CHECK (cudaStreamSynchronize (stream));
216
- CUDA_CHECK (cudaFree (d_keys));
217
- CUDA_CHECK (cudaFree (d_values));
218
- CUDA_CHECK (cudaFree (d_dump_counter));
231
+ cudaStreamSynchronize (stream);
232
+ cudaFree (d_keys);
233
+ cudaFree (d_values);
234
+ cudaFree (d_dump_counter);
235
+ }
219
236
max_size_ = new_max_size;
237
+
220
238
LOG (INFO) << " HashTable on GPU changes to new status: [size="
221
- << total_size << " , max_size=" << max_size_
239
+ << h_dump_counter << " , max_size=" << max_size_
222
240
<< " , load factor=" << std::setprecision (2 )
223
- << (float )total_size / (float )max_size_ << " ]." ;
241
+ << (float )h_dump_counter / (float )max_size_ << " ]." ;
224
242
}
225
243
}
226
244
@@ -231,7 +249,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
231
249
CUDA_CHECK (cudaStreamCreate (&_stream));
232
250
{
233
251
mutex_lock l (mu_);
234
- RehashIfNeeded (_stream);
252
+ RehashIfNeeded (_stream, len );
235
253
table_->upsert ((const K*)keys.tensor_data ().data (),
236
254
(const gpu::ValueArrayBase<V>*)values.tensor_data ().data (),
237
255
len, _stream);
@@ -249,7 +267,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
249
267
CUDA_CHECK (cudaStreamCreate (&_stream));
250
268
{
251
269
mutex_lock l (mu_);
252
- RehashIfNeeded (_stream);
270
+ RehashIfNeeded (_stream, len );
253
271
table_->accum (
254
272
(const K*)keys.tensor_data ().data (),
255
273
(const gpu::ValueArrayBase<V>*)values_or_deltas.tensor_data ().data (),
@@ -304,6 +322,7 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
304
322
if (len > 0 ) {
305
323
cudaStream_t _stream;
306
324
CUDA_CHECK (cudaStreamCreate (&_stream));
325
+ RehashIfNeeded (_stream, len);
307
326
CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * len));
308
327
CUDA_CHECK (
309
328
cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * len));
0 commit comments