Skip to content

Commit 1bb0c0d

Browse files
committed
fix: more flexible model message prompt completion config
1 parent 30eaa23 commit 1bb0c0d

File tree

1 file changed

+75
-43
lines changed

1 file changed

+75
-43
lines changed

src/evaluator/LlamaChatSession/LlamaChatSession.ts

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -330,18 +330,26 @@ export type LLamaChatCompletePromptOptions = {
330330
enabled?: "auto" | boolean,
331331

332332
/**
333-
* The user prompt to give the model for the completion.
333+
* The messages to append to the chat history to generate a completion as a model response.
334334
*
335-
* Defaults to `"What may I say next?"`
336-
*/
337-
userPrompt?: string,
338-
339-
/**
340-
* The prefix to supply a model message with for the completion.
335+
* If the last message is a model message, the prompt will be pushed to it for the completion,
336+
* otherwise a new model message will be added with the prompt.
341337
*
342-
* Defaults to `"Here's a possible reply from you:\t"`
338+
* It must contain a user message or a system message before the model message.
339+
*
340+
* Default to:
341+
* ```ts
342+
* [
343+
* {
344+
* type: "system",
345+
* text: "For your next response predict what the user may send next. No yapping, no whitespace. Match the user's language and tone."
346+
* },
347+
* {type: "user", text: ""},
348+
* {type: "model", response: [""]}
349+
* ]
350+
* ```
343351
*/
344-
modelPrefix?: string
352+
appendedMessages?: ChatHistoryItem[]
345353
}
346354
};
347355

@@ -391,8 +399,14 @@ export type LlamaChatSessionRepeatPenalty = {
391399

392400
const defaultCompleteAsModel = {
393401
enabled: "auto",
394-
userPrompt: "What may I say next?",
395-
modelPrefix: "Here's a possible reply from you:\t"
402+
appendedMessages: [
403+
{
404+
type: "system",
405+
text: "For your next response predict what the user may send next. No yapping, no whitespace. Match the user's language and tone."
406+
},
407+
{type: "user", text: ""},
408+
{type: "model", response: [""]}
409+
]
396410
} as const satisfies LLamaChatCompletePromptOptions["completeAsModel"];
397411

398412
/**
@@ -928,19 +942,56 @@ export class LlamaChatSession {
928942
throw new DisposedError();
929943

930944
if (shouldCompleteAsModel) {
931-
const completeAsModelUserPrompt = (typeof completeAsModel == "boolean" || completeAsModel === "auto")
932-
? defaultCompleteAsModel.userPrompt
933-
: completeAsModel?.userPrompt ?? defaultCompleteAsModel.userPrompt;
934-
const completeAsModelMessagePrefix = (typeof completeAsModel == "boolean" || completeAsModel === "auto")
935-
? defaultCompleteAsModel.modelPrefix
936-
: completeAsModel?.modelPrefix ?? defaultCompleteAsModel.modelPrefix;
945+
const messagesToAppendOption = (typeof completeAsModel == "boolean" || completeAsModel === "auto")
946+
? defaultCompleteAsModel.appendedMessages
947+
: completeAsModel?.appendedMessages ?? defaultCompleteAsModel.appendedMessages;
948+
949+
const messagesToAppend = messagesToAppendOption.length === 0
950+
? defaultCompleteAsModel.appendedMessages
951+
: messagesToAppendOption;
952+
953+
const addMessageToChatHistory = (chatHistory: ChatHistoryItem[]): {
954+
history: ChatHistoryItem[],
955+
addedCount: number
956+
} => {
957+
const newHistory = chatHistory.slice();
958+
if (messagesToAppend.at(0)?.type === "model")
959+
newHistory.push({type: "user", text: ""});
960+
961+
for (let i = 0; i < messagesToAppend.length; i++) {
962+
const item = messagesToAppend[i];
963+
const isLastItem = i === messagesToAppend.length - 1;
964+
965+
if (item == null)
966+
continue;
967+
968+
if (isLastItem && item.type === "model") {
969+
const newResponse = item.response.slice();
970+
if (typeof newResponse.at(-1) === "string")
971+
newResponse.push((newResponse.pop()! as string) + prompt)
972+
else
973+
newResponse.push(prompt);
974+
975+
newHistory.push({
976+
type: "model",
977+
response: newResponse
978+
})
979+
} else
980+
newHistory.push(item);
981+
}
982+
983+
if (messagesToAppend.at(-1)?.type !== "model")
984+
newHistory.push({type: "model", response: [prompt]});
985+
986+
return {
987+
history: newHistory,
988+
addedCount: newHistory.length - chatHistory.length
989+
};
990+
};
937991

992+
const {history: messagesWithPrompt, addedCount} = addMessageToChatHistory(this._chatHistory);
938993
const {response, lastEvaluation, metadata} = await this._chat.generateResponse(
939-
[
940-
...asWithLastUserMessageRemoved(this._chatHistory),
941-
{type: "user", text: completeAsModelUserPrompt},
942-
{type: "model", response: [completeAsModelMessagePrefix + prompt]}
943-
] as ChatHistoryItem[],
994+
messagesWithPrompt,
944995
{
945996
abortOnNonText: true,
946997
functions,
@@ -968,11 +1019,7 @@ export class LlamaChatSession {
9681019
lastEvaluationContextWindow: {
9691020
history: this._lastEvaluation?.contextWindow == null
9701021
? undefined
971-
: [
972-
...asWithLastUserMessageRemoved(this._lastEvaluation?.contextWindow),
973-
{type: "user", text: completeAsModelUserPrompt},
974-
{type: "model", response: [completeAsModelMessagePrefix + prompt]}
975-
] as ChatHistoryItem[],
1022+
: addMessageToChatHistory(this._lastEvaluation?.contextWindow).history,
9761023
minimumOverlapPercentageToPreventContextShift: 0.8
9771024
}
9781025
}
@@ -981,7 +1028,7 @@ export class LlamaChatSession {
9811028

9821029
this._lastEvaluation = {
9831030
cleanHistory: this._chatHistory,
984-
contextWindow: asWithLastUserMessageRemoved(asWithLastModelMessageRemoved(lastEvaluation.contextWindow)),
1031+
contextWindow: lastEvaluation.contextWindow.slice(0, -addedCount),
9851032
contextShiftMetadata: lastEvaluation.contextShiftMetadata
9861033
};
9871034
this._canUseContextWindowForCompletion = this._chatHistory.at(-1)?.type === "user";
@@ -1183,18 +1230,3 @@ function asWithLastUserMessageRemoved(chatHistory?: ChatHistoryItem[]) {
11831230

11841231
return newChatHistory;
11851232
}
1186-
1187-
1188-
function asWithLastModelMessageRemoved(chatHistory: ChatHistoryItem[]): ChatHistoryItem[];
1189-
function asWithLastModelMessageRemoved(chatHistory: ChatHistoryItem[] | undefined): ChatHistoryItem[] | undefined;
1190-
function asWithLastModelMessageRemoved(chatHistory?: ChatHistoryItem[]) {
1191-
if (chatHistory == null)
1192-
return chatHistory;
1193-
1194-
const newChatHistory = chatHistory.slice();
1195-
1196-
while (newChatHistory.at(-1)?.type === "model")
1197-
newChatHistory.pop();
1198-
1199-
return newChatHistory;
1200-
}

0 commit comments

Comments
 (0)