diff --git a/llama/addon/AddonContext.cpp b/llama/addon/AddonContext.cpp index 574dd79f..2d02326e 100644 --- a/llama/addon/AddonContext.cpp +++ b/llama/addon/AddonContext.cpp @@ -587,7 +587,7 @@ Napi::Value AddonContext::DisposeSequence(const Napi::CallbackInfo& info) { int32_t sequenceId = info[0].As().Int32Value(); - bool result = llama_kv_self_seq_rm(ctx, sequenceId, -1, -1); + bool result = llama_memory_seq_rm(llama_get_memory(ctx), sequenceId, -1, -1); if (!result) { Napi::Error::New(info.Env(), "Failed to dispose sequence").ThrowAsJavaScriptException(); @@ -606,7 +606,7 @@ Napi::Value AddonContext::RemoveTokenCellsFromSequence(const Napi::CallbackInfo& int32_t startPos = info[1].As().Int32Value(); int32_t endPos = info[2].As().Int32Value(); - bool result = llama_kv_self_seq_rm(ctx, sequenceId, startPos, endPos); + bool result = llama_memory_seq_rm(llama_get_memory(ctx), sequenceId, startPos, endPos); return Napi::Boolean::New(info.Env(), result); } @@ -621,7 +621,7 @@ Napi::Value AddonContext::ShiftSequenceTokenCells(const Napi::CallbackInfo& info int32_t endPos = info[2].As().Int32Value(); int32_t shiftDelta = info[3].As().Int32Value(); - llama_kv_self_seq_add(ctx, sequenceId, startPos, endPos, shiftDelta); + llama_memory_seq_add(llama_get_memory(ctx), sequenceId, startPos, endPos, shiftDelta); return info.Env().Undefined(); } @@ -634,7 +634,7 @@ Napi::Value AddonContext::GetSequenceKvCacheMinPosition(const Napi::CallbackInfo int32_t sequenceId = info[0].As().Int32Value(); - const auto minPosition = llama_kv_self_seq_pos_min(ctx, sequenceId); + const auto minPosition = llama_memory_seq_pos_min(llama_get_memory(ctx), sequenceId); return Napi::Number::New(info.Env(), minPosition); } @@ -647,7 +647,7 @@ Napi::Value AddonContext::GetSequenceKvCacheMaxPosition(const Napi::CallbackInfo int32_t sequenceId = info[0].As().Int32Value(); - const auto maxPosition = llama_kv_self_seq_pos_max(ctx, sequenceId); + const auto maxPosition = llama_memory_seq_pos_max(llama_get_memory(ctx), sequenceId); return Napi::Number::New(info.Env(), maxPosition); } diff --git a/package-lock.json b/package-lock.json index 2a79518a..0e327ffe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -24,7 +24,7 @@ "ignore": "^7.0.4", "ipull": "^3.9.2", "is-unicode-supported": "^2.1.0", - "lifecycle-utils": "^2.0.0", + "lifecycle-utils": "^2.0.1", "log-symbols": "^7.0.0", "nanoid": "^5.1.5", "node-addon-api": "^8.3.1", @@ -11548,9 +11548,9 @@ } }, "node_modules/lifecycle-utils": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/lifecycle-utils/-/lifecycle-utils-2.0.0.tgz", - "integrity": "sha512-KIkV6NeD2n0jZnO+fdIGKI5Or7alyhb6UTFzeaqf6EnE5y3pdK821+kd7yOMBUL/sPYhHU5ny74J0QKslLikGw==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/lifecycle-utils/-/lifecycle-utils-2.0.1.tgz", + "integrity": "sha512-jVso5WXIHfDL7Lf9sCRbLbPwgpoha5qUPgi+RMNVIMuOcb0nJ9Qr0r1OXbqLaxzBUQBhN8jYy92RLSk2OGJ6Cg==", "license": "MIT" }, "node_modules/lines-and-columns": { diff --git a/package.json b/package.json index cc455955..b63a6e46 100644 --- a/package.json +++ b/package.json @@ -197,7 +197,7 @@ "ignore": "^7.0.4", "ipull": "^3.9.2", "is-unicode-supported": "^2.1.0", - "lifecycle-utils": "^2.0.0", + "lifecycle-utils": "^2.0.1", "log-symbols": "^7.0.0", "nanoid": "^5.1.5", "node-addon-api": "^8.3.1", diff --git a/src/bindings/Llama.ts b/src/bindings/Llama.ts index ad025b09..005e4a7a 100644 --- a/src/bindings/Llama.ts +++ b/src/bindings/Llama.ts @@ -5,7 +5,7 @@ import {DisposedError, EventRelay, withLock} from "lifecycle-utils"; import {getConsoleLogPrefix} from "../utils/getConsoleLogPrefix.js"; import {LlamaModel, LlamaModelOptions} from "../evaluator/LlamaModel/LlamaModel.js"; import {DisposeGuard} from "../utils/DisposeGuard.js"; -import {GbnfJsonSchema} from "../utils/gbnfJson/types.js"; +import {GbnfJsonDefList, GbnfJsonSchema} from "../utils/gbnfJson/types.js"; import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js"; import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js"; import {ThreadsSplitter} from "../utils/ThreadsSplitter.js"; @@ -345,8 +345,11 @@ export class Llama { * @see [Using a JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#json-schema) tutorial * @see [Reducing Hallucinations When Using JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#reducing-json-schema-hallucinations) tutorial */ - public async createGrammarForJsonSchema(schema: Readonly) { - return new LlamaJsonSchemaGrammar(this, schema); + public async createGrammarForJsonSchema< + const T extends GbnfJsonSchema, + const Defs extends GbnfJsonDefList = Record + >(schema: Readonly & GbnfJsonSchema) { + return new LlamaJsonSchemaGrammar(this, schema); } /* eslint-enable @stylistic/max-len */ diff --git a/src/chatWrappers/QwenChatWrapper.ts b/src/chatWrappers/QwenChatWrapper.ts index 4fc3d7dd..f9fc9027 100644 --- a/src/chatWrappers/QwenChatWrapper.ts +++ b/src/chatWrappers/QwenChatWrapper.ts @@ -84,8 +84,8 @@ export class QwenChatWrapper extends ChatWrapper { segments: { reiterateStackAfterFunctionCalls: true, thought: { - prefix: LlamaText(new SpecialTokensText("")), - suffix: LlamaText(new SpecialTokensText("")) + prefix: LlamaText(new SpecialTokensText("\n")), + suffix: LlamaText(new SpecialTokensText("\n")) } } }; @@ -247,7 +247,9 @@ export class QwenChatWrapper extends ChatWrapper { public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): ChatWrapperJinjaMatchConfiguration { return [ [{}, {}, {_requireFunctionCallSettingsExtraction: true}], - [{_lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}] + [{_lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}], + [{thoughts: "discourage"}, {}, {_requireFunctionCallSettingsExtraction: true}], + [{thoughts: "discourage", _lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}] ]; } } diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts index 642173cc..6c41e2db 100644 --- a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -671,7 +671,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { return res; }; - const validateThatAllMessageIdsAreUsed = (parts: ReturnType>) => { + const validateThatAllMessageIdsAreUsed = (parts: ReturnType>) => { const messageIdsLeft = new Set(messageIds); for (const part of parts) { diff --git a/src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts b/src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts index 95b69f0f..3a6fcb5d 100644 --- a/src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts +++ b/src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts @@ -94,6 +94,22 @@ export function extractFunctionCallSettingsFromJinjaTemplate({ modelMessage2 ] }]; + const chatHistoryOnlyCall: ChatHistoryItem[] = [...baseChatHistory, { + type: "model", + response: [ + { + type: "functionCall", + name: func1name, + + // convert to number since this will go through JSON.stringify, + // and we want to avoid escaping characters in the rendered output + params: Number(func1params), + result: Number(func1result), + startsNewChunk: true + }, + modelMessage2 + ] + }]; const chatHistory2Calls: ChatHistoryItem[] = [...baseChatHistory, { type: "model", response: [ @@ -257,6 +273,17 @@ export function extractFunctionCallSettingsFromJinjaTemplate({ stringifyFunctionResults: stringifyResult, combineModelMessageAndToolCalls }); + const renderedOnlyCall = getFirstValidResult([ + () => renderTemplate({ + chatHistory: chatHistoryOnlyCall, + functions: functions1, + additionalParams, + stringifyFunctionParams: stringifyParams, + stringifyFunctionResults: stringifyResult, + combineModelMessageAndToolCalls + }), + () => undefined + ]); const rendered2Calls = getFirstValidResult([ () => renderTemplate({ chatHistory: chatHistory2Calls, @@ -411,6 +438,38 @@ export function extractFunctionCallSettingsFromJinjaTemplate({ parallelismResultPrefix } = resolveParallelismBetweenSectionsParts(func2ParamsToFunc1Result.text.slice(callSuffixLength, -resultPrefixLength)); + let revivedCallPrefix = reviveSeparatorText(callPrefixText, idToStaticContent, contentIds); + const revivedParallelismCallSectionPrefix = removeCommonRevivedPrefix( + reviveSeparatorText(parallelismCallPrefix, idToStaticContent, contentIds), + !combineModelMessageAndToolCalls + ? textBetween2TextualModelResponses + : LlamaText() + ); + let revivedParallelismCallBetweenCalls = reviveSeparatorText(parallelismBetweenCallsText, idToStaticContent, contentIds); + + if (revivedParallelismCallSectionPrefix.values.length === 0 && renderedOnlyCall != null) { + const userMessage1ToModelMessage1Start = getTextBetweenIds(rendered1Call, userMessage1, modelMessage1); + const onlyCallUserMessage1ToFunc1Name = getTextBetweenIds(renderedOnlyCall, userMessage1, func1name); + + if (userMessage1ToModelMessage1Start.text != null && onlyCallUserMessage1ToFunc1Name.text != null) { + const onlyCallModelMessagePrefixLength = findCommandStartLength( + userMessage1ToModelMessage1Start.text, + onlyCallUserMessage1ToFunc1Name.text + ); + const onlyCallCallPrefixText = onlyCallUserMessage1ToFunc1Name.text.slice(onlyCallModelMessagePrefixLength); + const revivedOnlyCallCallPrefixText = reviveSeparatorText(onlyCallCallPrefixText, idToStaticContent, contentIds); + + const optionalCallPrefix = removeCommonRevivedSuffix(revivedCallPrefix, revivedOnlyCallCallPrefixText); + if (optionalCallPrefix.values.length > 0) { + revivedCallPrefix = removeCommonRevivedPrefix(revivedCallPrefix, optionalCallPrefix); + revivedParallelismCallBetweenCalls = LlamaText([ + optionalCallPrefix, + revivedParallelismCallBetweenCalls + ]); + } + } + } + return { stringifyParams, stringifyResult, @@ -418,7 +477,7 @@ export function extractFunctionCallSettingsFromJinjaTemplate({ settings: { call: { optionalPrefixSpace: true, - prefix: reviveSeparatorText(callPrefixText, idToStaticContent, contentIds), + prefix: revivedCallPrefix, paramsPrefix: reviveSeparatorText(callParamsPrefixText, idToStaticContent, contentIds), suffix: reviveSeparatorText(callSuffixText, idToStaticContent, contentIds), emptyCallParamsPlaceholder: {} @@ -445,13 +504,8 @@ export function extractFunctionCallSettingsFromJinjaTemplate({ }, parallelism: { call: { - sectionPrefix: removeCommonRevivedPrefix( - reviveSeparatorText(parallelismCallPrefix, idToStaticContent, contentIds), - !combineModelMessageAndToolCalls - ? textBetween2TextualModelResponses - : LlamaText() - ), - betweenCalls: reviveSeparatorText(parallelismBetweenCallsText, idToStaticContent, contentIds), + sectionPrefix: revivedParallelismCallSectionPrefix, + betweenCalls: revivedParallelismCallBetweenCalls, sectionSuffix: reviveSeparatorText(parallelismCallSuffixText, idToStaticContent, contentIds) }, result: { @@ -524,7 +578,8 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) { } else if (targetValue instanceof SpecialToken && matchStartValue instanceof SpecialToken) { if (targetValue.value === matchStartValue.value) continue; - } + } else if (LlamaText(targetValue ?? "").compare(LlamaText(matchStartValue ?? ""))) + continue; return LlamaText(target.values.slice(commonStartLength)); } @@ -532,6 +587,39 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) { return LlamaText(target.values.slice(matchStart.values.length)); } +function removeCommonRevivedSuffix(target: LlamaText, matchEnd: LlamaText) { + for ( + let commonEndLength = 0; + commonEndLength < target.values.length && commonEndLength < matchEnd.values.length; + commonEndLength++ + ) { + const targetValue = target.values[target.values.length - commonEndLength - 1]; + const matchEndValue = matchEnd.values[matchEnd.values.length - commonEndLength - 1]; + + if (typeof targetValue === "string" && typeof matchEndValue === "string") { + if (targetValue === matchEndValue) + continue; + } else if (targetValue instanceof SpecialTokensText && matchEndValue instanceof SpecialTokensText) { + const commonLength = findCommonEndLength(targetValue.value, matchEndValue.value); + if (commonLength === targetValue.value.length && commonLength === matchEndValue.value.length) + continue; + + return LlamaText([ + ...target.values.slice(0, target.values.length - commonEndLength - 1), + new SpecialTokensText(targetValue.value.slice(0, targetValue.value.length - commonLength)) + ]); + } else if (targetValue instanceof SpecialToken && matchEndValue instanceof SpecialToken) { + if (targetValue.value === matchEndValue.value) + continue; + } else if (LlamaText(targetValue ?? "").compare(LlamaText(matchEndValue ?? ""))) + continue; + + return LlamaText(target.values.slice(0, target.values.length - commonEndLength - 1)); + } + + return LlamaText(target.values.slice(0, target.values.length - matchEnd.values.length)); +} + function findCommandStartLength(text1: string, text2: string) { let commonStartLength = 0; while (commonStartLength < text1.length && commonStartLength < text2.length) { diff --git a/src/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.ts b/src/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.ts index 30f434a0..c77d98bf 100644 --- a/src/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.ts +++ b/src/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.ts @@ -8,6 +8,42 @@ export function extractSegmentSettingsFromTokenizerAndChatTemplate( function tryMatchPrefixSuffixPair(tryMatchGroups: [prefix: string, suffix: string][]) { if (chatTemplate != null) { for (const [prefix, suffix] of tryMatchGroups) { + if ( + ( + hasAll(chatTemplate.replaceAll(prefix + "\\n\\n" + suffix, ""), [ + prefix + "\\n\\n", + "\\n\\n" + suffix + ]) + ) || ( + hasAll(chatTemplate.replaceAll(prefix + "\n\n" + suffix, ""), [ + prefix + "\n\n", + "\n\n" + suffix + ]) + ) + ) + return { + prefix: LlamaText(new SpecialTokensText(prefix + "\n\n")), + suffix: LlamaText(new SpecialTokensText("\n\n" + suffix)) + }; + + if ( + ( + hasAll(chatTemplate.replaceAll(prefix + "\\n" + suffix, ""), [ + prefix + "\\n", + "\\n" + suffix + ]) + ) || ( + hasAll(chatTemplate.replaceAll(prefix + "\n" + suffix, ""), [ + prefix + "\n", + "\n" + suffix + ]) + ) + ) + return { + prefix: LlamaText(new SpecialTokensText(prefix + "\n")), + suffix: LlamaText(new SpecialTokensText("\n" + suffix)) + }; + if (chatTemplate.includes(prefix) && chatTemplate.includes(suffix)) return { prefix: LlamaText(new SpecialTokensText(prefix)), @@ -46,3 +82,7 @@ export function extractSegmentSettingsFromTokenizerAndChatTemplate( ]) }); } + +function hasAll(text: string, matches: string[]) { + return matches.every((match) => text.includes(match)); +} diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index a23e58e5..39f0bd59 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -12,7 +12,6 @@ import {defineChatSessionFunction} from "../../evaluator/LlamaChatSession/utils/ import {getLlama} from "../../bindings/getLlama.js"; import {LlamaGrammar} from "../../evaluator/LlamaGrammar.js"; import {LlamaChatSession} from "../../evaluator/LlamaChatSession/LlamaChatSession.js"; -import {LlamaJsonSchemaGrammar} from "../../evaluator/LlamaJsonSchemaGrammar.js"; import { BuildGpu, LlamaLogLevel, LlamaLogLevelGreaterThan, nodeLlamaCppGpuOptions, parseNodeLlamaCppGpuOption } from "../../bindings/types.js"; @@ -529,8 +528,7 @@ async function RunChat({ }); const grammar = jsonSchemaGrammarFilePath != null - ? new LlamaJsonSchemaGrammar( - llama, + ? await llama.createGrammarForJsonSchema( await fs.readJson( path.resolve(process.cwd(), jsonSchemaGrammarFilePath) ) diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index b445ff2c..e73aaea9 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -3,6 +3,7 @@ import process from "process"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; +import {Template} from "@huggingface/jinja"; import {readGgufFileInfo} from "../../../../gguf/readGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; import {getGgufFileTypeName} from "../../../../gguf/utils/getGgufFileTypeName.js"; @@ -18,6 +19,8 @@ import {toBytes} from "../../../utils/toBytes.js"; import {printDidYouMeanUri} from "../../../utils/resolveCommandGgufPath.js"; import {isModelUri} from "../../../../utils/parseModelUri.js"; +const chatTemplateKey = ".chatTemplate"; + type InspectGgufCommand = { modelPath: string, header?: string[], @@ -54,7 +57,8 @@ export const InspectGgufCommand: CommandModule = { .option("key", { alias: ["k"], type: "string", - description: "A single metadata key to print the value of. If not provided, all metadata will be printed", + description: "A single metadata key to print the value of. If not provided, all metadata will be printed. " + + "If the key is `" + chatTemplateKey + "` then the chat template of the model will be formatted and printed.", group: "Optional:" }) .option("noSplice", { @@ -141,7 +145,9 @@ export const InspectGgufCommand: CommandModule = { if (plainJson || outputToJsonFile != null) { const getOutputJson = () => { if (key != null) { - const keyValue = getGgufMetadataKeyValue(parsedMetadata.metadata, key); + const keyValue = key === chatTemplateKey + ? tryFormattingJinja(getGgufMetadataKeyValue(parsedMetadata.metadata, "tokenizer.chat_template")) + : getGgufMetadataKeyValue(parsedMetadata.metadata, key); if (keyValue === undefined) { console.log(`Key not found: ${key}`); process.exit(1); @@ -172,7 +178,9 @@ export const InspectGgufCommand: CommandModule = { console.info(outputJson); } } else if (key != null) { - const keyValue = getGgufMetadataKeyValue(parsedMetadata.metadata, key); + const keyValue = key === chatTemplateKey + ? tryFormattingJinja(getGgufMetadataKeyValue(parsedMetadata.metadata, "tokenizer.chat_template")) + : getGgufMetadataKeyValue(parsedMetadata.metadata, key); if (keyValue === undefined) { console.log(`${chalk.red("Metadata key not found:")} ${key}`); process.exit(1); @@ -237,3 +245,17 @@ function removeAdditionalTensorInfoFields(tensorInfo?: GgufTensorInfo[]) { delete (tensor as {filePart?: GgufTensorInfo["filePart"]}).filePart; } } + +function tryFormattingJinja(template?: string) { + if (typeof template !== "string") + return template; + + try { + const parsedTemplate = new Template(template); + return parsedTemplate.format({ + indent: 4 + }) ?? template; + } catch (err) { + return template; + } +} diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 77a171d9..628e220c 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -24,6 +24,7 @@ import {LlamaSampler} from "../LlamaContext/LlamaSampler.js"; import {LlamaModel} from "../LlamaModel/LlamaModel.js"; import {getChatWrapperSegmentDefinition} from "../../utils/getChatWrapperSegmentDefinition.js"; import {jsonDumps} from "../../chatWrappers/utils/jsonDumps.js"; +import {defaultMaxPreloadTokens} from "../LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.js"; import { eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy } from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js"; @@ -721,7 +722,7 @@ export class LlamaChat { onTextChunk, onToken, signal, - maxTokens = Math.min(256, Math.ceil(this.context.contextSize / 2)), + maxTokens = defaultMaxPreloadTokens(this.sequence), temperature, minP, topK, @@ -2177,11 +2178,15 @@ class GenerateResponseState { + const defaultValue: number = 256; + + return sequence.model.fileInsights.swaSize != null + ? Math.min( + Math.ceil(sequence.model.fileInsights.swaSize / 2), + defaultValue, + Math.ceil(sequence.contextSize / 2) + ) + : Math.min( + defaultValue, + Math.ceil(sequence.contextSize / 2) + ); +}; const defaultMaxCachedCompletions = 100; export class LlamaChatSessionPromptCompletionEngine { @@ -51,7 +65,7 @@ export class LlamaChatSessionPromptCompletionEngine { /** @internal */ private _disposed = false; private constructor(chatSession: LlamaChatSession, { - maxPreloadTokens = defaultMaxPreloadTokens, + maxPreloadTokens = defaultMaxPreloadTokens(chatSession.sequence), onGeneration, maxCachedCompletions = defaultMaxCachedCompletions, ...options diff --git a/src/evaluator/LlamaChatSession/utils/defineChatSessionFunction.ts b/src/evaluator/LlamaChatSession/utils/defineChatSessionFunction.ts index 397119ea..bd27d644 100644 --- a/src/evaluator/LlamaChatSession/utils/defineChatSessionFunction.ts +++ b/src/evaluator/LlamaChatSession/utils/defineChatSessionFunction.ts @@ -1,4 +1,4 @@ -import {GbnfJsonSchema, GbnfJsonSchemaToType} from "../../../utils/gbnfJson/types.js"; +import {GbnfJsonDefList, GbnfJsonSchema, GbnfJsonSchemaToType} from "../../../utils/gbnfJson/types.js"; import {ChatSessionModelFunction} from "../../../types.js"; /** @@ -9,15 +9,18 @@ import {ChatSessionModelFunction} from "../../../types.js"; * The handler function can return a Promise, and the return value will be awaited before being returned to the model. * @param functionDefinition */ -export function defineChatSessionFunction({ +export function defineChatSessionFunction< + const Params extends GbnfJsonSchema, + const Defs extends GbnfJsonDefList +>({ description, params, handler }: { description?: string, - params?: Readonly, - handler: (params: GbnfJsonSchemaToType) => Promise | any -}): ChatSessionModelFunction { + params?: Readonly & GbnfJsonSchema, + handler: (params: GbnfJsonSchemaToType>) => Promise | any +}): ChatSessionModelFunction> { return { description, params, diff --git a/src/evaluator/LlamaJsonSchemaGrammar.ts b/src/evaluator/LlamaJsonSchemaGrammar.ts index 8cf80e42..62ad7e3f 100644 --- a/src/evaluator/LlamaJsonSchemaGrammar.ts +++ b/src/evaluator/LlamaJsonSchemaGrammar.ts @@ -1,4 +1,4 @@ -import {GbnfJsonSchema, GbnfJsonSchemaToType} from "../utils/gbnfJson/types.js"; +import {GbnfJsonDefList, GbnfJsonSchema, GbnfJsonSchemaToType} from "../utils/gbnfJson/types.js"; import {getGbnfGrammarForGbnfJsonSchema} from "../utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.js"; import {validateObjectAgainstGbnfSchema} from "../utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.js"; import {LlamaText} from "../utils/LlamaText.js"; @@ -10,14 +10,17 @@ import {LlamaGrammar} from "./LlamaGrammar.js"; * @see [Using a JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#json-schema) tutorial * @see [Reducing Hallucinations When Using JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#reducing-json-schema-hallucinations) tutorial */ -export class LlamaJsonSchemaGrammar extends LlamaGrammar { +export class LlamaJsonSchemaGrammar< + const T extends GbnfJsonSchema, + const Defs extends GbnfJsonDefList = Record +> extends LlamaGrammar { private readonly _schema: T; /** * Prefer to create a new instance of this class by using `llama.createGrammarForJsonSchema(...)`. * @deprecated Use `llama.createGrammarForJsonSchema(...)` instead. */ - public constructor(llama: Llama, schema: Readonly) { + public constructor(llama: Llama, schema: Readonly & GbnfJsonSchema) { const grammar = getGbnfGrammarForGbnfJsonSchema(schema); super(llama, { diff --git a/src/types.ts b/src/types.ts index 1b7fed75..488cb217 100644 --- a/src/types.ts +++ b/src/types.ts @@ -193,8 +193,8 @@ export type ChatSessionModelFunctions = { export type ChatSessionModelFunction = { readonly description?: string, - readonly params?: Readonly, - readonly handler: (params: GbnfJsonSchemaToType) => any + readonly params?: Params, + readonly handler: (params: GbnfJsonSchemaToType>) => any }; export function isChatModelResponseFunctionCall(item: ChatModelResponse["response"][number] | undefined): item is ChatModelFunctionCall { diff --git a/src/utils/gbnfJson/GbnfGrammarGenerator.ts b/src/utils/gbnfJson/GbnfGrammarGenerator.ts index 272a3ed3..6750ad22 100644 --- a/src/utils/gbnfJson/GbnfGrammarGenerator.ts +++ b/src/utils/gbnfJson/GbnfGrammarGenerator.ts @@ -1,9 +1,16 @@ +import {MultiKeyMap} from "lifecycle-utils"; +import {GbnfJsonSchema} from "./types.js"; + export class GbnfGrammarGenerator { - public rules = new Map(); + public rules = new Map(); public ruleContentToRuleName = new Map(); public literalValueRuleNames = new Map(); + public defRuleNames = new MultiKeyMap<[string, GbnfJsonSchema], string | null>(); + public defScopeDefs = new MultiKeyMap<[string, GbnfJsonSchema], Record>(); + public usedRootRuleName: boolean = false; private ruleId: number = 0; private valueRuleId: number = 0; + private defRuleId: number = 0; public generateRuleName() { const ruleId = this.ruleId; @@ -25,6 +32,24 @@ export class GbnfGrammarGenerator { return ruleName; } + public generateRuleNameForDef(defName: string, def: GbnfJsonSchema): string { + const existingRuleName = this.defRuleNames.get([defName, def]); + if (existingRuleName != null) + return existingRuleName; + + const ruleName = `def${this.defRuleId}`; + this.defRuleId++; + + this.defRuleNames.set([defName, def], ruleName); + + return ruleName; + } + + public registerDefs(scopeDefs: Record) { + for (const [defName, def] of Object.entries(scopeDefs)) + this.defScopeDefs.set([defName, def], scopeDefs); + } + public generateGbnfFile(rootGrammar: string) { const rules: {name: string, grammar: string}[] = [{ name: "root", diff --git a/src/utils/gbnfJson/GbnfTerminal.ts b/src/utils/gbnfJson/GbnfTerminal.ts index b8e83fe5..ca0fe99d 100644 --- a/src/utils/gbnfJson/GbnfTerminal.ts +++ b/src/utils/gbnfJson/GbnfTerminal.ts @@ -25,7 +25,22 @@ export abstract class GbnfTerminal { return this.getGrammar(grammarGenerator); } - public resolve(grammarGenerator: GbnfGrammarGenerator): string { + private _getRootRuleName(grammarGenerator: GbnfGrammarGenerator) { + if (this._ruleName != null) + return this._ruleName; + + const ruleName = grammarGenerator.usedRootRuleName + ? this.getRuleName(grammarGenerator) + : "root"; + this._ruleName = ruleName; + + if (ruleName === "root") + grammarGenerator.usedRootRuleName = true; + + return ruleName; + } + + public resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string { if (this._ruleName != null) return this._ruleName; @@ -37,7 +52,12 @@ export abstract class GbnfTerminal { return existingRuleName; } - const ruleName = this.getRuleName(grammarGenerator); + const ruleName = resolveAsRootGrammar + ? this._getRootRuleName(grammarGenerator) + : this.getRuleName(grammarGenerator); + + if (resolveAsRootGrammar) + return grammar; if (grammar === ruleName) { this._ruleName = ruleName; diff --git a/src/utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.ts b/src/utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.ts index 1c8619c9..577400a8 100644 --- a/src/utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.ts +++ b/src/utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.ts @@ -14,7 +14,7 @@ export function getGbnfGrammarForGbnfJsonSchema(schema: Readonly const grammarGenerator = new GbnfGrammarGenerator(); const scopeState = new GbnfJsonScopeState({allowNewLines, scopePadSpaces}); const rootTerminal = getGbnfJsonTerminalForGbnfJsonSchema(schema, grammarGenerator, scopeState); - const rootGrammar = rootTerminal.getGrammar(grammarGenerator); + const rootGrammar = rootTerminal.resolve(grammarGenerator, true); return grammarGenerator.generateGbnfFile(rootGrammar + ` "${"\\n".repeat(4)}"` + " [\\n]*"); } diff --git a/src/utils/gbnfJson/terminals/GbnfGrammar.ts b/src/utils/gbnfJson/terminals/GbnfGrammar.ts index 0ef1cc93..f2b48064 100644 --- a/src/utils/gbnfJson/terminals/GbnfGrammar.ts +++ b/src/utils/gbnfJson/terminals/GbnfGrammar.ts @@ -21,10 +21,10 @@ export class GbnfGrammar extends GbnfTerminal { return this.grammar; } - public override resolve(grammarGenerator: GbnfGrammarGenerator): string { + public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string { if (this.resolveToRawGrammar) return this.getGrammar(); - return super.resolve(grammarGenerator); + return super.resolve(grammarGenerator, resolveAsRootGrammar); } } diff --git a/src/utils/gbnfJson/terminals/GbnfNumberValue.ts b/src/utils/gbnfJson/terminals/GbnfNumberValue.ts index fc36ea05..76785e40 100644 --- a/src/utils/gbnfJson/terminals/GbnfNumberValue.ts +++ b/src/utils/gbnfJson/terminals/GbnfNumberValue.ts @@ -14,12 +14,12 @@ export class GbnfNumberValue extends GbnfTerminal { return '"' + JSON.stringify(this.value) + '"'; } - public override resolve(grammarGenerator: GbnfGrammarGenerator): string { + public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string { const grammar = this.getGrammar(); if (grammar.length <= grammarGenerator.getProposedLiteralValueRuleNameLength()) return grammar; - return super.resolve(grammarGenerator); + return super.resolve(grammarGenerator, resolveAsRootGrammar); } protected override generateRuleName(grammarGenerator: GbnfGrammarGenerator): string { diff --git a/src/utils/gbnfJson/terminals/GbnfOr.ts b/src/utils/gbnfJson/terminals/GbnfOr.ts index d22b93d4..8fd9c75a 100644 --- a/src/utils/gbnfJson/terminals/GbnfOr.ts +++ b/src/utils/gbnfJson/terminals/GbnfOr.ts @@ -30,7 +30,7 @@ export class GbnfOr extends GbnfTerminal { return "( " + mappedValues.join(" | ") + " )"; } - public override resolve(grammarGenerator: GbnfGrammarGenerator): string { + public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string { const mappedValues = this.values .map((v) => v.resolve(grammarGenerator)) .filter((value) => value !== "" && value !== grammarNoValue); @@ -40,6 +40,6 @@ export class GbnfOr extends GbnfTerminal { else if (mappedValues.length === 1) return mappedValues[0]!; - return super.resolve(grammarGenerator); + return super.resolve(grammarGenerator, resolveAsRootGrammar); } } diff --git a/src/utils/gbnfJson/terminals/GbnfRef.ts b/src/utils/gbnfJson/terminals/GbnfRef.ts new file mode 100644 index 00000000..c70fab09 --- /dev/null +++ b/src/utils/gbnfJson/terminals/GbnfRef.ts @@ -0,0 +1,53 @@ +import {GbnfTerminal} from "../GbnfTerminal.js"; +import {GbnfGrammarGenerator} from "../GbnfGrammarGenerator.js"; +import {GbnfJsonSchema} from "../types.js"; + + +export class GbnfRef extends GbnfTerminal { + public readonly getValueTerminal: () => GbnfTerminal; + public readonly defName: string; + public readonly def: GbnfJsonSchema; + + public constructor({ + getValueTerminal, + defName, + def + }: { + getValueTerminal: () => GbnfTerminal, + defName: string, + def: GbnfJsonSchema + }) { + super(); + this.getValueTerminal = getValueTerminal; + this.defName = defName; + this.def = def; + } + + public override getGrammar(grammarGenerator: GbnfGrammarGenerator): string { + return this.generateRuleName(grammarGenerator); + } + + protected override generateRuleName(grammarGenerator: GbnfGrammarGenerator): string { + if (!grammarGenerator.defRuleNames.has([this.defName, this.def])) { + const alreadyGeneratingGrammarForThisRef = grammarGenerator.defRuleNames.get([this.defName, this.def]) === null; + if (alreadyGeneratingGrammarForThisRef) + return grammarGenerator.generateRuleNameForDef(this.defName, this.def); + + grammarGenerator.defRuleNames.set([this.defName, this.def], null); + const grammar = this.getValueTerminal().resolve(grammarGenerator); + + if (grammarGenerator.rules.has(grammar) && grammarGenerator.defRuleNames.get([this.defName, this.def]) === null) { + grammarGenerator.defRuleNames.set([this.defName, this.def], grammar); + return grammar; + } + + const ruleName = grammarGenerator.generateRuleNameForDef(this.defName, this.def); + grammarGenerator.rules.set(ruleName, grammar); + grammarGenerator.ruleContentToRuleName.set(grammar, ruleName); + + return ruleName; + } + + return grammarGenerator.generateRuleNameForDef(this.defName, this.def); + } +} diff --git a/src/utils/gbnfJson/types.ts b/src/utils/gbnfJson/types.ts index 0f6c798a..dd17147b 100644 --- a/src/utils/gbnfJson/types.ts +++ b/src/utils/gbnfJson/types.ts @@ -1,6 +1,18 @@ export type GbnfJsonSchemaImmutableType = "string" | "number" | "integer" | "boolean" | "null"; -export type GbnfJsonSchema = GbnfJsonBasicSchema | GbnfJsonConstSchema | GbnfJsonEnumSchema | GbnfJsonOneOfSchema | - GbnfJsonStringSchema | GbnfJsonObjectSchema | GbnfJsonArraySchema; + +export type GbnfJsonSchema = Record> = GbnfJsonBasicSchema | GbnfJsonConstSchema | + GbnfJsonEnumSchema | GbnfJsonOneOfSchema | GbnfJsonStringSchema | GbnfJsonObjectSchema | + GbnfJsonArraySchema | ( + keyof Defs extends string + ? keyof NoInfer extends never + ? never + : GbnfJsonRefSchema + : never + ); + +export type GbnfJsonDefList> = {}> = { + readonly [key: string]: GbnfJsonSchema> +}; export type GbnfJsonBasicSchema = { readonly type: GbnfJsonSchemaImmutableType | readonly GbnfJsonSchemaImmutableType[], @@ -32,15 +44,17 @@ export type GbnfJsonEnumSchema = { */ readonly description?: string }; -export type GbnfJsonOneOfSchema = { - readonly oneOf: readonly GbnfJsonSchema[], +export type GbnfJsonOneOfSchema> = {}> = { + readonly oneOf: readonly GbnfJsonSchema>[], /** * A description of what you expect the model to set this value to. * * Only passed to the model when using function calling, and has no effect when using JSON Schema grammar directly. */ - readonly description?: string + readonly description?: string, + + readonly $defs?: Defs }; export type GbnfJsonStringSchema = GbnfJsonBasicStringSchema | GbnfJsonFormatStringSchema; export type GbnfJsonBasicStringSchema = { @@ -78,14 +92,17 @@ export type GbnfJsonFormatStringSchema = { */ readonly description?: string }; -export type GbnfJsonObjectSchema = { +export type GbnfJsonObjectSchema< + Keys extends string = string, + Defs extends GbnfJsonDefList> = {} +> = { readonly type: "object", - readonly properties?: {readonly [key in Keys]: GbnfJsonSchema}, + readonly properties?: {readonly [key in Keys]: GbnfJsonSchema>}, /** * Unlike the JSON Schema spec, `additionalProperties` defaults to `false` to avoid breaking existing code. */ - readonly additionalProperties?: boolean | GbnfJsonSchema, + readonly additionalProperties?: boolean | GbnfJsonSchema>, /** * Make sure you define `additionalProperties` for this to have any effect. @@ -121,12 +138,14 @@ export type GbnfJsonObjectSchema = { * * Only passed to the model when using function calling, and has no effect when using JSON Schema grammar directly. */ - readonly description?: string + readonly description?: string, + + readonly $defs?: Defs }; -export type GbnfJsonArraySchema = { +export type GbnfJsonArraySchema> = {}> = { readonly type: "array", - readonly items?: GbnfJsonSchema, - readonly prefixItems?: readonly GbnfJsonSchema[], + readonly items?: GbnfJsonSchema>, + readonly prefixItems?: readonly GbnfJsonSchema>[], /** * When using `minItems` and/or `maxItems`, @@ -147,7 +166,23 @@ export type GbnfJsonArraySchema = { * * Only passed to the model when using function calling, and has no effect when using JSON Schema grammar directly. */ - readonly description?: string + readonly description?: string, + + readonly $defs?: Defs +}; +export type GbnfJsonRefSchema> = {}> = { + readonly $ref: keyof NoInfer extends never + ? never + : `#/$defs/${OnlyStringKeys>}`, + + /** + * A description of what you expect the model to set this value to. + * + * Only passed to the model when using function calling, and has no effect when using JSON Schema grammar directly. + */ + readonly description?: string, + + readonly $defs?: Defs }; @@ -156,7 +191,7 @@ export type GbnfJsonArraySchema = { */ export type GbnfJsonSchemaToType = GbnfJsonSchemaToTSType; -export type GbnfJsonSchemaToTSType = +export type GbnfJsonSchemaToTSType> = {}> = Readonly extends T ? undefined : undefined extends T @@ -171,13 +206,15 @@ export type GbnfJsonSchemaToTSType = ? T["const"] : T extends GbnfJsonEnumSchema ? T["enum"][number] - : T extends GbnfJsonOneOfSchema - ? GbnfJsonSchemaToType - : T extends GbnfJsonObjectSchema - ? GbnfJsonObjectSchemaToType - : T extends GbnfJsonArraySchema - ? ArrayTypeToType - : undefined; + : T extends GbnfJsonOneOfSchema> + ? GbnfJsonSchemaToTSType, T["$defs"]>> + : T extends GbnfJsonObjectSchema> + ? GbnfJsonObjectSchemaToType> + : T extends GbnfJsonArraySchema> + ? ArrayTypeToType, T["$defs"]>> + : T extends GbnfJsonRefSchema + ? GbnfJsonRefSchemaToType, T["$defs"]>> + : undefined; type GbnfJsonBasicStringSchemaToType = T["maxLength"] extends 0 @@ -187,7 +224,7 @@ type GbnfJsonBasicStringSchemaToType = type GbnfJsonBasicSchemaToType = T extends GbnfJsonSchemaImmutableType ? ImmutableTypeToType - : T extends GbnfJsonSchemaImmutableType[] + : T[number] extends GbnfJsonSchemaImmutableType ? ImmutableTypeToType : never; @@ -205,7 +242,8 @@ type ImmutableTypeToType = : never; type ArrayTypeToType< - T extends GbnfJsonArraySchema, + T extends GbnfJsonArraySchema>, + Defs extends GbnfJsonDefList = {}, MinItems extends number = T["minItems"] extends number ? T["prefixItems"] extends readonly GbnfJsonSchema[] ? keyof T["prefixItems"] extends T["minItems"] @@ -222,45 +260,49 @@ type ArrayTypeToType< ? ( T["maxItems"] extends MinItems ? [ - ...GbnfJsonOrderedArrayTypes, + ...GbnfJsonOrderedArrayTypes>, ...IndexRangeWithSkip< MinItems, T["prefixItems"]["length"], T["items"] extends GbnfJsonSchema - ? GbnfJsonSchemaToType + ? GbnfJsonSchemaToTSType> : GbnfJsonAnyValue > ] : [ - ...GbnfJsonOrderedArrayTypes, + ...GbnfJsonOrderedArrayTypes>, ...( T["items"] extends GbnfJsonSchema - ? GbnfJsonSchemaToType + ? GbnfJsonSchemaToTSType> : GbnfJsonAnyValue )[] ] ) : T["maxItems"] extends MinItems ? [ - ...GbnfJsonOrderedArrayTypes, + ...GbnfJsonOrderedArrayTypes>, ...( T["items"] extends GbnfJsonSchema - ? IndexRangeWithSkip> + ? IndexRangeWithSkip< + T["maxItems"], + T["prefixItems"]["length"], + GbnfJsonSchemaToTSType> + > : IndexRangeWithSkip ) ] : [ - ...GbnfJsonOrderedArrayTypes, + ...GbnfJsonOrderedArrayTypes>, ...IndexRangeWithSkip< MinItems, T["prefixItems"]["length"], T["items"] extends GbnfJsonSchema - ? GbnfJsonSchemaToType + ? GbnfJsonSchemaToTSType> : GbnfJsonAnyValue >, ...( T["items"] extends GbnfJsonSchema - ? GbnfJsonSchemaToType + ? GbnfJsonSchemaToTSType> : GbnfJsonAnyValue )[] ] @@ -268,12 +310,18 @@ type ArrayTypeToType< : T["items"] extends GbnfJsonSchema ? ( MinItems extends 0 - ? GbnfJsonSchemaToType[] + ? GbnfJsonSchemaToTSType>[] : T["maxItems"] extends MinItems - ? IndexRange> + ? IndexRange< + T["maxItems"], + GbnfJsonSchemaToTSType> + > : [ - ...IndexRange>, - ...GbnfJsonSchemaToType[] + ...IndexRange< + MinItems, + GbnfJsonSchemaToTSType> + >, + ...GbnfJsonSchemaToTSType>[] ] ) : ( @@ -289,21 +337,31 @@ type ArrayTypeToType< type GbnfJsonObjectSchemaToType< - T extends GbnfJsonObjectSchema, + T extends GbnfJsonObjectSchema>, + Defs extends GbnfJsonDefList = {}, Props extends Readonly> | undefined = T["properties"], AdditionalProps extends true | false | GbnfJsonSchema | undefined = T["additionalProperties"], PropsMap = Props extends undefined ? {} - : {-readonly [P in keyof Props]: GbnfJsonSchemaToType}, + : {-readonly [P in keyof Props]: GbnfJsonSchemaToTSType>}, Res = AdditionalProps extends undefined | false ? PropsMap : AdditionalProps extends true ? PropsMap & {[key: string]: GbnfJsonAnyValue} : AdditionalProps extends GbnfJsonSchema - ? PropsMap & {[key: string]: GbnfJsonSchemaToType} + ? PropsMap & { + [key: string]: GbnfJsonSchemaToTSType> + } : PropsMap > = Res; +type GbnfJsonRefSchemaToType, Defs extends GbnfJsonDefList = {}> = + T["$ref"] extends `#/$defs/${infer Key}` + ? Key extends keyof Defs + ? GbnfJsonSchemaToTSType + : never + : never; + type GbnfJsonAnyValue = string | number | boolean | null | GbnfJsonAnyValue[] | {[key: string]: GbnfJsonAnyValue}; export function isGbnfJsonConstSchema(schema: GbnfJsonSchema): schema is GbnfJsonConstSchema { @@ -334,6 +392,10 @@ export function isGbnfJsonArraySchema(schema: GbnfJsonSchema): schema is GbnfJso return (schema as GbnfJsonArraySchema).type === "array"; } +export function isGbnfJsonRefSchema(schema: GbnfJsonSchema): schema is GbnfJsonRefSchema> { + return typeof (schema as GbnfJsonRefSchema).$ref === "string"; +} + export function isGbnfJsonBasicSchemaIncludesType( schema: GbnfJsonBasicSchema, type: T ): schema is GbnfJsonBasicSchema & {type: T | (T | GbnfJsonSchemaImmutableType)[]} { @@ -343,6 +405,19 @@ export function isGbnfJsonBasicSchemaIncludesType = { + [K in keyof T]: K extends string ? K : never; +}[keyof T]; + +type CombineDefs< + Defs1 extends GbnfJsonDefList, + Param2 extends Defs1 | Defs2 | undefined, + Defs2 extends GbnfJsonDefList = {} +> = + undefined extends NoInfer + ? Defs1 + : Defs1 & Param2; + type IndexRange< Length extends number, FillType = number, @@ -371,6 +446,6 @@ type _IndexRangeWithSkip< ? Value : _IndexRangeWithSkip<[...Value, FillType], [...ConditionValue, ConditionValue["length"]], MaxLength, FillType>; -type GbnfJsonOrderedArrayTypes = { - -readonly [P in keyof T]: GbnfJsonSchemaToType +type GbnfJsonOrderedArrayTypes = {}> = { + -readonly [P in keyof T]: GbnfJsonSchemaToTSType }; diff --git a/src/utils/gbnfJson/utils/defsScope.ts b/src/utils/gbnfJson/utils/defsScope.ts new file mode 100644 index 00000000..46fedaec --- /dev/null +++ b/src/utils/gbnfJson/utils/defsScope.ts @@ -0,0 +1,24 @@ +import {MultiKeyMap} from "lifecycle-utils"; +import {GbnfJsonSchema} from "../types.js"; + +export class DefScopeDefs { + public defScopeDefs: MultiKeyMap<[string, GbnfJsonSchema], Record> = new MultiKeyMap(); + + public registerDefs(scopeDefs: Record) { + for (const [defName, def] of Object.entries(scopeDefs)) + this.defScopeDefs.set([defName, def], scopeDefs); + } +} + +export function joinDefs( + parent: Record, + current?: Record +) { + if (current == null || Object.keys(current).length === 0) + return parent; + + return { + ...parent, + ...current + }; +} diff --git a/src/utils/gbnfJson/utils/getGbnfJsonTerminalForGbnfJsonSchema.ts b/src/utils/gbnfJson/utils/getGbnfJsonTerminalForGbnfJsonSchema.ts index 7777da11..a78cd82a 100644 --- a/src/utils/gbnfJson/utils/getGbnfJsonTerminalForGbnfJsonSchema.ts +++ b/src/utils/gbnfJson/utils/getGbnfJsonTerminalForGbnfJsonSchema.ts @@ -10,21 +10,83 @@ import {GbnfNull} from "../terminals/GbnfNull.js"; import {GbnfGrammarGenerator} from "../GbnfGrammarGenerator.js"; import { GbnfJsonSchema, isGbnfJsonArraySchema, isGbnfJsonBasicSchemaIncludesType, isGbnfJsonConstSchema, isGbnfJsonEnumSchema, - isGbnfJsonObjectSchema, isGbnfJsonOneOfSchema, isGbnfJsonBasicStringSchema, isGbnfJsonFormatStringSchema + isGbnfJsonObjectSchema, isGbnfJsonOneOfSchema, isGbnfJsonBasicStringSchema, isGbnfJsonFormatStringSchema, isGbnfJsonRefSchema } from "../types.js"; import {getConsoleLogPrefix} from "../../getConsoleLogPrefix.js"; import {GbnfAnyJson} from "../terminals/GbnfAnyJson.js"; import {GbnfFormatString} from "../terminals/GbnfFormatString.js"; +import {GbnfRef} from "../terminals/GbnfRef.js"; import {getGbnfJsonTerminalForLiteral} from "./getGbnfJsonTerminalForLiteral.js"; import {GbnfJsonScopeState} from "./GbnfJsonScopeState.js"; +import {joinDefs} from "./defsScope.js"; +const maxNestingScope = 512; export function getGbnfJsonTerminalForGbnfJsonSchema( - schema: GbnfJsonSchema, grammarGenerator: GbnfGrammarGenerator, scopeState: GbnfJsonScopeState = new GbnfJsonScopeState() + schema: GbnfJsonSchema, + grammarGenerator: GbnfGrammarGenerator, + scopeState: GbnfJsonScopeState = new GbnfJsonScopeState(), + defs: Record = {} ): GbnfTerminal { - if (isGbnfJsonOneOfSchema(schema)) { + if (scopeState.currentNestingScope >= maxNestingScope) + throw new Error("Maximum nesting scope exceeded. Ensure that your schema does not have circular references or excessive nesting."); + + if (isGbnfJsonRefSchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + grammarGenerator.registerDefs(currentDefs); + + const ref = schema?.$ref; + const referencePrefix = "#/$defs/"; + if (ref == null || !ref.startsWith(referencePrefix)) { + console.warn( + getConsoleLogPrefix(true, false), + `Reference "${ref}" does not start with "${referencePrefix}". ` + + 'Using an "any" type instead of a reference.' + ); + return new GbnfAnyJson(scopeState); + } + + const defName = ref.slice(referencePrefix.length); + const def = currentDefs[defName]; + if (def == null) { + console.warn( + getConsoleLogPrefix(true, false), + `Reference "${ref}" does not point to an existing definition. ` + + 'Using an "any" type instead of a reference.' + ); + return new GbnfAnyJson(scopeState); + } + + return new GbnfRef({ + getValueTerminal() { + const scopeDefs = grammarGenerator.defScopeDefs.get([defName, def]); + + return getGbnfJsonTerminalForGbnfJsonSchema( + def, + grammarGenerator, + new GbnfJsonScopeState({ + allowNewLines: false, + scopePadSpaces: scopeState.settings.scopePadSpaces + }, 0), + scopeDefs ?? {} + ); + }, + def, + defName + }); + } else if (isGbnfJsonOneOfSchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + grammarGenerator.registerDefs(currentDefs); + const values = schema.oneOf - .map((altSchema) => getGbnfJsonTerminalForGbnfJsonSchema(altSchema, grammarGenerator, scopeState)); + .map((altSchema) => ( + getGbnfJsonTerminalForGbnfJsonSchema( + altSchema, + grammarGenerator, + scopeState, + currentDefs + ) + )); return new GbnfOr(values); } else if (isGbnfJsonConstSchema(schema)) { @@ -33,6 +95,8 @@ export function getGbnfJsonTerminalForGbnfJsonSchema( return new GbnfOr(schema.enum.map((item) => getGbnfJsonTerminalForLiteral(item))); } else if (isGbnfJsonObjectSchema(schema)) { const propertiesEntries = Object.entries(schema.properties ?? {}); + const currentDefs = joinDefs(defs, schema.$defs); + grammarGenerator.registerDefs(currentDefs); let maxProperties = schema.maxProperties; if (schema.properties != null && maxProperties != null && maxProperties < propertiesEntries.length) { @@ -50,19 +114,32 @@ export function getGbnfJsonTerminalForGbnfJsonSchema( return { required: true, key: new GbnfStringValue(propName), - value: getGbnfJsonTerminalForGbnfJsonSchema(propSchema, grammarGenerator, scopeState.getForNewScope()) + value: getGbnfJsonTerminalForGbnfJsonSchema( + propSchema, + grammarGenerator, + scopeState.getForNewScope(), + currentDefs + ) }; }), additionalProperties: (schema.additionalProperties == null || schema.additionalProperties === false) ? undefined : schema.additionalProperties === true ? new GbnfAnyJson(scopeState.getForNewScope()) - : getGbnfJsonTerminalForGbnfJsonSchema(schema.additionalProperties, grammarGenerator, scopeState.getForNewScope()), + : getGbnfJsonTerminalForGbnfJsonSchema( + schema.additionalProperties, + grammarGenerator, + scopeState.getForNewScope(), + currentDefs + ), minProperties: schema.minProperties, maxProperties, scopeState }); } else if (isGbnfJsonArraySchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + grammarGenerator.registerDefs(currentDefs); + let maxItems = schema.maxItems; if (schema.prefixItems != null && maxItems != null && maxItems < schema.prefixItems.length) { console.warn( @@ -76,11 +153,11 @@ export function getGbnfJsonTerminalForGbnfJsonSchema( return new GbnfArray({ items: schema.items == null ? undefined - : getGbnfJsonTerminalForGbnfJsonSchema(schema.items, grammarGenerator, scopeState.getForNewScope()), + : getGbnfJsonTerminalForGbnfJsonSchema(schema.items, grammarGenerator, scopeState.getForNewScope(), currentDefs), prefixItems: schema.prefixItems == null ? undefined : schema.prefixItems.map((item) => ( - getGbnfJsonTerminalForGbnfJsonSchema(item, grammarGenerator, scopeState.getForNewScope()) + getGbnfJsonTerminalForGbnfJsonSchema(item, grammarGenerator, scopeState.getForNewScope(), currentDefs) )), minItems: schema.minItems, maxItems, diff --git a/src/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.ts b/src/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.ts index 69d13f0f..1bfca85f 100644 --- a/src/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.ts +++ b/src/utils/gbnfJson/utils/validateObjectAgainstGbnfSchema.ts @@ -2,8 +2,9 @@ import { GbnfJsonArraySchema, GbnfJsonConstSchema, GbnfJsonEnumSchema, GbnfJsonObjectSchema, GbnfJsonOneOfSchema, GbnfJsonSchema, GbnfJsonSchemaImmutableType, GbnfJsonSchemaToType, GbnfJsonBasicStringSchema, GbnfJsonFormatStringSchema, isGbnfJsonArraySchema, isGbnfJsonConstSchema, isGbnfJsonEnumSchema, isGbnfJsonObjectSchema, isGbnfJsonOneOfSchema, isGbnfJsonBasicStringSchema, - isGbnfJsonFormatStringSchema + isGbnfJsonFormatStringSchema, isGbnfJsonRefSchema, GbnfJsonRefSchema } from "../types.js"; +import {DefScopeDefs, joinDefs} from "./defsScope.js"; export function validateObjectAgainstGbnfSchema(object: any, schema: unknown): boolean; @@ -37,13 +38,20 @@ class TechnicalValidationError extends Error { } } -function validateObjectWithGbnfSchema(object: any, schema: T): object is GbnfJsonSchemaToType { - if (isGbnfJsonArraySchema(schema)) - return validateArray(object, schema); +function validateObjectWithGbnfSchema( + object: any, + schema: T, + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): object is GbnfJsonSchemaToType { + if (isGbnfJsonRefSchema(schema)) + return validateRef(object, schema, defs, defScopeDefs); + else if (isGbnfJsonArraySchema(schema)) + return validateArray(object, schema, defs, defScopeDefs); else if (isGbnfJsonObjectSchema(schema)) - return validateObject(object, schema); + return validateObject(object, schema, defs, defScopeDefs); else if (isGbnfJsonOneOfSchema(schema)) - return validateOneOf(object, schema); + return validateOneOf(object, schema, defs, defScopeDefs); else if (isGbnfJsonBasicStringSchema(schema)) return validateBasicString(object, schema); else if (isGbnfJsonFormatStringSchema(schema)) @@ -70,7 +78,42 @@ function validateObjectWithGbnfSchema(object: any, sch throw new TechnicalValidationError(`Expected type "${schema.type}" but got "${object === null ? "null" : typeof object}"`); } -function validateArray(object: any, schema: T): object is GbnfJsonSchemaToType { +function validateRef>>( + object: any, + schema: T, + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): object is GbnfJsonSchemaToType { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + + const ref = schema.$ref; + const referencePrefix = "#/$defs/"; + + if (ref == null || !ref.startsWith(referencePrefix)) { + // if the $ref is invalid, a warning was already shows when the grammar was generated, + // so we don't perform validation on the object as it's considered an "any" type + return true; + } + + const defName = ref.slice(referencePrefix.length); + const def = currentDefs[defName]; + if (def == null) { + // if the $ref points to a non-existing def, a warning was already shows when the grammar was generated, + // so we don't perform validation on the object as it's considered an "any" type + return true; + } + + const scopeDefs = defScopeDefs.defScopeDefs.get([defName, def]); + return validateObjectWithGbnfSchema(object, def, scopeDefs ?? {}, defScopeDefs); +} + +function validateArray( + object: any, + schema: T, + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): object is GbnfJsonSchemaToType { if (!(object instanceof Array)) throw new TechnicalValidationError(`Expected an array but got "${typeof object}"`); @@ -93,28 +136,38 @@ function validateArray(object: any, schema: T): o let res = true; let index = 0; + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); if (schema.prefixItems != null) { for (const item of schema.prefixItems) { - res &&= validateObjectWithGbnfSchema(object[index], item); + res &&= validateObjectWithGbnfSchema(object[index], item, currentDefs, defScopeDefs); index++; } } if (schema.items != null) { for (; index < object.length; index++) - res &&= validateObjectWithGbnfSchema(object[index], schema.items); + res &&= validateObjectWithGbnfSchema(object[index], schema.items, currentDefs, defScopeDefs); } return res; } -function validateObject(object: any, schema: T): object is GbnfJsonSchemaToType { +function validateObject( + object: any, + schema: T, + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): object is GbnfJsonSchemaToType { if (typeof object !== "object" || object === null) throw new TechnicalValidationError(`Expected an object but got "${typeof object}"`); let res = true; + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + const objectKeys = Object.keys(object); const objectKeysSet = new Set(objectKeys); const schemaKeys = Object.keys(schema.properties ?? {}); @@ -131,7 +184,7 @@ function validateObject(object: any, schema: T): throw new TechnicalValidationError(`Unexpected keys: ${extraKeys.map((key) => JSON.stringify(key)).join(", ")}`); else if (schema.additionalProperties !== true) { for (const key of extraKeys) - res &&= validateObjectWithGbnfSchema(object[key], schema.additionalProperties); + res &&= validateObjectWithGbnfSchema(object[key], schema.additionalProperties, currentDefs, defScopeDefs); } } @@ -140,7 +193,7 @@ function validateObject(object: any, schema: T): throw new TechnicalValidationError(`Missing keys: ${missingKeys.map((key) => JSON.stringify(key)).join(", ")}`); for (const key of schemaKeys) - res &&= validateObjectWithGbnfSchema(object[key], schema.properties![key]!); + res &&= validateObjectWithGbnfSchema(object[key], schema.properties![key]!, currentDefs, defScopeDefs); if (schema.additionalProperties != null && schema.additionalProperties !== false) { if (objectKeys.length < minProperties) { @@ -159,10 +212,18 @@ function validateObject(object: any, schema: T): return res; } -function validateOneOf(object: any, schema: T): object is GbnfJsonSchemaToType { +function validateOneOf( + object: any, + schema: T, + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): object is GbnfJsonSchemaToType { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + for (const item of schema.oneOf) { try { - return validateObjectWithGbnfSchema(object, item); + return validateObjectWithGbnfSchema(object, item, currentDefs, defScopeDefs); } catch (err) { if (err instanceof TechnicalValidationError) continue; diff --git a/src/utils/getTypeScriptTypeStringForGbnfJsonSchema.ts b/src/utils/getTypeScriptTypeStringForGbnfJsonSchema.ts index aac2043f..5cb0c0c5 100644 --- a/src/utils/getTypeScriptTypeStringForGbnfJsonSchema.ts +++ b/src/utils/getTypeScriptTypeStringForGbnfJsonSchema.ts @@ -1,14 +1,63 @@ import { GbnfJsonSchema, isGbnfJsonArraySchema, isGbnfJsonBasicSchemaIncludesType, isGbnfJsonBasicStringSchema, isGbnfJsonConstSchema, - isGbnfJsonEnumSchema, isGbnfJsonFormatStringSchema, isGbnfJsonObjectSchema, isGbnfJsonOneOfSchema + isGbnfJsonEnumSchema, isGbnfJsonFormatStringSchema, isGbnfJsonObjectSchema, isGbnfJsonOneOfSchema, isGbnfJsonRefSchema } from "./gbnfJson/types.js"; +import {DefScopeDefs, joinDefs} from "./gbnfJson/utils/defsScope.js"; const maxTypeRepetition = 10; export function getTypeScriptTypeStringForGbnfJsonSchema(schema: GbnfJsonSchema): string { - if (isGbnfJsonOneOfSchema(schema)) { + return _getTypeScriptTypeStringForGbnfJsonSchema(schema); +} + +function _getTypeScriptTypeStringForGbnfJsonSchema( + schema: GbnfJsonSchema, + printedDefs: Set = new Set(), + defs: Record = {}, + defScopeDefs: DefScopeDefs = new DefScopeDefs() +): string { + if (isGbnfJsonRefSchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + + const ref = schema?.$ref; + const referencePrefix = "#/$defs/"; + if (ref == null || !ref.startsWith(referencePrefix)) + return "any"; + + const defName = ref.slice(referencePrefix.length); + const def = currentDefs[defName]; + if (def == null) + return "any"; + else if (printedDefs.has(def)) { + return [ + "/* ", + defName + .replaceAll("\n", " ") + .replaceAll("*/", "* /"), + " type */ any" + ].join(""); + } + + const scopeDefs = defScopeDefs.defScopeDefs.get([defName, def]); + if (scopeDefs == null) + return "any"; + + printedDefs.add(def); + return [ + "/* Type: ", + defName + .replaceAll("\n", " ") + .replaceAll("*/", "* /"), + " */ ", + _getTypeScriptTypeStringForGbnfJsonSchema(def, printedDefs, scopeDefs, defScopeDefs) + ].join(""); + } else if (isGbnfJsonOneOfSchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + const values = schema.oneOf - .map((altSchema) => getTypeScriptTypeStringForGbnfJsonSchema(altSchema)); + .map((altSchema) => _getTypeScriptTypeStringForGbnfJsonSchema(altSchema, printedDefs, currentDefs, defScopeDefs)); return values.join(" | "); } else if (isGbnfJsonConstSchema(schema)) { @@ -19,12 +68,15 @@ export function getTypeScriptTypeStringForGbnfJsonSchema(schema: GbnfJsonSchema) .filter((item) => item !== "") .join(" | "); } else if (isGbnfJsonObjectSchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + let addNewline = false; const valueTypes = Object.entries(schema.properties ?? {}) .map(([propName, propSchema]) => { const escapedValue = JSON.stringify(propName) ?? ""; const keyText = escapedValue.slice(1, -1) === propName ? propName : escapedValue; - const valueType = getTypeScriptTypeStringForGbnfJsonSchema(propSchema); + const valueType = _getTypeScriptTypeStringForGbnfJsonSchema(propSchema, printedDefs, currentDefs, defScopeDefs); if (keyText === "" || valueType === "") return ""; @@ -103,7 +155,7 @@ export function getTypeScriptTypeStringForGbnfJsonSchema(schema: GbnfJsonSchema) : schema.additionalProperties === true ? "{[key: string]: any}" : schema.additionalProperties != null - ? ["{[key: string]: ", getTypeScriptTypeStringForGbnfJsonSchema(schema.additionalProperties), "}"].join("") + ? ["{[key: string]: ", _getTypeScriptTypeStringForGbnfJsonSchema(schema.additionalProperties), "}"].join("") : undefined; if (valueTypes.length === 0 && additionalPropertiesMapSyntax != null) @@ -113,14 +165,17 @@ export function getTypeScriptTypeStringForGbnfJsonSchema(schema: GbnfJsonSchema) return knownPropertiesMapSyntax; } else if (isGbnfJsonArraySchema(schema)) { + const currentDefs = joinDefs(defs, schema.$defs); + defScopeDefs.registerDefs(currentDefs); + if (schema.maxItems === 0) return "[]"; if (schema.prefixItems != null && schema.prefixItems.length > 0) { - const valueTypes = schema.prefixItems.map((item) => getTypeScriptTypeStringForGbnfJsonSchema(item)); + const valueTypes = schema.prefixItems.map((item) => _getTypeScriptTypeStringForGbnfJsonSchema(item)); const restType = schema.items != null - ? getTypeScriptTypeStringForGbnfJsonSchema(schema.items) + ? _getTypeScriptTypeStringForGbnfJsonSchema(schema.items, printedDefs, currentDefs, defScopeDefs) : "any"; if (schema.minItems != null) { @@ -133,7 +188,7 @@ export function getTypeScriptTypeStringForGbnfJsonSchema(schema: GbnfJsonSchema) return "[" + valueTypes.join(", ") + "]"; } else if (schema.items != null) { - const valuesType = getTypeScriptTypeStringForGbnfJsonSchema(schema.items); + const valuesType = _getTypeScriptTypeStringForGbnfJsonSchema(schema.items, printedDefs, currentDefs, defScopeDefs); if (valuesType === "") return "[]"; diff --git a/test/modelDependent/qwen3-0.6b/functions.test.ts b/test/modelDependent/qwen3-0.6b/functions.test.ts new file mode 100644 index 00000000..07bd2899 --- /dev/null +++ b/test/modelDependent/qwen3-0.6b/functions.test.ts @@ -0,0 +1,250 @@ +import {describe, expect, test} from "vitest"; +import { + defineChatSessionFunction, JinjaTemplateChatWrapper, LlamaChatSession, QwenChatWrapper, resolveChatWrapper +} from "../../../src/index.js"; +import {getModelFile} from "../../utils/modelFiles.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; + +describe("qwen3 0.6b", () => { + describe("functions", () => { + test("get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Qwen3-0.6B-Q8_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 1024 + }); + const chatSession = new LlamaChatSession({ + contextSequence: context.getSequence() + }); + expect(chatSession.chatWrapper).to.be.instanceof(QwenChatWrapper); + + const promptOptions: Parameters[1] = { + functions: { + getNthWord: defineChatSessionFunction({ + description: "Get an n-th word", + params: { + type: "object", + properties: { + n: { + enum: [1, 2, 3, 4] + } + } + }, + handler(params) { + return ["very", "secret", "this", "hello"][params.n - 1]; + } + }) + } + } as const; + + const res = await chatSession.prompt("What is the second word? No yapping, no formatting", { + ...promptOptions, + maxTokens: 250, + budgets: { + thoughtTokens: 100 + } + }); + + expect(res.trim()).to.be.eq('The second word is "secret".'); + + const res2 = await chatSession.prompt("Explain what this word means", { + ...promptOptions, + maxTokens: 40, + budgets: { + thoughtTokens: 0 + } + }); + + expect(res2.length).to.be.greaterThan(1); + }); + + test("get n-th word using jinja template", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Qwen3-0.6B-Q8_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 1024 + }); + const chatSession = new LlamaChatSession({ + contextSequence: context.getSequence(), + chatWrapper: resolveChatWrapper(model, { + type: "jinjaTemplate" + }) + }); + expect(chatSession.chatWrapper).to.be.instanceof(JinjaTemplateChatWrapper); + + const promptOptions: Parameters[1] = { + functions: { + getNthWord: defineChatSessionFunction({ + description: "Get an n-th word", + params: { + type: "object", + properties: { + n: { + enum: [1, 2, 3, 4] + } + } + }, + handler(params) { + return ["very", "secret", "this", "hello"][params.n - 1]; + } + }) + } + } as const; + + const res = await chatSession.prompt("What is the second word? No yapping, no formatting", { + ...promptOptions, + maxTokens: 250, + budgets: { + thoughtTokens: 100 + } + }); + + expect(res.trim()).to.be.eq('The second word is "secret".'); + + const res2 = await chatSession.prompt("Explain what this word means", { + ...promptOptions, + maxTokens: 40, + budgets: { + thoughtTokens: 0 + } + }); + + expect(res2.length).to.be.greaterThan(1); + }); + + test("$defs and $ref with recursion", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Qwen3-0.6B-Q8_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 1024 + }); + const chatSession = new LlamaChatSession({ + contextSequence: context.getSequence() + }); + expect(chatSession.chatWrapper).to.be.instanceof(QwenChatWrapper); + + const promptOptions = { + functions: { + getNthWord: defineChatSessionFunction({ + description: "Get an n-th word", + params: { + type: "object", + $defs: { + nthWord: { + enum: [1, 2, 3, 4] + } + }, + properties: { + n: { + $ref: "#/$defs/nthWord" + } + } + }, + handler(params) { + return ["very", "secret", "this", "hello"][params.n - 1]; + } + }), + notifyOwner: defineChatSessionFunction({ + description: "Send a notification to the owner, and create sub notifications", + params: { + $ref: "#/$defs/notification", + $defs: { + notification: { + type: "object", + properties: { + message: { + type: "string" + }, + subNotifications: { + type: "array", + items: { + $ref: "#/$defs/notification" + } + } + } + } + } + }, + handler(notification) { + createdNotifications.push(notification); + return "Notification created"; + } + }) + } + } as const satisfies Parameters[1]; + const createdNotifications: Parameters[0][] = []; + + const res = await chatSession.prompt("What is the second word? No yapping, no formatting", { + ...promptOptions, + maxTokens: 250, + budgets: { + thoughtTokens: 100 + } + }); + + expect(res.trim()).to.be.eq('The second word is "secret".'); + + const res2 = await chatSession.prompt([ + "The owner has 3 apps: App1, App2, and App3.", + "Notify the owner with a main notifications about 'apps time', with sub notifications for each app with the app's name.", + "Under each app sub-notification add a sub-notification with the app's number." + ].join("\n"), { + ...promptOptions, + maxTokens: 200, + budgets: { + thoughtTokens: 0 + } + }); + + expect(res2.length).to.be.greaterThan(1); + expect(createdNotifications).toMatchInlineSnapshot(` + [ + { + "message": "apps time", + "subNotifications": [ + { + "message": "App1", + "subNotifications": [ + { + "message": "1. App1 sub notification 1", + "subNotifications": [], + }, + ], + }, + { + "message": "App2", + "subNotifications": [ + { + "message": "2. App2 sub notification 2", + "subNotifications": [], + }, + ], + }, + { + "message": "App3", + "subNotifications": [ + { + "message": "3. App3 sub notification 3", + "subNotifications": [], + }, + ], + }, + ], + }, + ] + `); + }); + }); +}); diff --git a/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts b/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts index 0cf624e1..81c4b686 100644 --- a/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts +++ b/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts @@ -39,6 +39,58 @@ describe("FunctionaryChatWrapper", () => { return Math.floor(Math.random() * (params.max - params.min + 1) + params.min); } }), + notifyOwner: defineChatSessionFunction({ + description: "Send a notification to the owner, and create sub notifications", + params: { + $ref: "#/$defs/notification", + $defs: { + notification: { + type: "object", + properties: { + message: { + type: "string" + }, + subNotifications: { + type: "array", + items: { + $ref: "#/$defs/notification" + } + } + } + } + } + }, + handler(notification) { + return "Notification created: " + notification.message; + } + }), + notifyOwner2: defineChatSessionFunction({ + description: "Send a notification to the owner, and create sub notifications", + params: { + $ref: "#/$defs/notification", + $defs: { + notification: { + type: "object", + properties: { + message: { + type: "string", + description: "Notification message" + }, + subNotifications: { + type: "array", + description: "Sub notifications", + items: { + $ref: "#/$defs/notification" + } + } + } + } + } + }, + handler(notification) { + return "Notification created: " + notification.message; + } + }), func1: defineChatSessionFunction({ description: "Some function", params: { @@ -188,6 +240,18 @@ describe("FunctionaryChatWrapper", () => { // Get a random number type getRandomNumber = (_: {min: number, max: number}) => any; + // Send a notification to the owner, and create sub notifications + type notifyOwner = (_: /* Type: notification */ {message: string, subNotifications: (/* notification type */ any)[]}) => any; + + // Send a notification to the owner, and create sub notifications + type notifyOwner2 = (_: /* Type: notification */ { + // Notification message + message: string, + + // Sub notifications + subNotifications: (/* notification type */ any)[] + }) => any; + // Some function type func1 = (_: { // Some message @@ -403,6 +467,18 @@ describe("FunctionaryChatWrapper", () => { // Get a random number type getRandomNumber = (_: {min: number, max: number}) => any; + // Send a notification to the owner, and create sub notifications + type notifyOwner = (_: /* Type: notification */ {message: string, subNotifications: (/* notification type */ any)[]}) => any; + + // Send a notification to the owner, and create sub notifications + type notifyOwner2 = (_: /* Type: notification */ { + // Notification message + message: string, + + // Sub notifications + subNotifications: (/* notification type */ any)[] + }) => any; + // Some function type func1 = (_: { // Some message diff --git a/test/standalone/llamaEvaluator/LlamaGrammar.test.ts b/test/standalone/llamaEvaluator/LlamaGrammar.test.ts index f7e694c5..cab7d15c 100644 --- a/test/standalone/llamaEvaluator/LlamaGrammar.test.ts +++ b/test/standalone/llamaEvaluator/LlamaGrammar.test.ts @@ -2241,6 +2241,418 @@ describe("grammar for JSON schema", () => { }); }); }); + + describe("definitions and $ref", () => { + test("simple", async () => { + const llama = await getTestLlama(); + const grammar = new LlamaJsonSchemaGrammar(llama, { + type: "object", + properties: { + "message": { + $ref: "#/$defs/messageValue" + }, + "numberOfWordsInMessage": { + $ref: "#/$defs/numberOfWordsInMessageValue" + }, + "feelingGoodPercentage": { + type: ["number"] + }, + "feelingGood": { + $ref: "#/$defs/feeling" + }, + "feelingOverall": { + oneOf: [{ + $ref: "#/$defs/goodConst" + }, { + $ref: "#/$defs/badConst" + }] + }, + "verbsInMessage": { + type: "array", + items: { + $ref: "#/$defs/verb" + } + } + }, + $defs: { + messageValue: { + type: ["string", "null"] + }, + numberOfWordsInMessageValue: { + type: "integer" + }, + goodConst: { + const: "good" + }, + badConst: { + const: "bad" + }, + verb: { + type: "string" + }, + feeling: { + oneOf: [{ + $ref: "#/$defs/goodConst" + }, { + $ref: "#/$defs/badConst" + }] + } + } + } as const); + type schemaType = { + "message": string | null, + "numberOfWordsInMessage": number, + "feelingGoodPercentage": number, + "feelingGood": "good" | "bad", + "feelingOverall": "good" | "bad", + "verbsInMessage": string[] + }; + const exampleValidValue = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", "world"] + }; + const exampleValidValue2 = { + "message": null, + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "bad", + "feelingOverall": "bad", + "verbsInMessage": ["Hello", "world"] + }; + const exampleInvalidValue = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", 10] + }; + const exampleInvalidValue2 = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "average", + "verbsInMessage": ["Hello", "world"] + }; + const exampleInvalidValue3 = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", "world", true] + }; + const exampleInvalidValue4 = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", "world", {}] + }; + const exampleInvalidValue5 = { + "message": "Hello, world!", + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", "world", null] + }; + const exampleInvalidValue6 = { + "message": false, + "numberOfWordsInMessage": 3, + "feelingGoodPercentage": 0.5, + "feelingGood": "good", + "feelingOverall": "good", + "verbsInMessage": ["Hello", "world"] + }; + + expect(grammar.grammar).toMatchInlineSnapshot(` + "root ::= "{" whitespace-b-1-4-rule "\\"message\\"" ":" [ ]? rule0 comma-whitespace-b-1-4-rule "\\"numberOfWordsInMessage\\"" ":" [ ]? integer-number-rule comma-whitespace-b-1-4-rule "\\"feelingGoodPercentage\\"" ":" [ ]? fractional-number-rule comma-whitespace-b-1-4-rule "\\"feelingGood\\"" ":" [ ]? rule1 comma-whitespace-b-1-4-rule "\\"feelingOverall\\"" ":" [ ]? rule1 comma-whitespace-b-1-4-rule "\\"verbsInMessage\\"" ":" [ ]? rule2 whitespace-b-0-4-rule "}" "\\n\\n\\n\\n" [\\n]* + string-char-rule ::= [^"\\\\\\x7F\\x00-\\x1F] | "\\\\" ["\\\\/bfnrt] | "\\\\u" [0-9a-fA-F]{4} + string-rule ::= "\\"" string-char-rule* "\\"" + null-rule ::= "null" + rule0 ::= ( string-rule | null-rule ) + comma-whitespace-b-1-4-rule ::= "," ([\\n] (" " | "\\t") | [ ]?) + integer-number-rule ::= "-"? ("0" | [1-9] [0-9]{0,15}) ([eE] [-+]? ("0" | [1-9] [0-9]{0,15}))? + fractional-number-rule ::= "-"? ("0" | [1-9] [0-9]{0,15}) ("." [0-9]{1,16})? ([eE] [-+]? ("0" | [1-9] [0-9]{0,15}))? + val0 ::= "\\"good\\"" + val1 ::= "\\"bad\\"" + rule1 ::= ( val0 | val1 ) + comma-whitespace-b-2-4-rule ::= "," ([\\n] (" "{8} | "\\t\\t") | [ ]?) + whitespace-b-2-4-rule ::= [\\n] (" "{8} | "\\t\\t") | [ ]? + whitespace-b-1-4-rule ::= [\\n] (" " | "\\t") | [ ]? + rule2 ::= "[" whitespace-b-2-4-rule ( string-rule ( comma-whitespace-b-2-4-rule string-rule )* )? whitespace-b-1-4-rule "]" + whitespace-b-0-4-rule ::= [\\n] | [ ]?" + `); + + const parsedValue = grammar.parse(JSON.stringify(exampleValidValue)); + + expectTypeOf(parsedValue).toMatchObjectType(); + expect(parsedValue).toEqual(exampleValidValue); + expect(testGrammar(grammar, exampleValidValue)).to.eql(true); + expect(testGrammar(grammar, exampleValidValue, "pretty")).to.eql(true); + expect(testGrammar(grammar, exampleValidValue, "dumps")).to.eql(true); + + const parsedValue2 = grammar.parse(JSON.stringify(exampleValidValue2)); + + expectTypeOf(parsedValue2).toMatchObjectType(); + expect(parsedValue2).toEqual(exampleValidValue2); + expect(testGrammar(grammar, exampleValidValue2)).to.eql(true); + expect(testGrammar(grammar, exampleValidValue2, "pretty")).to.eql(true); + expect(testGrammar(grammar, exampleValidValue2, "dumps")).to.eql(true); + + try { + grammar.parse(JSON.stringify(exampleInvalidValue)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected type \"string\" but got \"number\"]"); + expect(testGrammar(grammar, exampleInvalidValue)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue2)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 2 schemas but got \"average\"]"); + expect(testGrammar(grammar, exampleInvalidValue2)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue3)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected type \"string\" but got \"boolean\"]"); + expect(testGrammar(grammar, exampleInvalidValue3)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue4)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected type \"string\" but got \"object\"]"); + expect(testGrammar(grammar, exampleInvalidValue4)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue5)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected type \"string\" but got \"null\"]"); + expect(testGrammar(grammar, exampleInvalidValue5)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue6)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot('[Error: Expected one type of ["string", "null"] but got type "boolean"]'); + expect(testGrammar(grammar, exampleInvalidValue6)).to.eql(false); + } + }); + + test("recursive references", async () => { + const llama = await getTestLlama(); + const grammar = await llama.createGrammarForJsonSchema({ + type: "object", + properties: { + "message": { + $ref: "#/$defs/messageValue" + } + }, + $defs: { + messageValue: { + oneOf: [{ + type: "string" + }, { + type: "number" + }, { + $ref: "#/$defs/feeling" + }] + }, + goodConst: { + const: "good" + }, + badConst: { + const: "bad" + }, + feeling: { + oneOf: [{ + $ref: "#/$defs/goodConst" + }, { + $ref: "#/$defs/badConst" + }, { + type: "object", + properties: { + feel: { + oneOf: [{ + $ref: "#/$defs/goodConst" + }, { + $ref: "#/$defs/badConst" + }] + }, + message: { + $ref: "#/$defs/messageValue" + } + } + }] + } + } + } as const); + type schemaType = { + "message": string | number | { + feel: "good" | "bad", + message: schemaType["message"] + } + }; + const exampleValidValue = { + "message": "Hello, world!" + }; + const exampleValidValue2 = { + "message": { + "feel": "good", + "message": { + "feel": "bad", + "message": 6 + } + } + }; + const exampleInvalidValue = { + "message": false + }; + const exampleInvalidValue2 = { + "message": { + "feel": "ok", + "message": "Hello, world!" + } + }; + const exampleInvalidValue3 = { + "message": { + "feel": "good", + "message": { + "feel": "bad", + "message": true + } + } + }; + const exampleInvalidValue4 = { + "message": { + "feel": "good", + "message": { + "feel": "bad", + "message": {} + } + } + }; + const exampleInvalidValue5 = { + "message": { + "feel": "good", + "message": { + "feel": "bad", + "message": { + "feel": "good", + "message": { + "feel": "bad", + "message": null + } + } + } + } + }; + const exampleInvalidValue6 = { + "message": { + "message": "Hello, world!" + } + }; + + expect(grammar.grammar).toMatchInlineSnapshot(` + "root ::= "{" whitespace-b-1-4-rule "\\"message\\"" ":" [ ]? def0 whitespace-b-0-4-rule "}" "\\n\\n\\n\\n" [\\n]* + string-char-rule ::= [^"\\\\\\x7F\\x00-\\x1F] | "\\\\" ["\\\\/bfnrt] | "\\\\u" [0-9a-fA-F]{4} + string-rule ::= "\\"" string-char-rule* "\\"" + fractional-number-rule ::= "-"? ("0" | [1-9] [0-9]{0,15}) ("." [0-9]{1,16})? ([eE] [-+]? ("0" | [1-9] [0-9]{0,15}))? + val0 ::= "\\"good\\"" + val1 ::= "\\"bad\\"" + rule0 ::= ( val0 | val1 ) + comma-whitespace-no-new-lines-rule ::= "," [ ]? + whitespace-no-new-lines-rule ::= [ ]? + rule1 ::= "{" whitespace-no-new-lines-rule "\\"feel\\"" ":" [ ]? rule0 comma-whitespace-no-new-lines-rule "\\"message\\"" ":" [ ]? def0 whitespace-no-new-lines-rule "}" + rule2 ::= ( val0 | val1 | rule1 ) + rule3 ::= ( string-rule | fractional-number-rule | rule2 ) + def0 ::= rule3 + whitespace-b-1-4-rule ::= [\\n] (" " | "\\t") | [ ]? + whitespace-b-0-4-rule ::= [\\n] | [ ]?" + `); + + const parsedValue = grammar.parse(JSON.stringify(exampleValidValue)); + + expectTypeOf>().toExtend(); + expect(parsedValue).toEqual(exampleValidValue); + expect(testGrammar(grammar, exampleValidValue)).to.eql(true); + expect(testGrammar(grammar, exampleValidValue, "pretty")).to.eql(true); + expect(testGrammar(grammar, exampleValidValue, "dumps")).to.eql(true); + + const parsedValue2 = grammar.parse(JSON.stringify(exampleValidValue2)); + + expectTypeOf>().toExtend(); + expect(parsedValue2).toEqual(exampleValidValue2); + expect(testGrammar(grammar, exampleValidValue2)).to.eql(true); + // expect(testGrammar(grammar, exampleValidValue2, "pretty")).to.eql(true); + expect(testGrammar(grammar, exampleValidValue2, "dumps")).to.eql(true); + + try { + grammar.parse(JSON.stringify(exampleInvalidValue)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got false]"); + expect(testGrammar(grammar, exampleInvalidValue)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue2)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got {\"feel\":\"ok\",\"message\":\"Hello, world!\"}]"); + expect(testGrammar(grammar, exampleInvalidValue2)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue3)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got {\"feel\":\"good\",\"message\":{\"feel\":\"bad\",\"message\":true}}]"); + expect(testGrammar(grammar, exampleInvalidValue3)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue4)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got {\"feel\":\"good\",\"message\":{\"feel\":\"bad\",\"message\":{}}}]"); + expect(testGrammar(grammar, exampleInvalidValue4)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue5)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got {\"feel\":\"good\",\"message\":{\"feel\":\"bad\",\"message\":{\"feel\":\"good\",\"message\":{\"feel\":\"bad\",\"message\":null}}}}]"); + expect(testGrammar(grammar, exampleInvalidValue5)).to.eql(false); + } + + try { + grammar.parse(JSON.stringify(exampleInvalidValue6)); + expect.unreachable("Parsing should have failed"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Expected one of 3 schemas but got {\"message\":\"Hello, world!\"}]"); + expect(testGrammar(grammar, exampleInvalidValue6)).to.eql(false); + } + }); + }); }); function testGrammar(grammar: LlamaJsonSchemaGrammar, object: any, formattingType: false | "dumps" | "pretty" = false) { @@ -2251,3 +2663,9 @@ function testGrammar(grammar: LlamaJsonSchemaGrammar, object: any, formatti return grammar._testText(JSON.stringify(object) + "\n".repeat(4)); } + +type ExpectTypesMatch = T1 extends T2 + ? T2 extends T1 + ? true + : false + : false;