Skip to content

Commit 0978ba5

Browse files
committed
Optimize logit filtering in sampler
1 parent f1d79c9 commit 0978ba5

File tree

5 files changed

+73
-40
lines changed

5 files changed

+73
-40
lines changed

exllamav2/exllamav2_ext/cpp/sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ int softmax_cpu_nonavx2
129129

130130
for (int i = 0; i < vocab_size; i++)
131131
{
132-
if (!logits_filter[i]) continue;
132+
if (logits_filter && !logits_filter[i]) continue;
133133
if (logits[i] > maxl)
134134
{
135135
maxl = logits[i];
@@ -139,7 +139,7 @@ int softmax_cpu_nonavx2
139139

140140
for (int i = 0; i < vocab_size; i++)
141141
{
142-
if (!logits_filter[i]) continue;
142+
if (logits_filter && !logits_filter[i]) continue;
143143
float l = logits[i] - maxl;
144144
if (exponent == 2.0f)
145145
l *= -l;
@@ -154,7 +154,7 @@ int softmax_cpu_nonavx2
154154

155155
for (int i = 0; i < vocab_size; i++)
156156
{
157-
if (logits_filter[i]) output[i] *= isum;
157+
if (!logits_filter || logits_filter[i]) output[i] *= isum;
158158
else output[i] = 0.0f;
159159
}
160160

exllamav2/exllamav2_ext/cpp/sampling_avx2.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ int softmax_cpu_avx2
3434

3535
// Apply logit filter and find max logit
3636

37-
int i = 0;
38-
for (; i < vocab_size; ++i)
37+
for (int i = 0; i < vocab_size; ++i)
3938
{
4039
float l = logits[i];
41-
bool f = logits_filter[i];
40+
bool f = !logits_filter || logits_filter[i];
4241
l = f ? l : minf;
4342
if (l > maxl)
4443
{
@@ -47,7 +46,8 @@ int softmax_cpu_avx2
4746
}
4847
output[i] = l;
4948
}
50-
for (; i < vocab_size_aligned; i++)
49+
50+
for (int i = vocab_size; i < vocab_size_aligned; i++)
5151
output[i] = minf;
5252

5353
// SIMD values
@@ -61,8 +61,7 @@ int softmax_cpu_avx2
6161
if (exponent == 2.0f)
6262
{
6363
__m256 sign_mask = _mm256_set1_ps(-0.0f);
64-
i = 0;
65-
for (; i < vocab_size_aligned; i += 8)
64+
for (int i = 0; i < vocab_size_aligned; i += 8)
6665
{
6766
__m256 x = _mm256_load_ps(&output[i]);
6867
x = _mm256_sub_ps(x, maxl8);
@@ -87,10 +86,9 @@ int softmax_cpu_avx2
8786
}
8887
else
8988
{
90-
i = 0;
9189
if (itemp == 1.0f)
9290
{
93-
for (; i < vocab_size_aligned; i += 8)
91+
for (int i = 0; i < vocab_size_aligned; i += 8)
9492
{
9593
__m256 x = _mm256_load_ps(&output[i]);
9694
x = _mm256_sub_ps(x, maxl8);
@@ -101,7 +99,7 @@ int softmax_cpu_avx2
10199
}
102100
else
103101
{
104-
for (; i < vocab_size_aligned; i += 8)
102+
for (int i = 0; i < vocab_size_aligned; i += 8)
105103
{
106104
__m256 x = _mm256_load_ps(&output[i]);
107105
x = _mm256_sub_ps(x, maxl8);
@@ -121,8 +119,7 @@ int softmax_cpu_avx2
121119
float isum = 1.0f / esum;
122120
__m256 isum8 = _mm256_set1_ps(isum);
123121

124-
i = 0;
125-
for (; i < vocab_size_aligned; i += 8)
122+
for (int i = 0; i < vocab_size_aligned; i += 8)
126123
{
127124
__m256 x = _mm256_load_ps(&output[i]);
128125
x = _mm256_mul_ps(x, isum8);

exllamav2/exllamav2_ext/ext_sampling.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void apply_rep_penalty
6666

6767
std::vector<float> sample_basic
6868
(
69-
torch::Tensor logits, // shape [bsz, vocab_size]
69+
torch::Tensor logits, // shape [bsz, 1, vocab_size]
7070
float temperature,
7171
int top_k,
7272
float top_p,
@@ -96,10 +96,10 @@ std::vector<float> sample_basic
9696
TORCH_CHECK_DTYPE(output_tokens, kLong);
9797
TORCH_CHECK_DTYPE(output_probs, kFloat);
9898
TORCH_CHECK_DTYPE(logits, kFloat);
99-
TORCH_CHECK_DTYPE(logit_filter, kBool);
99+
TORCH_CHECK_DTYPE_OPT(logit_filter, kBool);
100100

101-
TORCH_CHECK_SHAPES(logit_filter, 0, logits, 0, 1);
102-
TORCH_CHECK_SHAPES(logit_filter, 1, logits, 1, 1);
101+
TORCH_CHECK_SHAPES_OPT(logit_filter, 0, logits, 0, 1);
102+
TORCH_CHECK_SHAPES_OPT(logit_filter, 1, logits, -1, 1);
103103

104104
int vocab_size = logits.size(-1);
105105
int bsz = logits.size(0);
@@ -112,7 +112,7 @@ std::vector<float> sample_basic
112112
if (!output_kprobs.device().is_meta())
113113
num_probs = output_kprobs.size(2);
114114

115-
bool* logits_filter_ptr = (bool*) logit_filter.data_ptr();
115+
bool* logits_filter_ptr = logit_filter.device().is_meta() ? NULL : (bool*) logit_filter.data_ptr();
116116

117117
Py_BEGIN_ALLOW_THREADS
118118

@@ -136,7 +136,7 @@ std::vector<float> sample_basic
136136
vocab_size,
137137
temperature,
138138
logits_ptr + i * vocab_size,
139-
logits_filter_ptr + i * vocab_size,
139+
logits_filter_ptr ? logits_filter_ptr + i * vocab_size : NULL,
140140
exponent,
141141
temp_probs
142142
);
@@ -282,7 +282,7 @@ std::vector<float> sample_basic
282282
void logit_filter_exclusive
283283
(
284284
torch::Tensor filter, // shape [bsz, vocab_size]
285-
const std::vector<std::vector<int>> &exclusive_lists
285+
const py::list& exclusive_lists
286286
)
287287
{
288288
TORCH_CHECK_DTYPE(filter, kBool);
@@ -291,13 +291,15 @@ void logit_filter_exclusive
291291
bool* filter_ptr = (bool*) filter.data_ptr();
292292
unsigned int vocab_size = filter.size(1);
293293

294-
Py_BEGIN_ALLOW_THREADS
294+
// Py_BEGIN_ALLOW_THREADS
295295

296-
for(const auto& list : exclusive_lists)
296+
for(const auto& list_ : exclusive_lists)
297297
{
298+
auto list = list_.cast<py::list>();
299+
298300
unsigned int id = 0;
299301
unsigned int next_id_idx = 0;
300-
unsigned int next_id = list[next_id_idx];
302+
unsigned int next_id = list[next_id_idx].cast<unsigned int>();
301303

302304
while (id < vocab_size)
303305
{
@@ -309,13 +311,13 @@ void logit_filter_exclusive
309311
id++;
310312
next_id_idx++;
311313
if (next_id_idx >= list.size()) next_id = vocab_size;
312-
else next_id = list[next_id_idx];
314+
else next_id = list[next_id_idx].cast<unsigned int>();;
313315
}
314316

315317
filter_ptr += vocab_size;
316318
}
317319

318-
Py_END_ALLOW_THREADS
320+
// Py_END_ALLOW_THREADS
319321
}
320322

321323
void fast_fill_cpu_ones_bool(torch::Tensor tensor)

exllamav2/exllamav2_ext/ext_sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ std::vector<float> sample_basic
4343
void logit_filter_exclusive
4444
(
4545
torch::Tensor filter, // shape [bsz, vocab_size]
46-
const std::vector<std::vector<int>> &exclusive_lists
46+
const py::list& exclusive_lists
4747
);
4848

4949
void fast_fill_cpu_ones_bool(torch::Tensor tensor);

exllamav2/generator/sampler.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,36 @@
77
from exllamav2.generator.hooks import ExLlamaV2PostSamplingHook
88
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
99
from copy import copy
10+
import threading
1011
# import line_profiler
1112

13+
_tl_tensors = threading.local()
14+
15+
def _get_logit_filter(shape, dtype):
16+
global _tl_tensors
17+
if not hasattr(_tl_tensors, 'logit_filter') \
18+
or _tl_tensors.logit_filter.shape != shape \
19+
or _tl_tensors.logit_filter.dtype != dtype:
20+
_tl_tensors.logit_filter = torch.empty(shape, dtype = dtype)
21+
return _tl_tensors.logit_filter
22+
23+
def _get_output_tokens(shape, dtype):
24+
global _tl_tensors
25+
if not hasattr(_tl_tensors, 'output_tokens') \
26+
or _tl_tensors.output_tokens.shape != shape \
27+
or _tl_tensors.output_tokens.dtype != dtype:
28+
_tl_tensors.output_tokens = torch.empty(shape, dtype = dtype)
29+
return _tl_tensors.output_tokens
30+
31+
def _get_output_probs(shape, dtype):
32+
global _tl_tensors
33+
if not hasattr(_tl_tensors, 'output_probs') \
34+
or _tl_tensors.output_probs.shape != shape \
35+
or _tl_tensors.output_probs.dtype != dtype:
36+
_tl_tensors.output_probs = torch.empty(shape, dtype = dtype)
37+
return _tl_tensors.output_probs
38+
39+
1240
class ExLlamaV2Sampler:
1341

1442
@dataclass
@@ -186,7 +214,7 @@ def sample(
186214
else:
187215
assert batch_size == 1 or len(filters) == 0, "Filters not implemented for batch size > 1"
188216

189-
logits = logits.squeeze(1)
217+
# logits = logits.view(batch_size, vocab_size)
190218

191219
# Sync
192220

@@ -203,8 +231,13 @@ def sample(
203231

204232
# Prepare filter
205233

206-
logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
207-
ext_c.fast_fill_cpu_ones_bool(logit_filter)
234+
logit_filter = None
235+
def prep_logit_filter(lf):
236+
if lf is not None:
237+
return lf
238+
lf = _get_logit_filter((batch_size, vocab_size), torch.bool)
239+
ext_c.fast_fill_cpu_ones_bool(lf)
240+
return lf
208241

209242
# Repetition penalty
210243

@@ -223,7 +256,7 @@ def sample(
223256
# Temporarily ban individual tokens
224257

225258
if blocked_tokens:
226-
logits[:, blocked_tokens] = -1e30
259+
logits[:, :, blocked_tokens] = -1e30
227260

228261
# Token bias
229262

@@ -247,7 +280,7 @@ def sample(
247280
assert pass_tokens, "Filter excluded all tokens"
248281
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
249282
pass_tokens = { tokenizer.eos_token_id }
250-
# TODO: pass pass_tokens as a numpy array or Python set
283+
logit_filter = prep_logit_filter(logit_filter)
251284
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])
252285

253286
# Healing
@@ -260,6 +293,7 @@ def sample(
260293
for i in range(batch_size):
261294
valid_token_lists.append(prefix_id_to_ids[prefix_token[i, 0].item()])
262295

296+
logit_filter = prep_logit_filter(logit_filter)
263297
ext_c.logit_filter_exclusive(logit_filter, valid_token_lists)
264298

265299
# Begin Mirostat
@@ -272,20 +306,20 @@ def sample(
272306

273307
vs = tokenizer.get_vocab_size()
274308
if vs < logits.shape[-1]:
275-
logits[:, vs:] = float("-inf")
309+
logits[:, :, vs:] = float("-inf")
276310

277311
# Sampling
278312

279-
batch_size = logits.shape[0]
280-
281-
output_tokens = torch.empty((batch_size, 1), device = "cpu", dtype = torch.long)
282-
output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
313+
output_tokens = torch.empty((batch_size, 1), dtype = torch.long)
314+
# output_tokens = _get_output_tokens((batch_size, 1), torch.long)
315+
output_probs = torch.empty((batch_size, 1), dtype = torch.float)
316+
# output_probs = _get_output_probs((batch_size, 1), torch.float)
283317
if return_top_tokens == 0:
284318
output_ktokens = none_tensor
285319
output_kprobs = none_tensor
286320
else:
287-
output_ktokens = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.long)
288-
output_kprobs = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.float)
321+
output_ktokens = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.long)
322+
output_kprobs = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.float)
289323

290324
m = ext_c.sample_basic(
291325
logits,
@@ -301,7 +335,7 @@ def sample(
301335
output_probs,
302336
output_kprobs,
303337
output_ktokens,
304-
logit_filter,
338+
logit_filter if logit_filter is not None else none_tensor,
305339
settings.mirostat,
306340
settings.mirostat_mu if settings.mirostat else [],
307341
settings.mirostat_tau,

0 commit comments

Comments
 (0)