Skip to content

Commit b8a5bcf

Browse files
committed
Fix sampling errors due to float rounding errors on large-vocab models
1 parent 44dd05b commit b8a5bcf

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

exllamav2/exllamav2_ext/cpp/sampling.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,12 @@ int multinomial_cpu
833833
while (true)
834834
{
835835
if (accum >= random) break;
836-
if (idx == num_candidates - 1) break;
836+
if (idx == num_candidates - 1)
837+
{
838+
// Roll back in case the sampled probability is exactly zero
839+
while (idx > 0 && temp_probs[idx] == 0.0f) idx--;
840+
break;
841+
}
837842
idx++;
838843
accum += temp_probs[idx];
839844
}

exllamav2/exllamav2_ext/cpp/util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
1515
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
1616
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
17+
#define DBGF4(__x, __y, __z, __w) printf("%s, %s, %s, %s: %f, %f, %f, %f\n", #__x, #__y, #__z, #__w, __x, __y, __z, __w)
1718
#define DBGIF(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __y)
1819

1920
#define TIME_START \

exllamav2/exllamav2_ext/ext_sampling.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,24 @@ std::vector<float> sample_basic
230230
random_s = powf(random, expf(-skew));
231231
}
232232

233-
multinomial_cpu(num_candidates, temp_probs, temp_indices, random_s);
233+
// {
234+
// float sum = 0.0f;
235+
// float pmin = temp_probs[0];
236+
// float pmax = pmin;
237+
// for (int i = 0; i < num_candidates; ++i)
238+
// {
239+
// if (temp_probs[i] < pmin) pmin = temp_probs[i];
240+
// if (temp_probs[i] > pmax) pmax = temp_probs[i];
241+
// sum += temp_probs[i];
242+
// }
243+
// DBGF4(pmin, pmax, sum, random_s);
244+
// }
245+
246+
// Scale random sampling point a little to account for FP32 rounding errors during softmax. Probs
247+
// can potentially sum to slightly less than 1 for large-vocab models
248+
float random_s_adj = random_s * 0.9998;
249+
250+
multinomial_cpu(num_candidates, temp_probs, temp_indices, random_s_adj);
234251

235252
output_tokens[i][0] = temp_indices[0];
236253
output_probs[i][0] = temp_probs[0];

0 commit comments

Comments
 (0)