|
| 1 | +import {ChatWrapper} from "../ChatWrapper.js"; |
| 2 | +import { |
| 3 | + ChatModelFunctions, ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState, ChatWrapperSettings, |
| 4 | + isChatModelResponseSegment |
| 5 | +} from "../types.js"; |
| 6 | +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; |
| 7 | +import {ChatModelFunctionsDocumentationGenerator} from "./utils/ChatModelFunctionsDocumentationGenerator.js"; |
| 8 | + |
| 9 | +const defaultThinkingBudget = null; |
| 10 | + |
| 11 | +// source: https://huggingface.co/ByteDance-Seed/Seed-OSS-36B-Instruct/blob/main/chat_template.jinja |
| 12 | +export class SeedChatWrapper extends ChatWrapper { |
| 13 | + public readonly wrapperName: string = "Seed"; |
| 14 | + |
| 15 | + public readonly thinkingBudget: number | 0 | null; |
| 16 | + |
| 17 | + public override readonly settings: ChatWrapperSettings = { |
| 18 | + supportsSystemMessages: true, |
| 19 | + functions: { |
| 20 | + call: { |
| 21 | + optionalPrefixSpace: true, |
| 22 | + prefix: LlamaText(new SpecialTokensText("<seed:tool_call>\n"), "<function="), |
| 23 | + paramsPrefix: LlamaText(new SpecialTokensText(">")), |
| 24 | + suffix: LlamaText(new SpecialTokensText("\n</function>\n</seed:tool_call>\n")), |
| 25 | + emptyCallParamsPlaceholder: {} |
| 26 | + }, |
| 27 | + result: { |
| 28 | + prefix: LlamaText(new SpecialTokensText("<seed:bos>tool\n")), |
| 29 | + suffix: LlamaText(new SpecialTokensText("<seed:eos>")) |
| 30 | + } |
| 31 | + }, |
| 32 | + segments: { |
| 33 | + thought: { |
| 34 | + prefix: LlamaText(new SpecialTokensText("<seed:think>")), |
| 35 | + suffix: LlamaText(new SpecialTokensText("</seed:think>")), |
| 36 | + reopenAfterFunctionCalls: true |
| 37 | + } |
| 38 | + } |
| 39 | + }; |
| 40 | + |
| 41 | + public constructor(options: { |
| 42 | + /** |
| 43 | + * The thinking budget to instruct the model to conform to. |
| 44 | + * |
| 45 | + * This is purely a request, the model may ignore it. |
| 46 | + * |
| 47 | + * Set to `0` to instruct the model to not use any reasoning. |
| 48 | + * |
| 49 | + * When set to `null`, the instruction will be omitted (unlimited reasoning). |
| 50 | + * |
| 51 | + * Defaults to `null`. |
| 52 | + */ |
| 53 | + thinkingBudget?: number | 0 | null |
| 54 | + } = {}) { |
| 55 | + super(); |
| 56 | + |
| 57 | + const { |
| 58 | + thinkingBudget = defaultThinkingBudget |
| 59 | + } = options; |
| 60 | + |
| 61 | + this.thinkingBudget = thinkingBudget; |
| 62 | + } |
| 63 | + |
| 64 | + public override generateContextState({ |
| 65 | + chatHistory, availableFunctions, documentFunctionParams |
| 66 | + }: ChatWrapperGenerateContextStateOptions): ChatWrapperGeneratedContextState { |
| 67 | + const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0; |
| 68 | + const modifiedChatHistory = chatHistory.slice(); |
| 69 | + |
| 70 | + let systemMessage: LlamaText = LlamaText(); |
| 71 | + if (modifiedChatHistory[0]?.type === "system") { |
| 72 | + systemMessage = LlamaText.fromJSON(modifiedChatHistory[0].text); |
| 73 | + modifiedChatHistory.shift(); |
| 74 | + } |
| 75 | + |
| 76 | + const contextContent: LlamaText[] = []; |
| 77 | + |
| 78 | + if (systemMessage.values.length > 0 || hasFunctions) |
| 79 | + contextContent.push( |
| 80 | + LlamaText([ |
| 81 | + new SpecialTokensText("<seed:bos>system\n"), |
| 82 | + this._getFirstSystemMessage(systemMessage, availableFunctions, {documentParams: documentFunctionParams}), |
| 83 | + new SpecialTokensText("\n<seed:eos>") |
| 84 | + ]) |
| 85 | + ); |
| 86 | + |
| 87 | + const thinkingBudgetSystemMessage = this._getThinkingBudgetSystemMessage(); |
| 88 | + if (thinkingBudgetSystemMessage.values.length > 0) |
| 89 | + contextContent.push( |
| 90 | + LlamaText([ |
| 91 | + new SpecialTokensText("<seed:bos>system\n"), |
| 92 | + thinkingBudgetSystemMessage, |
| 93 | + new SpecialTokensText("\n<seed:eos>") |
| 94 | + ]) |
| 95 | + ); |
| 96 | + |
| 97 | + for (let i = 0; i < modifiedChatHistory.length; i++) { |
| 98 | + const isLastItem = i === modifiedChatHistory.length - 1; |
| 99 | + const item = modifiedChatHistory[i]; |
| 100 | + |
| 101 | + if (item == null) |
| 102 | + continue; |
| 103 | + |
| 104 | + if (item.type === "system") { |
| 105 | + contextContent.push( |
| 106 | + LlamaText([ |
| 107 | + new SpecialTokensText("<seed:bos>system\n"), |
| 108 | + LlamaText.fromJSON(item.text), |
| 109 | + isLastItem |
| 110 | + ? LlamaText([]) |
| 111 | + : new SpecialTokensText("\n<seed:eos>") |
| 112 | + ]) |
| 113 | + ); |
| 114 | + } else if (item.type === "user") { |
| 115 | + contextContent.push( |
| 116 | + LlamaText([ |
| 117 | + new SpecialTokensText("<seed:bos>system\n"), |
| 118 | + item.text, |
| 119 | + isLastItem |
| 120 | + ? LlamaText([]) |
| 121 | + : new SpecialTokensText("\n<seed:eos>") |
| 122 | + ]) |
| 123 | + ); |
| 124 | + } else if (item.type === "model") { |
| 125 | + const injectNoThinkingThought = this.thinkingBudget === 0 && ( |
| 126 | + isLastItem || |
| 127 | + !item.response.some( |
| 128 | + (item) => ( |
| 129 | + isChatModelResponseSegment(item) && item.segmentType === "thought" |
| 130 | + ) |
| 131 | + ) |
| 132 | + ); |
| 133 | + |
| 134 | + contextContent.push( |
| 135 | + LlamaText([ |
| 136 | + new SpecialTokensText("<seed:bos>assistant\n"), |
| 137 | + !injectNoThinkingThought |
| 138 | + ? [] |
| 139 | + : [ |
| 140 | + new SpecialTokensText("<seed:think>\n"), |
| 141 | + [ |
| 142 | + new SpecialTokensText("<seed:cot_budget_reflect>"), |
| 143 | + "The current thinking budget is 0, so I will directly start answering the question.", |
| 144 | + new SpecialTokensText("</seed:cot_budget_reflect>") |
| 145 | + ], |
| 146 | + new SpecialTokensText("\n</seed:think>") |
| 147 | + ], |
| 148 | + this.generateModelResponseText(item.response, true), |
| 149 | + isLastItem |
| 150 | + ? LlamaText([]) |
| 151 | + : new SpecialTokensText("\n<seed:eos>") |
| 152 | + ]) |
| 153 | + ); |
| 154 | + } else |
| 155 | + void (item satisfies never); |
| 156 | + } |
| 157 | + |
| 158 | + const contextText = LlamaText(contextContent); |
| 159 | + |
| 160 | + return { |
| 161 | + contextText, |
| 162 | + stopGenerationTriggers: [ |
| 163 | + LlamaText(new SpecialToken("EOS")), |
| 164 | + LlamaText(new SpecialTokensText("<seed:eos>")), |
| 165 | + LlamaText("<seed:eos>") |
| 166 | + ] |
| 167 | + }; |
| 168 | + } |
| 169 | + |
| 170 | + public override generateAvailableFunctionsSystemText(availableFunctions: ChatModelFunctions, {documentParams = true}: { |
| 171 | + documentParams?: boolean |
| 172 | + }) { |
| 173 | + const functionsDocumentationGenerator = new ChatModelFunctionsDocumentationGenerator(availableFunctions); |
| 174 | + |
| 175 | + if (!functionsDocumentationGenerator.hasAnyFunctions) |
| 176 | + return LlamaText([]); |
| 177 | + |
| 178 | + return LlamaText.joinValues("\n", [ |
| 179 | + "", |
| 180 | + "Tool List:", |
| 181 | + ( |
| 182 | + "You are authorized to use the following tools (described in JSON Schema format). " + |
| 183 | + "Before performing any task, you must decide how to call them based on the descriptions and parameters of these tools." |
| 184 | + ), |
| 185 | + functionsDocumentationGenerator.getSeedFunctionSignatures({documentParams}), |
| 186 | + "When invoking tools, strictly adhere to the following format:", // the original text for this is in Chinese, translated to English here |
| 187 | + new SpecialTokensText("<seed:tool_call>\n<function=example_function_name>\n{\"example_parameter_1\": \"value_1\", \"example_parameter_2\": \"This is the value for the second parameter\"}</function>\n</seed:tool_call>") |
| 188 | + ]); |
| 189 | + } |
| 190 | + |
| 191 | + /** @internal */ |
| 192 | + private _getFirstSystemMessage( |
| 193 | + systemPrompt: LlamaText, |
| 194 | + availableFunctions?: ChatModelFunctions, |
| 195 | + {documentParams = true}: {documentParams?: boolean} = {} |
| 196 | + ) { |
| 197 | + const res: LlamaText[] = []; |
| 198 | + |
| 199 | + const functionsDocumentationGenerator = new ChatModelFunctionsDocumentationGenerator(availableFunctions); |
| 200 | + |
| 201 | + if (systemPrompt.values.length === 0 && functionsDocumentationGenerator.hasAnyFunctions) |
| 202 | + res.push( |
| 203 | + LlamaText("You are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query.") |
| 204 | + ); |
| 205 | + else if (systemPrompt.values.length > 0) |
| 206 | + res.push(systemPrompt); |
| 207 | + |
| 208 | + if (functionsDocumentationGenerator.hasAnyFunctions) |
| 209 | + res.push(this.generateAvailableFunctionsSystemText(availableFunctions!, {documentParams})); |
| 210 | + |
| 211 | + return LlamaText(res); |
| 212 | + } |
| 213 | + |
| 214 | + /** @internal */ |
| 215 | + private _getThinkingBudgetSystemMessage() { |
| 216 | + if (this.thinkingBudget == null || this.thinkingBudget < 0) |
| 217 | + return LlamaText([]); |
| 218 | + |
| 219 | + if (this.thinkingBudget === 0) |
| 220 | + return LlamaText([ |
| 221 | + "You are an intelligent assistant that can answer questions in one step without the need for reasoning and thinking, " + |
| 222 | + "that is, your thinking budget is 0. " + |
| 223 | + "Next, please skip the thinking process and directly start answering the user's questions." |
| 224 | + ]); |
| 225 | + |
| 226 | + let reflectionInterval: number = 1024; |
| 227 | + const reflectionIntervals = new Map<number, number>([ |
| 228 | + [16384, 1024], |
| 229 | + [8192, 1024], |
| 230 | + [4096, 512], |
| 231 | + [2048, 512], |
| 232 | + [1024, 256], |
| 233 | + [512, 128], |
| 234 | + [0, 0] |
| 235 | + ]); |
| 236 | + for (const [maxBudget, interval] of reflectionIntervals.entries()) { |
| 237 | + if (this.thinkingBudget <= maxBudget) { |
| 238 | + reflectionInterval = interval; |
| 239 | + break; |
| 240 | + } |
| 241 | + } |
| 242 | + |
| 243 | + return LlamaText([ |
| 244 | + new SpecialTokensText("<seed:bos>system\n"), |
| 245 | + "You are an intelligent assistant with reflective ability. In the process of thinking and reasoning, you need to strictly follow the thinking budget, which is ", |
| 246 | + this.thinkingBudget, |
| 247 | + ". That is, you need to complete your thinking within ", |
| 248 | + this.thinkingBudget, |
| 249 | + " tokens and start answering the user's questions. You will reflect on your thinking process every ", |
| 250 | + reflectionInterval, |
| 251 | + " tokens, stating how many tokens have been used and how many are left.", |
| 252 | + new SpecialTokensText("\n<seed:eos>") |
| 253 | + ]); |
| 254 | + } |
| 255 | +} |
0 commit comments