@@ -152,7 +152,7 @@ def sample(
152
152
blocked_tokens : list [int ] | None = None ,
153
153
filters : list [ExLlamaV2Filter ] | None = None ,
154
154
filter_prefer_eos : bool = False ,
155
- sync : bool = False
155
+ sync : bool = False ,
156
156
):
157
157
158
158
"""
@@ -273,6 +273,9 @@ def prep_logit_filter(lf):
273
273
for f in filters :
274
274
275
275
pt , et = f .next ()
276
+ if len (filters ) > 1 and not isinstance (pt , set ):
277
+ pt , et = set (pt ), set (et )
278
+
276
279
if pt is not None : pass_tokens = pt if pass_tokens is None else pass_tokens & pt
277
280
if et is not None : end_tokens = et if end_tokens is None else end_tokens | et
278
281
@@ -290,9 +293,15 @@ def prep_logit_filter(lf):
290
293
return output_tokens , output_ktokens , output_kprobs , output_probs , end_filter
291
294
292
295
if filter_prefer_eos and tokenizer .eos_token_id in pass_tokens :
293
- pass_tokens = { tokenizer .eos_token_id }
294
- logit_filter = prep_logit_filter (logit_filter )
295
- ext_c .logit_filter_exclusive (logit_filter , [sorted (list (pass_tokens ))])
296
+ pass_tokens_list = [tokenizer .eos_token_id ]
297
+ logit_filter = prep_logit_filter (logit_filter )
298
+ ext_c .logit_filter_exclusive (logit_filter , pass_tokens_list )
299
+ else :
300
+ logit_filter = prep_logit_filter (logit_filter )
301
+ if isinstance (pass_tokens , set ):
302
+ ext_c .logit_filter_exclusive (logit_filter , [sorted (list (pass_tokens ))])
303
+ else :
304
+ ext_c .logit_filter_exclusive (logit_filter , [pass_tokens ])
296
305
297
306
# Healing
298
307
0 commit comments