Skip to content

Commit 3e8e181

Browse files
committed
Add DRY (still needs testing)
1 parent affdc0d commit 3e8e181

File tree

1 file changed

+105
-2
lines changed

1 file changed

+105
-2
lines changed

exllamav2/generator/sampler.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
99
from copy import copy
1010
import threading
11+
from functools import lru_cache
12+
import re
1113
# import line_profiler
1214

1315
_tl_tensors = threading.local()
@@ -37,6 +39,12 @@ def _get_output_probs(shape, dtype):
3739
return _tl_tensors.output_probs
3840

3941

42+
@dataclass
43+
class NgramNode:
44+
value: int = 0
45+
children: dict[int, NgramNode] = field(default_factory = dict)
46+
47+
4048
class ExLlamaV2Sampler:
4149

4250
@dataclass
@@ -74,6 +82,15 @@ class Settings:
7482

7583
post_sampling_hooks: list[ExLlamaV2PostSamplingHook] = field(default_factory = list)
7684

85+
dry_allowed_length: int = 0 # 0 to disable
86+
dry_base: float = 2.0
87+
dry_multiplier: float = 2.0
88+
dry_sequence_breakers: set[int] | None = None
89+
dry_max_ngram: int = 20
90+
91+
ngram_trie: dict[int, NgramNode] = None
92+
ngram_index: int = 0
93+
7794
@staticmethod
7895
def greedy(**kwargs) -> ExLlamaV2Sampler.Settings:
7996
defaults = {
@@ -101,6 +118,11 @@ def greedy_clone(self):
101118
c.token_frequency_penalty = self.token_frequency_penalty
102119
c.token_presence_penalty = self.token_presence_penalty
103120
c.token_bias = None
121+
c.dry_allowed_length = self.dry_allowed_length
122+
c.dry_base = self.dry_allowed_length
123+
c.dry_multiplier = self.dry_multiplier
124+
c.dry_sequence_breakers = self.dry_sequence_breakers
125+
c.dry_max_ngram = self.dry_max_ngram
104126
c.filters = []
105127
return c
106128

@@ -139,6 +161,82 @@ def allow_tokens(
139161
raise ValueError("Incorrect type in allow_tokens list")
140162

141163

164+
@staticmethod
165+
@lru_cache(10)
166+
def get_dry_default_sequence_breaker_tokens(
167+
tokenizer: ExLlamaV2Tokenizer
168+
) -> set[int]:
169+
result = set()
170+
dry_default_sequence_breaker_chars = r".,!?<>\[\]\(\)\{\}\n\t\""
171+
pattern = re.compile(r"[" + dry_default_sequence_breaker_chars + "]")
172+
pieces = tokenizer.get_id_to_piece_list(include_special_tokens = True)
173+
for t in range(len(pieces)):
174+
if bool(pattern.search(pieces[t])):
175+
result.add(t)
176+
for t in tokenizer.extended_id_to_piece.keys():
177+
result.add(t)
178+
return result
179+
180+
181+
@staticmethod
182+
def apply_dry(
183+
settings: ExLlamaV2Sampler.Settings,
184+
tokenizer: ExLlamaV2Tokenizer,
185+
sequence_ids: torch.Tensor,
186+
logits: torch.Tensor
187+
):
188+
if settings.ngram_trie is None:
189+
settings.ngram_trie = NgramNode(0, {})
190+
settings.ngram_index = 0
191+
192+
if settings.dry_sequence_breakers is None:
193+
settings.dry_sequence_breakers = \
194+
ExLlamaV2Sampler.get_dry_default_sequence_breaker_tokens(tokenizer)
195+
196+
# Convert sequence IDs to list once since .item() is slow
197+
sequence_list = sequence_ids[0].tolist()
198+
199+
# Update trie with new ngrams
200+
seq_len = max(len(sequence_list) - 1, 0)
201+
for i in range(max(settings.ngram_index - settings.dry_max_ngram, 0), seq_len):
202+
node = settings.ngram_trie
203+
for j in range(i, min(i + settings.dry_max_ngram, seq_len)):
204+
t = sequence_list[j]
205+
if t in settings.dry_sequence_breakers:
206+
break
207+
if t not in node.children:
208+
node.children[t] = NgramNode(0, {})
209+
if j >= settings.ngram_index:
210+
node.children[t].value += 1
211+
node = node.children[t]
212+
settings.ngram_index = seq_len
213+
214+
# Find longest ngram
215+
seq_len = len(sequence_list)
216+
beg = max(seq_len - settings.dry_max_ngram, 0)
217+
end = max(seq_len - settings.dry_allowed_length + 1, 0)
218+
penalty_tokens = None
219+
for i in range(beg, end):
220+
node = settings.ngram_trie
221+
for j in range(i, seq_len):
222+
t = sequence_list[j]
223+
if t not in node.children:
224+
break
225+
node = node.children[t]
226+
else:
227+
penalty_tokens = node.children
228+
ngram_prefix_length = j - i + 1
229+
break
230+
231+
# Apply penalties if a node with children was reached at the end of the context, in which case
232+
# those children count all ngrams of length > ngram_prefix_length
233+
if penalty_tokens:
234+
indices = torch.tensor([[list(penalty_tokens.keys())]], dtype = torch.long)
235+
exc_length = ngram_prefix_length - settings.dry_allowed_length
236+
penalty = -settings.dry_multiplier * settings.dry_base ** exc_length
237+
penalties = torch.tensor([[[penalty * node.value for node in penalty_tokens.values()]]], dtype = torch.float)
238+
logits.scatter_add_(-1, indices, penalties)
239+
142240
@staticmethod
143241
# @profile
144242
def sample(
@@ -264,6 +362,11 @@ def prep_logit_filter(lf):
264362
# logits = logits + settings.token_bias
265363
ext_c.fast_fadd_cpu(logits, settings.token_bias)
266364

365+
# DRY
366+
367+
if settings.dry_allowed_length:
368+
ExLlamaV2Sampler.apply_dry(settings, tokenizer, sequence_ids, logits)
369+
267370
# Evaluate filters
268371

269372
if len(filters) > 0:
@@ -285,8 +388,8 @@ def prep_logit_filter(lf):
285388
# Special case if a single token passes
286389
if len(pass_tokens) == 1 and return_top_tokens == 0 and prefix_token is None:
287390
single_passed_token = next(iter(pass_tokens))
288-
output_tokens = torch.tensor([[single_passed_token]], dtype=torch.long)
289-
output_probs = torch.tensor([[1]], dtype=torch.float)
391+
output_tokens = torch.tensor([[single_passed_token]], dtype = torch.long)
392+
output_probs = torch.tensor([[1]], dtype = torch.float)
290393
output_ktokens = none_tensor
291394
output_kprobs = none_tensor
292395
end_filter = (single_passed_token in end_tokens)

0 commit comments

Comments
 (0)