Skip to content

Commit 12f08db

Browse files
committed
Accept Sequence return type from ExLlamaV2Filter.next()
1 parent ea27954 commit 12f08db

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

exllamav2/generator/sampler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def sample(
152152
blocked_tokens: list[int] | None = None,
153153
filters: list[ExLlamaV2Filter] | None = None,
154154
filter_prefer_eos: bool = False,
155-
sync: bool = False
155+
sync: bool = False,
156156
):
157157

158158
"""
@@ -273,6 +273,9 @@ def prep_logit_filter(lf):
273273
for f in filters:
274274

275275
pt, et = f.next()
276+
if len(filters) > 1 and not isinstance(pt, set):
277+
pt, et = set(pt), set(et)
278+
276279
if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
277280
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
278281

@@ -290,9 +293,15 @@ def prep_logit_filter(lf):
290293
return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter
291294

292295
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
293-
pass_tokens = { tokenizer.eos_token_id }
294-
logit_filter = prep_logit_filter(logit_filter)
295-
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
296+
pass_tokens_list = [tokenizer.eos_token_id]
297+
logit_filter = prep_logit_filter(logit_filter)
298+
ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list)
299+
else:
300+
logit_filter = prep_logit_filter(logit_filter)
301+
if isinstance(pass_tokens, set):
302+
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
303+
else:
304+
ext_c.logit_filter_exclusive(logit_filter, [pass_tokens])
296305

297306
# Healing
298307

0 commit comments

Comments
 (0)