Skip to content

Commit cf01e53

Browse files
committed
feat: segment options for generic template chat wrappers
1 parent 288fb82 commit cf01e53

File tree

6 files changed

+223
-9
lines changed

6 files changed

+223
-9
lines changed

src/chatWrappers/generic/JinjaTemplateChatWrapper.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import {
66
import {SpecialToken, LlamaText, SpecialTokensText} from "../../utils/LlamaText.js";
77
import {ChatWrapper} from "../../ChatWrapper.js";
88
import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js";
9+
import {
10+
templateSegmentOptionsToChatWrapperSettings, TemplateChatWrapperSegmentsOptions
11+
} from "./utils/templateSegmentOptionsToChatWrapperSettings.js";
912

1013
export type JinjaTemplateChatWrapperOptions = {
1114
template: string,
@@ -61,7 +64,12 @@ export type JinjaTemplateChatWrapperOptions = {
6164
/**
6265
* Additional parameters to use for rendering the Jinja template.
6366
*/
64-
additionalRenderParameters?: Record<string, any>
67+
additionalRenderParameters?: Record<string, any>,
68+
69+
/**
70+
* Format of the segments generated by the model (like thought segments)
71+
*/
72+
segments?: TemplateChatWrapperSegmentsOptions
6573
};
6674

6775
export type JinjaTemplateChatWrapperOptionsConvertMessageFormat = {
@@ -92,6 +100,10 @@ const defaultConvertUnsupportedSystemMessagesToUserMessagesFormat: JinjaTemplate
92100
* // functionCallMessageTemplate: { // optional
93101
* // call: "[[call: {{functionName}}({{functionParams}})]]",
94102
* // result: " [[result: {{functionCallResult}}]]"
103+
* // },
104+
* // segments: {
105+
* // thoughtTemplate: "<think>{{content}}</think>",
106+
* // reopenThoughtAfterFunctionCalls: true
95107
* // }
96108
* });
97109
* ```
@@ -125,7 +137,8 @@ export class JinjaTemplateChatWrapper extends ChatWrapper {
125137
functionCallMessageTemplate,
126138
joinAdjacentMessagesOfTheSameType = true,
127139
trimLeadingWhitespaceInResponses = true,
128-
additionalRenderParameters
140+
additionalRenderParameters,
141+
segments
129142
}: JinjaTemplateChatWrapperOptions) {
130143
super();
131144

@@ -144,7 +157,8 @@ export class JinjaTemplateChatWrapper extends ChatWrapper {
144157

145158
this.settings = {
146159
...ChatWrapper.defaultSettings,
147-
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions
160+
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions,
161+
segments: templateSegmentOptionsToChatWrapperSettings(segments)
148162
};
149163

150164
if (this.convertUnsupportedSystemMessagesToUserMessages != null && !this.convertUnsupportedSystemMessagesToUserMessages.format.includes("{{message}}"))

src/chatWrappers/generic/TemplateChatWrapper.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import {SpecialToken, LlamaText, LlamaTextValue, SpecialTokensText} from "../../
33
import {ChatWrapper} from "../../ChatWrapper.js";
44
import {parseTextTemplate} from "../../utils/parseTextTemplate.js";
55
import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js";
6+
import {
7+
templateSegmentOptionsToChatWrapperSettings, TemplateChatWrapperSegmentsOptions
8+
} from "./utils/templateSegmentOptionsToChatWrapperSettings.js";
69

710
export type TemplateChatWrapperOptions = {
811
template: `${"" | `${string}{{systemPrompt}}`}${string}{{history}}${string}{{completion}}${string}`,
@@ -12,7 +15,18 @@ export type TemplateChatWrapperOptions = {
1215
model: `${string}{{message}}${string}`
1316
},
1417
functionCallMessageTemplate?: ChatHistoryFunctionCallMessageTemplate,
15-
joinAdjacentMessagesOfTheSameType?: boolean
18+
19+
/**
20+
* Whether to join adjacent messages of the same type.
21+
*
22+
* Defaults to `true`.
23+
*/
24+
joinAdjacentMessagesOfTheSameType?: boolean,
25+
26+
/**
27+
* Format of the segments generated by the model (like thought segments)
28+
*/
29+
segments?: TemplateChatWrapperSegmentsOptions
1630
};
1731

1832
/**
@@ -33,6 +47,10 @@ export type TemplateChatWrapperOptions = {
3347
* // functionCallMessageTemplate: { // optional
3448
* // call: "[[call: {{functionName}}({{functionParams}})]]",
3549
* // result: " [[result: {{functionCallResult}}]]"
50+
* // },
51+
* // segments: {
52+
* // thoughtTemplate: "<think>{{content}}</think>",
53+
* // reopenThoughtAfterFunctionCalls: true
3654
* // }
3755
* });
3856
* ```
@@ -52,6 +70,8 @@ export type TemplateChatWrapperOptions = {
5270
*
5371
* **`functionCallMessageTemplate`** is used to specify the format in which functions can be called by the model and
5472
* how their results are fed to the model after the function call.
73+
*
74+
* **`segments`** is used to specify the format of the segments generated by the model (like thought segments).
5575
*/
5676
export class TemplateChatWrapper extends ChatWrapper {
5777
public readonly wrapperName = "Template";
@@ -72,7 +92,8 @@ export class TemplateChatWrapper extends ChatWrapper {
7292
template,
7393
historyTemplate,
7494
functionCallMessageTemplate,
75-
joinAdjacentMessagesOfTheSameType = true
95+
joinAdjacentMessagesOfTheSameType = true,
96+
segments
7697
}: TemplateChatWrapperOptions) {
7798
super();
7899

@@ -95,7 +116,8 @@ export class TemplateChatWrapper extends ChatWrapper {
95116

96117
this.settings = {
97118
...ChatWrapper.defaultSettings,
98-
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions
119+
functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSettings.functions,
120+
segments: templateSegmentOptionsToChatWrapperSettings(segments)
99121
};
100122
}
101123

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import {ChatWrapperSettings} from "../../../types.js";
2+
import {parseTextTemplate} from "../../../utils/parseTextTemplate.js";
3+
import {removeUndefinedFields} from "../../../utils/removeNullFields.js";
4+
5+
export function templateSegmentOptionsToChatWrapperSettings(
6+
templateOptions?: TemplateChatWrapperSegmentsOptions
7+
): ChatWrapperSettings["segments"] {
8+
if (templateOptions == null)
9+
return {};
10+
11+
function getThoughtSegmentOptions(): Exclude<ChatWrapperSettings["segments"], undefined>["thought"] {
12+
if (templateOptions?.thoughtTemplate == null)
13+
return undefined;
14+
15+
const parsedThoughtTemplate = parseTextTemplate(templateOptions.thoughtTemplate, [{
16+
text: "{{content}}",
17+
key: "content"
18+
}]);
19+
20+
const prefix = parsedThoughtTemplate.content.prefix;
21+
if (prefix.length === 0)
22+
throw new Error("Thought template must have text before \"{{content}}\"");
23+
24+
return removeUndefinedFields({
25+
prefix,
26+
suffix: parsedThoughtTemplate.content.suffix || undefined,
27+
reopenAfterFunctionCalls: templateOptions.reopenThoughtAfterFunctionCalls
28+
});
29+
}
30+
31+
return removeUndefinedFields({
32+
closeAllSegments: templateOptions.closeAllSegmentsTemplate || undefined,
33+
reiterateStackAfterFunctionCalls: templateOptions.reiterateStackAfterFunctionCalls,
34+
35+
thought: getThoughtSegmentOptions()
36+
});
37+
}
38+
39+
export type TemplateChatWrapperSegmentsOptions = {
40+
/** Template for a thought segment */
41+
thoughtTemplate?: `${string}{{content}}${string}`,
42+
43+
/**
44+
* Automatically reopen a thought segment after function calls.
45+
*
46+
* Useful for aligning the output of models that assume that a thought segment is already open after function calls.
47+
*
48+
* Defaults to `false`.
49+
*/
50+
reopenThoughtAfterFunctionCalls?: boolean,
51+
52+
/** Consider all segments to be closed when this text is detected */
53+
closeAllSegmentsTemplate?: string,
54+
55+
/**
56+
* After function calls, reiterate the stack of the active segments to remind the model of the context.
57+
*
58+
* Defaults to `false`.
59+
*/
60+
reiterateStackAfterFunctionCalls?: boolean
61+
};

src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./chatWrappe
6262
import {
6363
JinjaTemplateChatWrapper, type JinjaTemplateChatWrapperOptions, type JinjaTemplateChatWrapperOptionsConvertMessageFormat
6464
} from "./chatWrappers/generic/JinjaTemplateChatWrapper.js";
65-
import {ChatHistoryFunctionCallMessageTemplate} from "./chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.js";
6665
import {
6766
resolvableChatWrapperTypeNames, type ResolvableChatWrapperTypeName, specializedChatWrapperTypeNames,
6867
type SpecializedChatWrapperTypeName, templateChatWrapperTypeNames, type TemplateChatWrapperTypeName, resolveChatWrapper,
@@ -112,6 +111,8 @@ import {GgmlType, type GgufTensorInfo} from "./gguf/types/GgufTensorInfoTypes.js
112111
import {type ModelFileAccessTokens} from "./utils/modelFileAccesTokens.js";
113112
import {type OverridesObject} from "./utils/OverridesObject.js";
114113
import type {LlamaClasses} from "./utils/getLlamaClasses.js";
114+
import type {ChatHistoryFunctionCallMessageTemplate} from "./chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.js";
115+
import type {TemplateChatWrapperSegmentsOptions} from "./chatWrappers/generic/utils/templateSegmentOptionsToChatWrapperSettings.js";
115116

116117

117118
export {
@@ -219,6 +220,7 @@ export {
219220
type JinjaTemplateChatWrapperOptions,
220221
type JinjaTemplateChatWrapperOptionsConvertMessageFormat,
221222
type ChatHistoryFunctionCallMessageTemplate,
223+
type TemplateChatWrapperSegmentsOptions,
222224
resolveChatWrapper,
223225
type BuiltInChatWrapperType,
224226
type ResolveChatWrapperOptions,

src/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ export type ChatWrapperSettings = {
8383
},
8484

8585
readonly segments?: {
86-
/** When this text is detected, active text segments are considered closed */
86+
/** Consider all active segments to be closed when this text is detected */
8787
readonly closeAllSegments?: string | LlamaText,
8888

8989
/**
90-
* After function calls, reiterate the stack of the active text segments to remind the model of the context.
90+
* After function calls, reiterate the stack of the active segments to remind the model of the context.
9191
*
9292
* Defaults to `false`.
9393
*/
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import {describe, expect, test} from "vitest";
2+
import {
3+
templateSegmentOptionsToChatWrapperSettings
4+
} from "../../../../../src/chatWrappers/generic/utils/templateSegmentOptionsToChatWrapperSettings.js";
5+
6+
7+
describe("getChatWrapperSegmentsOptionsFromTemplateOption", () => {
8+
test("no options", () => {
9+
expect(templateSegmentOptionsToChatWrapperSettings()).to.eql({});
10+
expect(templateSegmentOptionsToChatWrapperSettings(undefined)).to.eql({});
11+
expect(templateSegmentOptionsToChatWrapperSettings({})).to.eql({});
12+
});
13+
14+
test("no thought content", () => {
15+
try {
16+
templateSegmentOptionsToChatWrapperSettings({
17+
thoughtTemplate: "text" as any
18+
});
19+
expect.unreachable("Parsing a thought template without a prefix should throw an error");
20+
} catch (err) {
21+
expect(err).toMatchInlineSnapshot('[Error: Template must contain "{{content}}" at the beginning]');
22+
}
23+
});
24+
25+
test("no thought prefix", () => {
26+
try {
27+
templateSegmentOptionsToChatWrapperSettings({
28+
thoughtTemplate: "{{content}}suffix"
29+
});
30+
expect.unreachable("Parsing a thought template without a prefix should throw an error");
31+
} catch (err) {
32+
expect(err).toMatchInlineSnapshot('[Error: Thought template must have text before "{{content}}"]');
33+
}
34+
});
35+
36+
test("valid thought prefix", () => {
37+
expect(templateSegmentOptionsToChatWrapperSettings({
38+
thoughtTemplate: "prefix{{content}}"
39+
})).to.eql({
40+
thought: {
41+
prefix: "prefix"
42+
}
43+
});
44+
});
45+
46+
test("valid thought suffix", () => {
47+
expect(templateSegmentOptionsToChatWrapperSettings({
48+
thoughtTemplate: "prefix{{content}}suffix"
49+
})).to.eql({
50+
thought: {
51+
prefix: "prefix",
52+
suffix: "suffix"
53+
}
54+
});
55+
});
56+
57+
test("reopenThoughtAfterFunctionCalls", () => {
58+
expect(templateSegmentOptionsToChatWrapperSettings({
59+
thoughtTemplate: "prefix{{content}}suffix",
60+
reopenThoughtAfterFunctionCalls: true
61+
})).to.eql({
62+
thought: {
63+
prefix: "prefix",
64+
suffix: "suffix",
65+
reopenAfterFunctionCalls: true
66+
}
67+
});
68+
});
69+
70+
test("closeAllSegmentsTemplate", () => {
71+
expect(templateSegmentOptionsToChatWrapperSettings({
72+
thoughtTemplate: "prefix{{content}}suffix",
73+
reopenThoughtAfterFunctionCalls: true,
74+
closeAllSegmentsTemplate: "closeAll"
75+
})).to.eql({
76+
closeAllSegments: "closeAll",
77+
thought: {
78+
prefix: "prefix",
79+
suffix: "suffix",
80+
reopenAfterFunctionCalls: true
81+
}
82+
});
83+
});
84+
85+
test("empty closeAllSegmentsTemplate", () => {
86+
expect(templateSegmentOptionsToChatWrapperSettings({
87+
thoughtTemplate: "prefix{{content}}suffix",
88+
reopenThoughtAfterFunctionCalls: true,
89+
closeAllSegmentsTemplate: ""
90+
})).to.eql({
91+
thought: {
92+
prefix: "prefix",
93+
suffix: "suffix",
94+
reopenAfterFunctionCalls: true
95+
}
96+
});
97+
});
98+
99+
test("reiterateStackAfterFunctionCalls", () => {
100+
expect(templateSegmentOptionsToChatWrapperSettings({
101+
thoughtTemplate: "prefix{{content}}suffix",
102+
reopenThoughtAfterFunctionCalls: true,
103+
closeAllSegmentsTemplate: "closeAll",
104+
reiterateStackAfterFunctionCalls: true
105+
})).to.eql({
106+
closeAllSegments: "closeAll",
107+
reiterateStackAfterFunctionCalls: true,
108+
thought: {
109+
prefix: "prefix",
110+
suffix: "suffix",
111+
reopenAfterFunctionCalls: true
112+
}
113+
});
114+
});
115+
});

0 commit comments

Comments
 (0)