Skip to content

Commit 0d5c0bc

Browse files
committed
Asynchronous filter evaluation
1 parent 12f08db commit 0d5c0bc

File tree

3 files changed

+76
-13
lines changed

3 files changed

+76
-13
lines changed

exllamav2/generator/dynamic.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ class ExLlamaV2DynamicGenerator:
232232
max_sampling_threads: int
233233
min_sampling_threads: int
234234
sampling_pool: ThreadPoolExecutor
235+
filter_pool: ThreadPoolExecutor
236+
filter_queue: list
235237

236238

237239
def __init__(
@@ -443,6 +445,11 @@ def __init__(
443445
if max_sampling_threads > 1:
444446
self.sampling_pool = ThreadPoolExecutor(max_workers = max_sampling_threads)
445447

448+
# Filter threads
449+
450+
self.filter_pool = ThreadPoolExecutor(max_workers = 16)
451+
self.filter_queue = []
452+
446453
# Temp buffers for defrag
447454

448455
if self.paged:
@@ -1130,6 +1137,14 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
11301137
loras = self.current_loras,
11311138
)["logits"]
11321139

1140+
# GPU workload is scheduled here, so launch any sampling filters that can run while waiting for CUDA
1141+
1142+
if self.filter_queue:
1143+
for f in self.filter_queue:
1144+
f.background_next(self.filter_pool)
1145+
time.sleep(0)
1146+
self.filter_queue.clear()
1147+
11331148
# Pass logits to jobs for sampling
11341149

11351150
batch_logits = self.logits_pinned[:device_logits.shape[0], :device_logits.shape[1], :]
@@ -1729,10 +1744,10 @@ def receive_logits(
17291744

17301745
# Start filters
17311746

1732-
# TODO: Try to move filter evaluation to the end of the forward pass, before sampling so it can potentially
1733-
# occur while waiting for the CUDA queue
17341747
if self.new_tokens == 0:
1735-
for f in self.filters: f.begin("")
1748+
for f in self.filters:
1749+
f.background_drop()
1750+
f.begin("")
17361751

17371752
# Sample
17381753

@@ -1780,7 +1795,11 @@ def receive_sample(
17801795
# Feed filters
17811796

17821797
if self.new_tokens >= 0:
1783-
for f in self.filters: f.feed(next_token)
1798+
for f in self.filters:
1799+
f.feed(next_token)
1800+
# Evaluate filter in background when possible
1801+
if f.use_background_worker():
1802+
self.generator.filter_queue.append(f)
17841803

17851804
# Accept token
17861805

exllamav2/generator/filters/base.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from exllamav2 import (
2-
ExLlamaV2,
3-
ExLlamaV2Tokenizer,
4-
)
1+
from __future__ import annotations
2+
from threading import Lock
3+
from concurrent.futures import ThreadPoolExecutor, Future
4+
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
5+
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
6+
import torch
57

68
class ExLlamaV2Filter:
79

@@ -11,6 +13,10 @@ class ExLlamaV2Filter:
1113
tokenizer: ExLlamaV2Tokenizer
1214
sequence_str: str
1315

16+
background_result: Future | None = None
17+
18+
# For compatibility
19+
allow_return_type_list: bool = True
1420

1521
def __init__(self,
1622
model: ExLlamaV2,
@@ -31,13 +37,51 @@ def clone(self, c = None):
3137

3238

3339
def begin(self, prefix_str):
34-
pass
40+
raise NotImplementedError
3541

3642

3743
def feed(self, token):
38-
pass
44+
raise NotImplementedError
3945

4046

4147
def next(self):
42-
pass
48+
raise NotImplementedError
49+
50+
51+
def use_background_worker(self) -> bool:
52+
"""
53+
To indicate whether filter can/should run as a background thread. If True, next() will be called
54+
asynchronously after the CUDA workload has been scheduled for the following forward pass, instead of right
55+
before sampling. Should be True for any CPU-intensive filter such as a grammar constraint.
56+
"""
57+
return False
58+
59+
60+
def background_next(self, pool: ThreadPoolExecutor):
61+
"""
62+
Schedule next() via the provided thread pool executor
63+
"""
64+
assert self.background_result is None
65+
self.background_result = pool.submit(self.next)
66+
67+
68+
def background_drop(self):
69+
"""
70+
Clear the result of an asynchronous filter pass. Used when a complex filter reaches an end state and forces
71+
the selection of eos_token_id. next() could still be scheduled after this selection, leaving a pending result
72+
that would break subsequent generations with the same filter.
73+
"""
74+
if self.background_result is not None:
75+
self.background_result.result()
76+
self.background_result = None
77+
4378

79+
def get_next(self) -> tuple:
80+
"""
81+
Return either next() or the result of any scheduled call to next()
82+
"""
83+
if self.background_result is None:
84+
return self.next()
85+
r = self.background_result.result()
86+
self.background_result = None
87+
return r

exllamav2/generator/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ def prep_logit_filter(lf):
272272
end_tokens = None
273273
for f in filters:
274274

275-
pt, et = f.next()
275+
pt, et = f.get_next()
276276
if len(filters) > 1 and not isinstance(pt, set):
277277
pt, et = set(pt), set(et)
278278

279279
if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
280280
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
281281

282282
if pass_tokens is not None:
283-
assert pass_tokens, "Filter excluded all tokens"
283+
assert len(pass_tokens), "Filter excluded all tokens"
284284

285285
# Special case if a single token passes
286286
if len(pass_tokens) == 1 and return_top_tokens == 0 and prefix_token is None:

0 commit comments

Comments
 (0)