|
9 | 9 | from copy import copy
|
10 | 10 | import threading
|
11 | 11 | from functools import lru_cache
|
| 12 | +from collections import deque |
12 | 13 | import re
|
13 | 14 | # import line_profiler
|
14 | 15 |
|
@@ -82,14 +83,16 @@ class Settings:
|
82 | 83 |
|
83 | 84 | post_sampling_hooks: list[ExLlamaV2PostSamplingHook] = field(default_factory = list)
|
84 | 85 |
|
85 |
| - dry_allowed_length: int = 0 # 0 to disable |
86 |
| - dry_base: float = 2.0 |
87 |
| - dry_multiplier: float = 2.0 |
88 |
| - dry_sequence_breakers: set[int] | None = None |
| 86 | + dry_allowed_length: int = 2 |
| 87 | + dry_base: float = 1.75 |
| 88 | + dry_multiplier: float = 0.0 # 0 to disable |
| 89 | + dry_sequence_breakers: set[int] | None = None # None to default set derived from special characters (eng) |
| 90 | + dry_range: int = 0 # 0 for unlimited reange |
89 | 91 | dry_max_ngram: int = 20
|
90 | 92 |
|
91 | 93 | ngram_trie: dict[int, NgramNode] = None
|
92 | 94 | ngram_index: int = 0
|
| 95 | + ngram_history: deque[int] = field(default_factory = deque) |
93 | 96 |
|
94 | 97 | @staticmethod
|
95 | 98 | def greedy(**kwargs) -> ExLlamaV2Sampler.Settings:
|
@@ -198,19 +201,43 @@ def apply_dry(
|
198 | 201 |
|
199 | 202 | # Update trie with new ngrams
|
200 | 203 | seq_len = max(len(sequence_list) - 1, 0)
|
201 |
| - for i in range(max(settings.ngram_index - settings.dry_max_ngram, 0), seq_len): |
| 204 | + new_beg = max(settings.ngram_index - settings.dry_max_ngram, 0) |
| 205 | + new_end = seq_len |
| 206 | + if settings.dry_range: |
| 207 | + new_beg = max(new_beg, new_end - settings.dry_range) |
| 208 | + for i in range(new_beg, new_end): |
202 | 209 | node = settings.ngram_trie
|
203 | 210 | for j in range(i, min(i + settings.dry_max_ngram, seq_len)):
|
204 | 211 | t = sequence_list[j]
|
205 | 212 | if t in settings.dry_sequence_breakers:
|
206 | 213 | break
|
207 | 214 | if t not in node.children:
|
208 | 215 | node.children[t] = NgramNode(0, {})
|
209 |
| - if j >= settings.ngram_index: |
210 |
| - node.children[t].value += 1 |
211 | 216 | node = node.children[t]
|
| 217 | + if j >= settings.ngram_index: |
| 218 | + node.value += 1 |
| 219 | + if len(settings.ngram_history) == 0 or settings.ngram_history[-1] < i: |
| 220 | + settings.ngram_history.append(i) |
212 | 221 | settings.ngram_index = seq_len
|
213 | 222 |
|
| 223 | + # Remove old ngrams |
| 224 | + if settings.dry_range > 0: |
| 225 | + assert settings.dry_range > settings.dry_max_ngram |
| 226 | + tail_index = max(len(sequence_list) - settings.dry_range - 1, 0) |
| 227 | + while settings.ngram_history[0] < tail_index: |
| 228 | + i = settings.ngram_history.popleft() |
| 229 | + node = settings.ngram_trie |
| 230 | + for j in range(i, i + settings.dry_max_ngram): |
| 231 | + t = sequence_list[j] |
| 232 | + if t in settings.dry_sequence_breakers: |
| 233 | + break |
| 234 | + assert t in node.children |
| 235 | + node.children[t].value -= 1 |
| 236 | + if node.children[t].value == 0: |
| 237 | + del node.children[t] |
| 238 | + break |
| 239 | + node = node.children[t] |
| 240 | + |
214 | 241 | # Find longest ngram
|
215 | 242 | seq_len = len(sequence_list)
|
216 | 243 | beg = max(seq_len - settings.dry_max_ngram, 0)
|
@@ -364,7 +391,7 @@ def prep_logit_filter(lf):
|
364 | 391 |
|
365 | 392 | # DRY
|
366 | 393 |
|
367 |
| - if settings.dry_allowed_length: |
| 394 | + if settings.dry_multiplier > 0.0: |
368 | 395 | ExLlamaV2Sampler.apply_dry(settings, tokenizer, sequence_ids, logits)
|
369 | 396 |
|
370 | 397 | # Evaluate filters
|
|
0 commit comments