Skip to content

Commit 1676d41

Browse files
committed
feat: improve context shift strategy
1 parent 76dea80 commit 1676d41

File tree

7 files changed

+345
-72
lines changed

7 files changed

+345
-72
lines changed

src/evaluator/LlamaChat/LlamaChat.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,21 +2166,24 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
21662166
}
21672167

21682168
public async alignCurrentSequenceStateWithCurrentTokens() {
2169-
let {firstDifferentIndex} = this.llamaChat.sequence.compareContextTokens(this.tokens);
2170-
2171-
// we need to decode at least one token to generate a response
2172-
if (firstDifferentIndex === this.tokens.length && firstDifferentIndex > 0)
2173-
firstDifferentIndex -= 1;
2174-
2175-
this.tokens.splice(0, firstDifferentIndex);
2176-
2177-
if (firstDifferentIndex < this.llamaChat.sequence.nextTokenIndex) {
2169+
if (this.tokens.length === 1 && this.llamaChat.sequence.nextTokenIndex !== 0) {
21782170
await this.llamaChat.sequence.eraseContextTokenRanges([{
2179-
start: firstDifferentIndex,
2171+
start: 0,
21802172
end: this.llamaChat.sequence.nextTokenIndex
21812173
}]);
2182-
this.ensureNotAborted();
2174+
return;
21832175
}
2176+
2177+
const lastToken = this.tokens[this.tokens.length - 1]!;
2178+
2179+
// we need to decode at least one token to generate a response
2180+
this.tokens.pop();
2181+
await this.llamaChat.sequence.adaptStateToTokens(this.tokens, false);
2182+
this.tokens.push(lastToken);
2183+
this.ensureNotAborted();
2184+
2185+
const firstDifferentIndex = this.llamaChat.sequence.nextTokenIndex;
2186+
this.tokens.splice(0, firstDifferentIndex);
21842187
}
21852188

21862189
public async evaluateWithoutGeneratingNewTokens() {

src/evaluator/LlamaChat/utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.ts

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
3030
initialCharactersRemovalCount,
3131
tokenizer,
3232
chatWrapper,
33+
failedCompressionErrorMessage: "Failed to compress chat history for context shift due to a too long prompt or system message that cannot be compressed without affecting the generation quality. " +
34+
"Consider increasing the context size or shortening the long prompt or system message.",
3335
compressChatHistory({chatHistory, charactersToRemove, estimatedCharactersPerToken}) {
3436
const res = chatHistory.map(item => structuredClone(item));
3537
let charactersLeftToRemove = charactersToRemove;
@@ -66,6 +68,8 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
6668
}
6769

6870
function removeHistoryThatLedToModelResponseAtIndex(index: number) {
71+
let removedItems = 0;
72+
6973
for (let i = index - 1; i >= 0; i--) {
7074
const historyItem = res[i];
7175

@@ -79,13 +83,19 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
7983
break; // keep the first system message
8084

8185
if (historyItem.type === "user" || historyItem.type === "system") {
82-
const newText = truncateLlamaTextAndRoundToWords(LlamaText.fromJSON(historyItem.text), charactersLeftToRemove);
86+
const newText = truncateLlamaTextAndRoundToWords(
87+
LlamaText.fromJSON(historyItem.text),
88+
charactersLeftToRemove,
89+
undefined,
90+
false
91+
);
8392
const newTextString = newText.toString();
8493
const historyItemString = LlamaText.fromJSON(historyItem.text).toString();
8594

8695
if (newText.values.length === 0) {
8796
res.splice(i, 1);
8897
i++;
98+
removedItems++;
8999
charactersLeftToRemove -= historyItemString.length;
90100
} else if (newTextString.length < historyItemString.length) {
91101
charactersLeftToRemove -= historyItemString.length - newTextString.length;
@@ -98,6 +108,66 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
98108
void (historyItem satisfies never);
99109
}
100110
}
111+
112+
return removedItems;
113+
}
114+
115+
function compressHistoryThatLedToModelResponseAtIndex(index: number, keepTokensCount: number = 0) {
116+
let removedItems = 0;
117+
let promptStartIndex: number | undefined = undefined;
118+
119+
for (let i = index - 1; i >= 0; i--) {
120+
const historyItem = res[i];
121+
122+
if (historyItem == null)
123+
continue;
124+
125+
if (historyItem.type === "model") {
126+
promptStartIndex = i + 1;
127+
break;
128+
}
129+
130+
if (i === 0 && historyItem.type === "system") {
131+
promptStartIndex = i + 1;
132+
break; // keep the first system message
133+
}
134+
}
135+
136+
if (promptStartIndex == null || promptStartIndex >= index)
137+
return 0;
138+
139+
for (let i = promptStartIndex; i < index && charactersLeftToRemove > 0; i++) {
140+
const historyItem = res[i];
141+
142+
if (historyItem == null || historyItem.type !== "user")
143+
continue;
144+
145+
let removeChars = Math.min(charactersLeftToRemove, historyItem.text.length);
146+
if (keepTokensCount > 0) {
147+
removeChars -= Math.floor(keepTokensCount * estimatedCharactersPerToken);
148+
if (removeChars < 0)
149+
removeChars = 0;
150+
151+
keepTokensCount -= Math.min(
152+
keepTokensCount,
153+
Math.max(0, historyItem.text.length - removeChars) / estimatedCharactersPerToken
154+
);
155+
}
156+
157+
const newText = truncateTextAndRoundToWords(historyItem.text, removeChars, undefined, false);
158+
if (newText.length === 0) {
159+
res.splice(i, 1);
160+
i--;
161+
index--;
162+
removedItems++;
163+
charactersLeftToRemove -= historyItem.text.length;
164+
} else {
165+
charactersLeftToRemove -= historyItem.text.length - newText.length;
166+
historyItem.text = newText;
167+
}
168+
}
169+
170+
return removedItems;
101171
}
102172

103173
function compressFirstModelResponse() {
@@ -116,7 +186,7 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
116186
continue;
117187

118188
if (typeof item === "string") {
119-
const newText = truncateTextAndRoundToWords(item, charactersLeftToRemove);
189+
const newText = truncateTextAndRoundToWords(item, charactersLeftToRemove, undefined, true);
120190

121191
if (newText === "") {
122192
historyItem.response.splice(t, 1);
@@ -139,14 +209,14 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
139209
if (historyItem.response.length === 0) {
140210
// if the model response is removed from the history,
141211
// the things that led to it are not important anymore
142-
removeHistoryThatLedToModelResponseAtIndex(i);
212+
i -= removeHistoryThatLedToModelResponseAtIndex(i);
143213
res.splice(i, 1);
144214
i--;
145215
}
146216
}
147217
}
148218

149-
function compressLastModelResponse(minCharactersToKeep: number = 20) {
219+
function compressLastModelResponse(minCharactersToKeep: number = 60) {
150220
const lastHistoryItem = res[res.length - 1];
151221

152222
if (lastHistoryItem == null || lastHistoryItem.type !== "model")
@@ -157,14 +227,27 @@ export async function eraseFirstResponseAndKeepFirstSystemChatContextShiftStrate
157227
if (lastResponseItem == null || typeof lastResponseItem !== "string")
158228
return;
159229

160-
const nextTextLength = lastResponseItem.length - charactersLeftToRemove;
161-
const charactersToRemoveFromText = charactersLeftToRemove + Math.max(0, nextTextLength - minCharactersToKeep);
162-
const newText = truncateTextAndRoundToWords(lastResponseItem, charactersToRemoveFromText);
230+
compressHistoryThatLedToModelResponseAtIndex(res.length - 1, maxTokensCount / 4);
231+
232+
if (charactersLeftToRemove <= 0)
233+
return;
234+
235+
const nextTextLength = Math.max(
236+
Math.min(lastResponseItem.length, minCharactersToKeep),
237+
lastResponseItem.length - charactersLeftToRemove
238+
);
239+
const charactersToRemoveFromText = lastResponseItem.length - nextTextLength;
240+
const newText = truncateTextAndRoundToWords(lastResponseItem, charactersToRemoveFromText, undefined, true);
163241

164242
if (newText.length < lastResponseItem.length) {
165243
lastHistoryItem.response[lastHistoryItem.response.length - 1] = newText;
166244
charactersLeftToRemove -= lastResponseItem.length - newText.length;
167245
}
246+
247+
if (charactersLeftToRemove <= 0)
248+
return;
249+
250+
compressHistoryThatLedToModelResponseAtIndex(res.length - 1);
168251
}
169252

170253
compressFunctionCalls();

src/evaluator/LlamaCompletion.ts

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -660,20 +660,22 @@ export class LlamaCompletion {
660660

661661
let shouldContextShift = false;
662662

663-
let {firstDifferentIndex} = sequence.compareContextTokens(inputTokens);
664-
665-
// we need to decode at least one token to generate a response
666-
if (firstDifferentIndex === inputTokens.length && firstDifferentIndex > 0)
667-
firstDifferentIndex -= 1;
668-
669-
inputTokens.splice(0, firstDifferentIndex);
670-
671-
if (firstDifferentIndex < sequence.nextTokenIndex) {
663+
if (inputTokens.length === 1 && sequence.nextTokenIndex !== 0)
672664
await sequence.eraseContextTokenRanges([{
673-
start: firstDifferentIndex,
665+
start: 0,
674666
end: sequence.nextTokenIndex
675667
}]);
668+
else {
669+
const lastToken = inputTokens[inputTokens.length - 1]!;
670+
671+
// we need to decode at least one token to generate a response
672+
inputTokens.pop();
673+
await sequence.adaptStateToTokens(inputTokens, false);
674+
inputTokens.push(lastToken);
676675
ensureNotAborted();
676+
677+
const firstDifferentIndex = sequence.nextTokenIndex;
678+
inputTokens.splice(0, firstDifferentIndex);
677679
}
678680

679681
const evaluationIterator = sequence.evaluate(inputTokens, removeNullFields({

src/evaluator/LlamaContext/LlamaContext.ts

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,61 @@ export class LlamaContextSequence {
909909
};
910910
}
911911

912+
/**
913+
* Erase parts of the context state to align it with the given tokens.
914+
*
915+
* If the given tokens do not align with the current context state, the context state will be erased to align with the given tokens.
916+
*
917+
* To find the first different token index between the context state and the given tokens, access the `nextTokenIndex` property.
918+
*
919+
* If `allowShift` is `true` (the default), shifting tokens may happen to align the context state with the given tokens,
920+
* which incurs token evaluation of the shifted tokens.
921+
*/
922+
public async adaptStateToTokens(tokens: Token[], allowShift: boolean = true) {
923+
if (this.model.fileInsights.isRecurrent || !allowShift) {
924+
const {firstDifferentIndex} = this.compareContextTokens(tokens);
925+
if (firstDifferentIndex < this._nextTokenIndex)
926+
await this.eraseContextTokenRanges([{
927+
start: firstDifferentIndex,
928+
end: this._nextTokenIndex
929+
}]);
930+
931+
return;
932+
}
933+
934+
const eraseRanges: ContextTokensDeleteRange[] = [];
935+
936+
let tokensIndex = 0;
937+
let differentTokenIndex: number | undefined = undefined;
938+
for (let i = 0; i < this._contextTokens.length && tokensIndex < tokens.length; i++) {
939+
if (compareTokens(this._contextTokens[i], tokens[tokensIndex])) {
940+
if (differentTokenIndex != null) {
941+
eraseRanges.push({
942+
start: differentTokenIndex,
943+
end: i
944+
});
945+
946+
differentTokenIndex = undefined;
947+
}
948+
949+
tokensIndex++;
950+
continue;
951+
}
952+
953+
if (differentTokenIndex == null)
954+
differentTokenIndex = i;
955+
}
956+
957+
if (differentTokenIndex != null)
958+
eraseRanges.push({
959+
start: differentTokenIndex,
960+
end: this._nextTokenIndex
961+
});
962+
963+
if (eraseRanges.length > 0)
964+
await this.eraseContextTokenRanges(eraseRanges);
965+
}
966+
912967
/**
913968
* Clear the history of the sequence.
914969
* If `prependBos` was enabled, the BOS token will be prepended to the sequence again.
@@ -975,15 +1030,23 @@ export class LlamaContextSequence {
9751030
if (deletionSuccessful)
9761031
deletionSuccessful &&= this._context._ctx.removeTokenCellsFromSequence(this._sequenceId, range.start, range.end);
9771032

978-
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start)
1033+
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== range.start) {
9791034
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, range.start, -removedTokens);
1035+
const shiftedTokens = range.start - lastDeleteRangeEndPos;
1036+
this._tokenMeter.useTokens(shiftedTokens, "input");
1037+
}
9801038

9811039
removedTokens += range.end - range.start;
9821040
lastDeleteRangeEndPos = range.end;
9831041
}
9841042

985-
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 && lastDeleteRangeEndPos !== this._nextTokenIndex)
1043+
if (deletionSuccessful && lastDeleteRangeEndPos != null && removedTokens > 0 &&
1044+
lastDeleteRangeEndPos !== this._nextTokenIndex
1045+
) {
9861046
this._context._ctx.shiftSequenceTokenCells(this._sequenceId, lastDeleteRangeEndPos, this._nextTokenIndex, -removedTokens);
1047+
const shiftedTokens = this._nextTokenIndex - lastDeleteRangeEndPos;
1048+
this._tokenMeter.useTokens(shiftedTokens, "input");
1049+
}
9871050

9881051
this._nextTokenIndex -= removedTokens;
9891052

src/gguf/insights/GgufInsights.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ export class GgufInsights {
104104
return true;
105105
}
106106

107+
public get isRecurrent() {
108+
switch (this._ggufFileInfo.metadata?.general?.architecture) {
109+
case GgufArchitectureType.mamba:
110+
case GgufArchitectureType.rwkv6:
111+
return true;
112+
}
113+
114+
return false;
115+
}
116+
107117
public estimateModelResourceRequirements({gpuLayers}: {gpuLayers: number}): GgufInsightsResourceRequirements {
108118
const {cpu, gpu} = this._getTensorResourceSplit(gpuLayers);
109119

0 commit comments

Comments
 (0)