Skip to content

Commit f4119ae

Browse files
committed
Fix background filter eval when draft model used
1 parent 4f83f52 commit f4119ae

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

exllamav2/generator/dynamic.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ class ExLlamaV2DynamicGenerator:
233233

234234
max_sampling_threads: int
235235
min_sampling_threads: int
236-
sampling_pool: ThreadPoolExecutor
237-
filter_pool: ThreadPoolExecutor
238-
filter_queue: list
236+
sampling_pool: ThreadPoolExecutor | None
237+
filter_pool: ThreadPoolExecutor | None
238+
filter_queue: list | None
239239

240240

241241
def __init__(
@@ -255,6 +255,7 @@ def __init__(
255255
max_sampling_threads: int = 16,
256256
min_sampling_threads: int = 3,
257257
paged: bool = True,
258+
filter_background_eval: bool = True,
258259
**kwargs
259260
):
260261
"""
@@ -316,6 +317,10 @@ def __init__(
316317
does not require paged attention support, but in which the max supported batch size is 1. CFG also will
317318
not work in this mode.
318319
320+
:param filter_background_eval:
321+
Try to overlap filter evaluation with model forward pass. This should generally have no downside since
322+
filters are evaluated by the CPU which will otherwise be busywaiting after CUDA workload is scheduled.
323+
319324
:param kwargs:
320325
"""
321326

@@ -449,8 +454,12 @@ def __init__(
449454

450455
# Filter threads
451456

452-
self.filter_pool = ThreadPoolExecutor(max_workers = 16)
453-
self.filter_queue = []
457+
if filter_background_eval:
458+
self.filter_pool = ThreadPoolExecutor(max_workers = 16)
459+
self.filter_queue = []
460+
else:
461+
self.filter_pool = None
462+
self.filter_queue = None
454463

455464
# Temp buffers for defrag
456465

@@ -1243,7 +1252,8 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
12431252
next_k_probs,
12441253
next_prob,
12451254
filter_eos,
1246-
results
1255+
results,
1256+
i == 0
12471257
)
12481258

12491259
if eos:
@@ -1867,7 +1877,8 @@ def receive_sample(
18671877
next_k_probs: torch.Tensor | None,
18681878
next_prob: torch.Tensor | None,
18691879
filter_eos: bool | None,
1870-
results: list
1880+
results: list,
1881+
first_sample_in_sd_batch: bool = True
18711882
):
18721883
page_size = self.generator.page_size
18731884

@@ -1879,15 +1890,16 @@ def receive_sample(
18791890
f.feed(next_token)
18801891
if not f.can_mask_logits() or not f.use_background_worker():
18811892
all_mask = False
1882-
if all_mask:
1883-
# Using logit mask(s)
1884-
for f in self.filters:
1885-
self.generator.filter_queue.append((f, True))
1886-
else:
1887-
# Using allowed token list(s)
1888-
for f in self.filters:
1889-
if f.use_background_worker():
1890-
self.generator.filter_queue.append((f, False))
1893+
if first_sample_in_sd_batch and self.generator.filter_queue is not None:
1894+
if all_mask:
1895+
# Using logit mask(s)
1896+
for f in self.filters:
1897+
self.generator.filter_queue.append((f, True))
1898+
else:
1899+
# Using allowed token list(s)
1900+
for f in self.filters:
1901+
if f.use_background_worker():
1902+
self.generator.filter_queue.append((f, False))
18911903

18921904
# Accept token
18931905

exllamav2/generator/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ def prep_logit_filter(lf):
442442
"Attempting to use precomputed logit mask, but filter is not precomputing mask"
443443
flat_logits = logits[0][0]
444444
logits = f.mask_logits(flat_logits).view(1, 1, -1)
445+
# not_inf_indices = torch.nonzero(logits != -float('inf'), as_tuple = True)
446+
# txt = [tokenizer.get_id_to_piece_list()[i] for i in not_inf_indices[2].tolist()]
447+
# print(txt)
445448
end_tokens = None
446449

447450
elif len(filters) > 0:

0 commit comments

Comments
 (0)