Skip to content

Commit f6c9757

Browse files
HavenDVclaude
andcommitted
fix: prevent stack overflow on large inputs and lazy-load encodings (#75, #48)
Fix #75: BytePairEncoding.FindParts used unbounded stackalloc for both partsIndexes and partsRanks arrays. For pieces >512 bytes (e.g., large regex matches on Russian text), this overflows the 1MB default stack. Now uses heap allocation for large pieces, stackalloc for small ones. Fix #48: ModelToEncoding eagerly created 5 duplicate Cl100KBase instances, each loading the same 100K-entry .tiktoken file. Now uses Lazy<Encoding> singletons so each encoding is loaded once on first access and shared across all models that use the same encoding. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent df580f6 commit f6c9757

File tree

2 files changed

+133
-60
lines changed

2 files changed

+133
-60
lines changed
Lines changed: 114 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
1+
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
22
using Bytes = System.ReadOnlyMemory<byte>;
33
#else
44
using Bytes = System.Collections.Generic.IReadOnlyCollection<byte>;
@@ -7,10 +7,14 @@
77
namespace Tiktoken.Core;
88

99
/// <summary>
10-
///
10+
///
1111
/// </summary>
1212
public static class BytePairEncoding
1313
{
14+
// Maximum number of int elements to stackalloc (2 arrays × this size × 4 bytes each).
15+
// 512 elements = 4KB per array = 8KB total on stack, well within safe limits.
16+
private const int MaxStackAllocLength = 512;
17+
1418
private static byte[] GetSlice(this Bytes bytes, int from, int to)
1519
{
1620
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
@@ -19,7 +23,7 @@ private static byte[] GetSlice(this Bytes bytes, int from, int to)
1923
return bytes.Skip(from).Take(to - from).ToArray();
2024
#endif
2125
}
22-
26+
2327
private static int GetLength(this Bytes bytes)
2428
{
2529
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
@@ -28,7 +32,7 @@ private static int GetLength(this Bytes bytes)
2832
return bytes.Count;
2933
#endif
3034
}
31-
35+
3236
private static unsafe bool TryFindMinRank(int* partsRanks, int count, out int result)
3337
{
3438
result = 0;
@@ -41,7 +45,7 @@ private static unsafe bool TryFindMinRank(int* partsRanks, int count, out int re
4145
result = i;
4246
}
4347
}
44-
48+
4549
return minRank != int.MaxValue;
4650
}
4751

@@ -63,17 +67,17 @@ private static unsafe int GetRank(
6367
return rank;
6468
}
6569
}
66-
70+
6771
return int.MaxValue;
6872
}
69-
73+
7074
private static unsafe int FindParts(
7175
Bytes piece,
7276
int* partsIndexes,
77+
int* partsRanks,
78+
int partsLength,
7379
IReadOnlyDictionary<byte[], int> ranks)
7480
{
75-
var partsLength = piece.GetLength() + 1;
76-
var partsRanks = stackalloc int [partsLength];
7781
for (var i = 0; i < partsLength; i++)
7882
{
7983
partsIndexes[i] = i;
@@ -83,15 +87,15 @@ private static unsafe int FindParts(
8387
{
8488
partsRanks[i] = GetRank(i, partsIndexes, partsLength, piece, ranks, length: 2);
8589
}
86-
90+
8791
var count = partsLength - 1;
8892
while (true)
8993
{
9094
if (!TryFindMinRank(partsRanks, count, out var i))
9195
{
9296
break;
9397
}
94-
98+
9599
partsRanks[i] = GetRank(i, partsIndexes, count + 1, piece, ranks, length: 3);
96100
if (i > 0)
97101
{
@@ -104,70 +108,133 @@ private static unsafe int FindParts(
104108
}
105109
count--;
106110
}
107-
111+
108112
return count;
109113
}
110-
114+
111115
internal static unsafe void BytePairEncode(Bytes piece, IReadOnlyDictionary<byte[], int> ranks, List<int> outList)
112116
{
113117
var partsLength = piece.GetLength() + 1;
114-
var partsIndexes = stackalloc int [partsLength];
115-
var count = FindParts(piece, partsIndexes, ranks);
116118

117-
for (var i = 0; i < count; i++)
119+
if (partsLength <= MaxStackAllocLength)
118120
{
119-
var from = partsIndexes[i];
120-
var to = partsIndexes[i + 1];
121-
var slice = piece.GetSlice(from, to);
121+
var partsIndexes = stackalloc int[partsLength];
122+
var partsRanks = stackalloc int[partsLength];
123+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
124+
125+
for (var i = 0; i < count; i++)
126+
{
127+
outList.Add(ranks[piece.GetSlice(partsIndexes[i], partsIndexes[i + 1])]);
128+
}
129+
}
130+
else
131+
{
132+
var heapIndexes = new int[partsLength];
133+
var heapRanks = new int[partsLength];
134+
fixed (int* partsIndexes = heapIndexes)
135+
fixed (int* partsRanks = heapRanks)
136+
{
137+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
122138

123-
outList.Add(ranks[slice]);
139+
for (var i = 0; i < count; i++)
140+
{
141+
outList.Add(ranks[piece.GetSlice(partsIndexes[i], partsIndexes[i + 1])]);
142+
}
143+
}
124144
}
125145
}
126146

127147
internal static unsafe int[] BytePairEncodeToArray(Bytes piece, IReadOnlyDictionary<byte[], int> ranks)
128148
{
129149
var partsLength = piece.GetLength() + 1;
130-
var partsIndexes = stackalloc int [partsLength];
131-
var count = FindParts(piece, partsIndexes, ranks);
132150

133-
var result = new int[count];
134-
for (var i = 0; i < count; i++)
151+
if (partsLength <= MaxStackAllocLength)
135152
{
136-
var from = partsIndexes[i];
137-
var to = partsIndexes[i + 1];
138-
var slice = piece.GetSlice(from, to);
153+
var partsIndexes = stackalloc int[partsLength];
154+
var partsRanks = stackalloc int[partsLength];
155+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
139156

140-
result[i] = ranks[slice];
157+
var result = new int[count];
158+
for (var i = 0; i < count; i++)
159+
{
160+
result[i] = ranks[piece.GetSlice(partsIndexes[i], partsIndexes[i + 1])];
161+
}
162+
return result;
141163
}
164+
else
165+
{
166+
var heapIndexes = new int[partsLength];
167+
var heapRanks = new int[partsLength];
168+
fixed (int* partsIndexes = heapIndexes)
169+
fixed (int* partsRanks = heapRanks)
170+
{
171+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
142172

143-
return result;
173+
var result = new int[count];
174+
for (var i = 0; i < count; i++)
175+
{
176+
result[i] = ranks[piece.GetSlice(partsIndexes[i], partsIndexes[i + 1])];
177+
}
178+
return result;
179+
}
180+
}
144181
}
145-
182+
146183
internal static unsafe List<byte[]> BytePairExplore(Bytes piece, IReadOnlyDictionary<byte[], int> ranks)
147184
{
148185
var partsLength = piece.GetLength() + 1;
149-
var partsIndexes = stackalloc int [partsLength];
150-
var count = FindParts(piece, partsIndexes, ranks);
151-
152-
var outList = new List<byte[]>(count);
153-
for (var i = 0; i < count; i++)
186+
187+
if (partsLength <= MaxStackAllocLength)
154188
{
155-
var from = partsIndexes[i];
156-
var to = partsIndexes[i + 1];
157-
var slice = piece.GetSlice(from, to);
158-
159-
outList.Add(slice);
189+
var partsIndexes = stackalloc int[partsLength];
190+
var partsRanks = stackalloc int[partsLength];
191+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
192+
193+
var outList = new List<byte[]>(count);
194+
for (var i = 0; i < count; i++)
195+
{
196+
outList.Add(piece.GetSlice(partsIndexes[i], partsIndexes[i + 1]));
197+
}
198+
return outList;
199+
}
200+
else
201+
{
202+
var heapIndexes = new int[partsLength];
203+
var heapRanks = new int[partsLength];
204+
fixed (int* partsIndexes = heapIndexes)
205+
fixed (int* partsRanks = heapRanks)
206+
{
207+
var count = FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
208+
209+
var outList = new List<byte[]>(count);
210+
for (var i = 0; i < count; i++)
211+
{
212+
outList.Add(piece.GetSlice(partsIndexes[i], partsIndexes[i + 1]));
213+
}
214+
return outList;
215+
}
160216
}
161-
162-
return outList;
163217
}
164-
218+
165219
internal static unsafe int BytePairEncodeCountTokens(Bytes piece, IReadOnlyDictionary<byte[], int> ranks)
166220
{
167221
var partsLength = piece.GetLength() + 1;
168-
var partsIndexes = stackalloc int [partsLength];
169-
var count = FindParts(piece, partsIndexes, ranks);
170-
171-
return count;
222+
223+
if (partsLength <= MaxStackAllocLength)
224+
{
225+
var partsIndexes = stackalloc int[partsLength];
226+
var partsRanks = stackalloc int[partsLength];
227+
return FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
228+
}
229+
else
230+
{
231+
var heapIndexes = new int[partsLength];
232+
var heapRanks = new int[partsLength];
233+
fixed (int* partsIndexes = heapIndexes)
234+
fixed (int* partsRanks = heapRanks)
235+
{
236+
return FindParts(piece, partsIndexes, partsRanks, partsLength, ranks);
237+
}
238+
}
172239
}
173240
}
Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
using Tiktoken.Encodings;
1+
using Tiktoken.Encodings;
22

33
namespace Tiktoken;
44

55
/// <summary>
6-
///
6+
///
77
/// </summary>
88
public static class ModelToEncoding
99
{
10-
private static Dictionary<string, Encoding> Dictionary { get; } = new()
10+
// Lazy singletons — each encoding is loaded only once, on first access.
11+
private static readonly Lazy<Encoding> Cl100K = new(static () => new Cl100KBase());
12+
private static readonly Lazy<Encoding> O200K = new(static () => new O200KBase());
13+
14+
private static Dictionary<string, Lazy<Encoding>> Dictionary { get; } = new()
1115
{
1216
// chat
13-
{ "gpt-4o", new O200KBase() },
14-
{ "gpt-4", new Cl100KBase() },
15-
{ "gpt-3.5-turbo", new Cl100KBase() },
16-
{ "gpt-35-turbo", new Cl100KBase() }, // Azure deployment name
17-
17+
{ "gpt-4o", O200K },
18+
{ "gpt-4", Cl100K },
19+
{ "gpt-3.5-turbo", Cl100K },
20+
{ "gpt-35-turbo", Cl100K }, // Azure deployment name
21+
1822
// embeddings
19-
{ "text-embedding-ada-002", new Cl100KBase() },
20-
{ "text-embedding-3-small", new Cl100KBase() },
21-
{ "text-embedding-3-large", new Cl100KBase() },
23+
{ "text-embedding-ada-002", Cl100K },
24+
{ "text-embedding-3-small", Cl100K },
25+
{ "text-embedding-3-large", Cl100K },
2226
};
2327

2428
/// <summary>
@@ -29,8 +33,10 @@ public static class ModelToEncoding
2933
/// <returns></returns>
3034
public static Encoding? TryFor(string modelName)
3135
{
32-
return Dictionary
36+
var lazy = Dictionary
3337
.FirstOrDefault(a => modelName.StartsWith(a.Key, StringComparison.Ordinal)).Value;
38+
39+
return lazy?.Value;
3440
}
3541

3642
/// <summary>
@@ -44,4 +50,4 @@ public static Encoding For(string modelName)
4450
return TryFor(modelName) ??
4551
throw new ArgumentException($"Model name {modelName} is not supported.");
4652
}
47-
}
53+
}

0 commit comments

Comments
 (0)