Skip to content

Commit c1fed2e

Browse files
committed
Add DRY range paramater
1 parent 3e8e181 commit c1fed2e

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

exllamav2/generator/sampler.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from copy import copy
1010
import threading
1111
from functools import lru_cache
12+
from collections import deque
1213
import re
1314
# import line_profiler
1415

@@ -82,14 +83,16 @@ class Settings:
8283

8384
post_sampling_hooks: list[ExLlamaV2PostSamplingHook] = field(default_factory = list)
8485

85-
dry_allowed_length: int = 0 # 0 to disable
86-
dry_base: float = 2.0
87-
dry_multiplier: float = 2.0
88-
dry_sequence_breakers: set[int] | None = None
86+
dry_allowed_length: int = 2
87+
dry_base: float = 1.75
88+
dry_multiplier: float = 0.0 # 0 to disable
89+
dry_sequence_breakers: set[int] | None = None # None to default set derived from special characters (eng)
90+
dry_range: int = 0 # 0 for unlimited reange
8991
dry_max_ngram: int = 20
9092

9193
ngram_trie: dict[int, NgramNode] = None
9294
ngram_index: int = 0
95+
ngram_history: deque[int] = field(default_factory = deque)
9396

9497
@staticmethod
9598
def greedy(**kwargs) -> ExLlamaV2Sampler.Settings:
@@ -198,19 +201,43 @@ def apply_dry(
198201

199202
# Update trie with new ngrams
200203
seq_len = max(len(sequence_list) - 1, 0)
201-
for i in range(max(settings.ngram_index - settings.dry_max_ngram, 0), seq_len):
204+
new_beg = max(settings.ngram_index - settings.dry_max_ngram, 0)
205+
new_end = seq_len
206+
if settings.dry_range:
207+
new_beg = max(new_beg, new_end - settings.dry_range)
208+
for i in range(new_beg, new_end):
202209
node = settings.ngram_trie
203210
for j in range(i, min(i + settings.dry_max_ngram, seq_len)):
204211
t = sequence_list[j]
205212
if t in settings.dry_sequence_breakers:
206213
break
207214
if t not in node.children:
208215
node.children[t] = NgramNode(0, {})
209-
if j >= settings.ngram_index:
210-
node.children[t].value += 1
211216
node = node.children[t]
217+
if j >= settings.ngram_index:
218+
node.value += 1
219+
if len(settings.ngram_history) == 0 or settings.ngram_history[-1] < i:
220+
settings.ngram_history.append(i)
212221
settings.ngram_index = seq_len
213222

223+
# Remove old ngrams
224+
if settings.dry_range > 0:
225+
assert settings.dry_range > settings.dry_max_ngram
226+
tail_index = max(len(sequence_list) - settings.dry_range - 1, 0)
227+
while settings.ngram_history[0] < tail_index:
228+
i = settings.ngram_history.popleft()
229+
node = settings.ngram_trie
230+
for j in range(i, i + settings.dry_max_ngram):
231+
t = sequence_list[j]
232+
if t in settings.dry_sequence_breakers:
233+
break
234+
assert t in node.children
235+
node.children[t].value -= 1
236+
if node.children[t].value == 0:
237+
del node.children[t]
238+
break
239+
node = node.children[t]
240+
214241
# Find longest ngram
215242
seq_len = len(sequence_list)
216243
beg = max(seq_len - settings.dry_max_ngram, 0)
@@ -364,7 +391,7 @@ def prep_logit_filter(lf):
364391

365392
# DRY
366393

367-
if settings.dry_allowed_length:
394+
if settings.dry_multiplier > 0.0:
368395
ExLlamaV2Sampler.apply_dry(settings, tokenizer, sequence_ids, logits)
369396

370397
# Evaluate filters

0 commit comments

Comments
 (0)