Skip to content

Commit a31c6c5

Browse files
committed
fix: handle $defs in a $ref object
1 parent a342b60 commit a31c6c5

13 files changed

+77
-45
lines changed

src/evaluator/LlamaChat/utils/FunctionCallParamsGrammar.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ export class FunctionCallParamsGrammar<const Functions extends ChatModelFunction
6666
function getGbnfGrammarForFunctionParams(paramsSchema: GbnfJsonSchema): string {
6767
const grammarGenerator = new GbnfGrammarGenerator();
6868
const rootTerminal = getGbnfJsonTerminalForGbnfJsonSchema(paramsSchema, grammarGenerator);
69-
const rootGrammar = rootTerminal.getGrammar(grammarGenerator);
69+
const rootGrammar = rootTerminal.resolve(grammarGenerator, true);
7070

7171
return grammarGenerator.generateGbnfFile(rootGrammar + ` "${"\\n".repeat(4)}"`);
7272
}

src/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ export type ChatSessionModelFunctions = {
193193

194194
export type ChatSessionModelFunction<Params extends GbnfJsonSchema | undefined = GbnfJsonSchema | undefined> = {
195195
readonly description?: string,
196-
readonly params?: Readonly<Params>,
197-
readonly handler: (params: GbnfJsonSchemaToType<Params>) => any
196+
readonly params?: Params,
197+
readonly handler: (params: GbnfJsonSchemaToType<NoInfer<Params>>) => any
198198
};
199199

200200
export function isChatModelResponseFunctionCall(item: ChatModelResponse["response"][number] | undefined): item is ChatModelFunctionCall {

src/utils/gbnfJson/GbnfGrammarGenerator.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ import {MultiKeyMap} from "lifecycle-utils";
22
import {GbnfJsonSchema} from "./types.js";
33

44
export class GbnfGrammarGenerator {
5-
public rules = new Map<string, string | null>();
5+
public rules = new Map<string, string>();
66
public ruleContentToRuleName = new Map<string, string>();
77
public literalValueRuleNames = new Map<string | number, string>();
8-
public defRuleNames = new MultiKeyMap<[string, GbnfJsonSchema], string>();
8+
public defRuleNames = new MultiKeyMap<[string, GbnfJsonSchema], string | null>();
99
public defScopeDefs = new MultiKeyMap<[string, GbnfJsonSchema], Record<string, GbnfJsonSchema>>();
10+
public usedRootRuleName: boolean = false;
1011
private ruleId: number = 0;
1112
private valueRuleId: number = 0;
1213
private defRuleId: number = 0;
@@ -31,17 +32,17 @@ export class GbnfGrammarGenerator {
3132
return ruleName;
3233
}
3334

34-
public generateRuleNameForDef(defName: string, def: GbnfJsonSchema): [created: boolean, ruleName: string] {
35+
public generateRuleNameForDef(defName: string, def: GbnfJsonSchema): string {
3536
const existingRuleName = this.defRuleNames.get([defName, def]);
3637
if (existingRuleName != null)
37-
return [false, existingRuleName];
38+
return existingRuleName;
3839

3940
const ruleName = `def${this.defRuleId}`;
4041
this.defRuleId++;
4142

4243
this.defRuleNames.set([defName, def], ruleName);
4344

44-
return [true, ruleName];
45+
return ruleName;
4546
}
4647

4748
public registerDefs(scopeDefs: Record<string, GbnfJsonSchema>) {

src/utils/gbnfJson/GbnfTerminal.ts

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,22 @@ export abstract class GbnfTerminal {
2525
return this.getGrammar(grammarGenerator);
2626
}
2727

28-
public resolve(grammarGenerator: GbnfGrammarGenerator): string {
28+
private _getRootRuleName(grammarGenerator: GbnfGrammarGenerator) {
29+
if (this._ruleName != null)
30+
return this._ruleName;
31+
32+
const ruleName = grammarGenerator.usedRootRuleName
33+
? this.getRuleName(grammarGenerator)
34+
: "root";
35+
this._ruleName = ruleName;
36+
37+
if (ruleName === "root")
38+
grammarGenerator.usedRootRuleName = true;
39+
40+
return ruleName;
41+
}
42+
43+
public resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string {
2944
if (this._ruleName != null)
3045
return this._ruleName;
3146

@@ -37,7 +52,12 @@ export abstract class GbnfTerminal {
3752
return existingRuleName;
3853
}
3954

40-
const ruleName = this.getRuleName(grammarGenerator);
55+
const ruleName = resolveAsRootGrammar
56+
? this._getRootRuleName(grammarGenerator)
57+
: this.getRuleName(grammarGenerator);
58+
59+
if (resolveAsRootGrammar)
60+
return grammar;
4161

4262
if (grammar === ruleName) {
4363
this._ruleName = ruleName;

src/utils/gbnfJson/getGbnfGrammarForGbnfJsonSchema.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export function getGbnfGrammarForGbnfJsonSchema(schema: Readonly<GbnfJsonSchema>
1414
const grammarGenerator = new GbnfGrammarGenerator();
1515
const scopeState = new GbnfJsonScopeState({allowNewLines, scopePadSpaces});
1616
const rootTerminal = getGbnfJsonTerminalForGbnfJsonSchema(schema, grammarGenerator, scopeState);
17-
const rootGrammar = rootTerminal.getGrammar(grammarGenerator);
17+
const rootGrammar = rootTerminal.resolve(grammarGenerator, true);
1818

1919
return grammarGenerator.generateGbnfFile(rootGrammar + ` "${"\\n".repeat(4)}"` + " [\\n]*");
2020
}

src/utils/gbnfJson/terminals/GbnfGrammar.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ export class GbnfGrammar extends GbnfTerminal {
2121
return this.grammar;
2222
}
2323

24-
public override resolve(grammarGenerator: GbnfGrammarGenerator): string {
24+
public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string {
2525
if (this.resolveToRawGrammar)
2626
return this.getGrammar();
2727

28-
return super.resolve(grammarGenerator);
28+
return super.resolve(grammarGenerator, resolveAsRootGrammar);
2929
}
3030
}

src/utils/gbnfJson/terminals/GbnfNumberValue.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ export class GbnfNumberValue extends GbnfTerminal {
1414
return '"' + JSON.stringify(this.value) + '"';
1515
}
1616

17-
public override resolve(grammarGenerator: GbnfGrammarGenerator): string {
17+
public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string {
1818
const grammar = this.getGrammar();
1919
if (grammar.length <= grammarGenerator.getProposedLiteralValueRuleNameLength())
2020
return grammar;
2121

22-
return super.resolve(grammarGenerator);
22+
return super.resolve(grammarGenerator, resolveAsRootGrammar);
2323
}
2424

2525
protected override generateRuleName(grammarGenerator: GbnfGrammarGenerator): string {

src/utils/gbnfJson/terminals/GbnfOr.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export class GbnfOr extends GbnfTerminal {
3030
return "( " + mappedValues.join(" | ") + " )";
3131
}
3232

33-
public override resolve(grammarGenerator: GbnfGrammarGenerator): string {
33+
public override resolve(grammarGenerator: GbnfGrammarGenerator, resolveAsRootGrammar: boolean = false): string {
3434
const mappedValues = this.values
3535
.map((v) => v.resolve(grammarGenerator))
3636
.filter((value) => value !== "" && value !== grammarNoValue);
@@ -40,6 +40,6 @@ export class GbnfOr extends GbnfTerminal {
4040
else if (mappedValues.length === 1)
4141
return mappedValues[0]!;
4242

43-
return super.resolve(grammarGenerator);
43+
return super.resolve(grammarGenerator, resolveAsRootGrammar);
4444
}
4545
}

src/utils/gbnfJson/terminals/GbnfRef.ts

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ export class GbnfRef extends GbnfTerminal {
77
public readonly getValueTerminal: () => GbnfTerminal;
88
public readonly defName: string;
99
public readonly def: GbnfJsonSchema;
10-
private _valueTerminal?: GbnfTerminal;
11-
private _grammar?: string;
1210

1311
public constructor({
1412
getValueTerminal,
@@ -26,28 +24,30 @@ export class GbnfRef extends GbnfTerminal {
2624
}
2725

2826
public override getGrammar(grammarGenerator: GbnfGrammarGenerator): string {
29-
this._createRule(grammarGenerator);
30-
31-
if (this._valueTerminal != null)
32-
return this._valueTerminal.getGrammar(grammarGenerator);
33-
else if (this._grammar != null)
34-
return this._grammar;
35-
36-
return this.getValueTerminal().getGrammar(grammarGenerator);
27+
return this.generateRuleName(grammarGenerator);
3728
}
3829

3930
protected override generateRuleName(grammarGenerator: GbnfGrammarGenerator): string {
40-
return this._createRule(grammarGenerator);
41-
}
31+
if (!grammarGenerator.defRuleNames.has([this.defName, this.def])) {
32+
const alreadyGeneratingGrammarForThisRef = grammarGenerator.defRuleNames.get([this.defName, this.def]) === null;
33+
if (alreadyGeneratingGrammarForThisRef)
34+
return grammarGenerator.generateRuleNameForDef(this.defName, this.def);
35+
36+
grammarGenerator.defRuleNames.set([this.defName, this.def], null);
37+
const grammar = this.getValueTerminal().resolve(grammarGenerator);
38+
39+
if (grammarGenerator.rules.has(grammar) && grammarGenerator.defRuleNames.get([this.defName, this.def]) === null) {
40+
grammarGenerator.defRuleNames.set([this.defName, this.def], grammar);
41+
return grammar;
42+
}
43+
44+
const ruleName = grammarGenerator.generateRuleNameForDef(this.defName, this.def);
45+
grammarGenerator.rules.set(ruleName, grammar);
46+
grammarGenerator.ruleContentToRuleName.set(grammar, ruleName);
4247

43-
private _createRule(grammarGenerator: GbnfGrammarGenerator) {
44-
const [isNew, ruleName] = grammarGenerator.generateRuleNameForDef(this.defName, this.def);
45-
if (!isNew) {
46-
this._grammar = ruleName;
4748
return ruleName;
4849
}
4950

50-
this._valueTerminal = this.getValueTerminal();
51-
return ruleName;
51+
return grammarGenerator.generateRuleNameForDef(this.defName, this.def);
5252
}
5353
}

src/utils/gbnfJson/types.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ export type GbnfJsonRefSchema<Defs extends GbnfJsonDefList<NoInfer<Defs>> = {}>
180180
*
181181
* Only passed to the model when using function calling, and has no effect when using JSON Schema grammar directly.
182182
*/
183-
readonly description?: string
183+
readonly description?: string,
184+
185+
readonly $defs?: Defs
184186
};
185187

186188

@@ -211,7 +213,7 @@ export type GbnfJsonSchemaToTSType<T, Defs extends GbnfJsonDefList<NoInfer<Defs>
211213
: T extends GbnfJsonArraySchema<Record<any, any>>
212214
? ArrayTypeToType<T, CombineDefs<NoInfer<Defs>, T["$defs"]>>
213215
: T extends GbnfJsonRefSchema<any>
214-
? GbnfJsonRefSchemaToType<T, NoInfer<Defs>>
216+
? GbnfJsonRefSchemaToType<T, CombineDefs<NoInfer<Defs>, T["$defs"]>>
215217
: undefined;
216218

217219
type GbnfJsonBasicStringSchemaToType<T extends GbnfJsonBasicStringSchema> =

0 commit comments

Comments
 (0)