Skip to content

Commit 547135c

Browse files
committed
Skip sampling if only one token allowed by filters
1 parent 0978ba5 commit 547135c

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

exllamav2/generator/sampler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,17 @@ def prep_logit_filter(lf):
278278

279279
if pass_tokens is not None:
280280
assert pass_tokens, "Filter excluded all tokens"
281+
282+
# Special case if a single token passes
283+
if len(pass_tokens) == 1 and return_top_tokens == 0 and prefix_token is None:
284+
single_passed_token = next(iter(pass_tokens))
285+
output_tokens = torch.tensor([[single_passed_token]], dtype=torch.long)
286+
output_probs = torch.tensor([[1]], dtype=torch.float)
287+
output_ktokens = none_tensor
288+
output_kprobs = none_tensor
289+
end_filter = (single_passed_token in end_tokens)
290+
return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter
291+
281292
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
282293
pass_tokens = { tokenizer.eos_token_id }
283294
logit_filter = prep_logit_filter(logit_filter)

0 commit comments

Comments
 (0)