8
8
from exllamav2 .ext import exllamav2_ext as ext_c , none_tensor
9
9
from copy import copy
10
10
import threading
11
+ from functools import lru_cache
12
+ import re
11
13
# import line_profiler
12
14
13
15
_tl_tensors = threading .local ()
@@ -37,6 +39,12 @@ def _get_output_probs(shape, dtype):
37
39
return _tl_tensors .output_probs
38
40
39
41
42
+ @dataclass
43
+ class NgramNode :
44
+ value : int = 0
45
+ children : dict [int , NgramNode ] = field (default_factory = dict )
46
+
47
+
40
48
class ExLlamaV2Sampler :
41
49
42
50
@dataclass
@@ -74,6 +82,15 @@ class Settings:
74
82
75
83
post_sampling_hooks : list [ExLlamaV2PostSamplingHook ] = field (default_factory = list )
76
84
85
+ dry_allowed_length : int = 0 # 0 to disable
86
+ dry_base : float = 2.0
87
+ dry_multiplier : float = 2.0
88
+ dry_sequence_breakers : set [int ] | None = None
89
+ dry_max_ngram : int = 20
90
+
91
+ ngram_trie : dict [int , NgramNode ] = None
92
+ ngram_index : int = 0
93
+
77
94
@staticmethod
78
95
def greedy (** kwargs ) -> ExLlamaV2Sampler .Settings :
79
96
defaults = {
@@ -101,6 +118,11 @@ def greedy_clone(self):
101
118
c .token_frequency_penalty = self .token_frequency_penalty
102
119
c .token_presence_penalty = self .token_presence_penalty
103
120
c .token_bias = None
121
+ c .dry_allowed_length = self .dry_allowed_length
122
+ c .dry_base = self .dry_allowed_length
123
+ c .dry_multiplier = self .dry_multiplier
124
+ c .dry_sequence_breakers = self .dry_sequence_breakers
125
+ c .dry_max_ngram = self .dry_max_ngram
104
126
c .filters = []
105
127
return c
106
128
@@ -139,6 +161,82 @@ def allow_tokens(
139
161
raise ValueError ("Incorrect type in allow_tokens list" )
140
162
141
163
164
+ @staticmethod
165
+ @lru_cache (10 )
166
+ def get_dry_default_sequence_breaker_tokens (
167
+ tokenizer : ExLlamaV2Tokenizer
168
+ ) -> set [int ]:
169
+ result = set ()
170
+ dry_default_sequence_breaker_chars = r".,!?<>\[\]\(\)\{\}\n\t\""
171
+ pattern = re .compile (r"[" + dry_default_sequence_breaker_chars + "]" )
172
+ pieces = tokenizer .get_id_to_piece_list (include_special_tokens = True )
173
+ for t in range (len (pieces )):
174
+ if bool (pattern .search (pieces [t ])):
175
+ result .add (t )
176
+ for t in tokenizer .extended_id_to_piece .keys ():
177
+ result .add (t )
178
+ return result
179
+
180
+
181
+ @staticmethod
182
+ def apply_dry (
183
+ settings : ExLlamaV2Sampler .Settings ,
184
+ tokenizer : ExLlamaV2Tokenizer ,
185
+ sequence_ids : torch .Tensor ,
186
+ logits : torch .Tensor
187
+ ):
188
+ if settings .ngram_trie is None :
189
+ settings .ngram_trie = NgramNode (0 , {})
190
+ settings .ngram_index = 0
191
+
192
+ if settings .dry_sequence_breakers is None :
193
+ settings .dry_sequence_breakers = \
194
+ ExLlamaV2Sampler .get_dry_default_sequence_breaker_tokens (tokenizer )
195
+
196
+ # Convert sequence IDs to list once since .item() is slow
197
+ sequence_list = sequence_ids [0 ].tolist ()
198
+
199
+ # Update trie with new ngrams
200
+ seq_len = max (len (sequence_list ) - 1 , 0 )
201
+ for i in range (max (settings .ngram_index - settings .dry_max_ngram , 0 ), seq_len ):
202
+ node = settings .ngram_trie
203
+ for j in range (i , min (i + settings .dry_max_ngram , seq_len )):
204
+ t = sequence_list [j ]
205
+ if t in settings .dry_sequence_breakers :
206
+ break
207
+ if t not in node .children :
208
+ node .children [t ] = NgramNode (0 , {})
209
+ if j >= settings .ngram_index :
210
+ node .children [t ].value += 1
211
+ node = node .children [t ]
212
+ settings .ngram_index = seq_len
213
+
214
+ # Find longest ngram
215
+ seq_len = len (sequence_list )
216
+ beg = max (seq_len - settings .dry_max_ngram , 0 )
217
+ end = max (seq_len - settings .dry_allowed_length + 1 , 0 )
218
+ penalty_tokens = None
219
+ for i in range (beg , end ):
220
+ node = settings .ngram_trie
221
+ for j in range (i , seq_len ):
222
+ t = sequence_list [j ]
223
+ if t not in node .children :
224
+ break
225
+ node = node .children [t ]
226
+ else :
227
+ penalty_tokens = node .children
228
+ ngram_prefix_length = j - i + 1
229
+ break
230
+
231
+ # Apply penalties if a node with children was reached at the end of the context, in which case
232
+ # those children count all ngrams of length > ngram_prefix_length
233
+ if penalty_tokens :
234
+ indices = torch .tensor ([[list (penalty_tokens .keys ())]], dtype = torch .long )
235
+ exc_length = ngram_prefix_length - settings .dry_allowed_length
236
+ penalty = - settings .dry_multiplier * settings .dry_base ** exc_length
237
+ penalties = torch .tensor ([[[penalty * node .value for node in penalty_tokens .values ()]]], dtype = torch .float )
238
+ logits .scatter_add_ (- 1 , indices , penalties )
239
+
142
240
@staticmethod
143
241
# @profile
144
242
def sample (
@@ -264,6 +362,11 @@ def prep_logit_filter(lf):
264
362
# logits = logits + settings.token_bias
265
363
ext_c .fast_fadd_cpu (logits , settings .token_bias )
266
364
365
+ # DRY
366
+
367
+ if settings .dry_allowed_length :
368
+ ExLlamaV2Sampler .apply_dry (settings , tokenizer , sequence_ids , logits )
369
+
267
370
# Evaluate filters
268
371
269
372
if len (filters ) > 0 :
@@ -285,8 +388,8 @@ def prep_logit_filter(lf):
285
388
# Special case if a single token passes
286
389
if len (pass_tokens ) == 1 and return_top_tokens == 0 and prefix_token is None :
287
390
single_passed_token = next (iter (pass_tokens ))
288
- output_tokens = torch .tensor ([[single_passed_token ]], dtype = torch .long )
289
- output_probs = torch .tensor ([[1 ]], dtype = torch .float )
391
+ output_tokens = torch .tensor ([[single_passed_token ]], dtype = torch .long )
392
+ output_probs = torch .tensor ([[1 ]], dtype = torch .float )
290
393
output_ktokens = none_tensor
291
394
output_kprobs = none_tensor
292
395
end_filter = (single_passed_token in end_tokens )
0 commit comments