@@ -233,9 +233,9 @@ class ExLlamaV2DynamicGenerator:
233
233
234
234
max_sampling_threads : int
235
235
min_sampling_threads : int
236
- sampling_pool : ThreadPoolExecutor
237
- filter_pool : ThreadPoolExecutor
238
- filter_queue : list
236
+ sampling_pool : ThreadPoolExecutor | None
237
+ filter_pool : ThreadPoolExecutor | None
238
+ filter_queue : list | None
239
239
240
240
241
241
def __init__ (
@@ -255,6 +255,7 @@ def __init__(
255
255
max_sampling_threads : int = 16 ,
256
256
min_sampling_threads : int = 3 ,
257
257
paged : bool = True ,
258
+ filter_background_eval : bool = True ,
258
259
** kwargs
259
260
):
260
261
"""
@@ -316,6 +317,10 @@ def __init__(
316
317
does not require paged attention support, but in which the max supported batch size is 1. CFG also will
317
318
not work in this mode.
318
319
320
+ :param filter_background_eval:
321
+ Try to overlap filter evaluation with model forward pass. This should generally have no downside since
322
+ filters are evaluated by the CPU which will otherwise be busywaiting after CUDA workload is scheduled.
323
+
319
324
:param kwargs:
320
325
"""
321
326
@@ -449,8 +454,12 @@ def __init__(
449
454
450
455
# Filter threads
451
456
452
- self .filter_pool = ThreadPoolExecutor (max_workers = 16 )
453
- self .filter_queue = []
457
+ if filter_background_eval :
458
+ self .filter_pool = ThreadPoolExecutor (max_workers = 16 )
459
+ self .filter_queue = []
460
+ else :
461
+ self .filter_pool = None
462
+ self .filter_queue = None
454
463
455
464
# Temp buffers for defrag
456
465
@@ -1243,7 +1252,8 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
1243
1252
next_k_probs ,
1244
1253
next_prob ,
1245
1254
filter_eos ,
1246
- results
1255
+ results ,
1256
+ i == 0
1247
1257
)
1248
1258
1249
1259
if eos :
@@ -1867,7 +1877,8 @@ def receive_sample(
1867
1877
next_k_probs : torch .Tensor | None ,
1868
1878
next_prob : torch .Tensor | None ,
1869
1879
filter_eos : bool | None ,
1870
- results : list
1880
+ results : list ,
1881
+ first_sample_in_sd_batch : bool = True
1871
1882
):
1872
1883
page_size = self .generator .page_size
1873
1884
@@ -1879,15 +1890,16 @@ def receive_sample(
1879
1890
f .feed (next_token )
1880
1891
if not f .can_mask_logits () or not f .use_background_worker ():
1881
1892
all_mask = False
1882
- if all_mask :
1883
- # Using logit mask(s)
1884
- for f in self .filters :
1885
- self .generator .filter_queue .append ((f , True ))
1886
- else :
1887
- # Using allowed token list(s)
1888
- for f in self .filters :
1889
- if f .use_background_worker ():
1890
- self .generator .filter_queue .append ((f , False ))
1893
+ if first_sample_in_sd_batch and self .generator .filter_queue is not None :
1894
+ if all_mask :
1895
+ # Using logit mask(s)
1896
+ for f in self .filters :
1897
+ self .generator .filter_queue .append ((f , True ))
1898
+ else :
1899
+ # Using allowed token list(s)
1900
+ for f in self .filters :
1901
+ if f .use_background_worker ():
1902
+ self .generator .filter_queue .append ((f , False ))
1891
1903
1892
1904
# Accept token
1893
1905
0 commit comments