Skip to content

Commit cefbc8c

Browse files
authored
Merge pull request #1147 from quarkiverse/#1143
Add support for structured output in OpenAI
2 parents b779f6b + 4751df3 commit cefbc8c

File tree

18 files changed

+310
-94
lines changed

18 files changed

+310
-94
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
1515
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.IGNORE;
1616
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.OPTIONAL_DENY;
17+
import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema;
18+
import static io.quarkiverse.langchain4j.runtime.types.TypeUtil.isMulti;
1719
import static io.quarkus.arc.processor.DotNames.NAMED;
1820

1921
import java.io.IOException;
@@ -61,7 +63,9 @@
6163
import org.objectweb.asm.tree.analysis.AnalyzerException;
6264

6365
import dev.langchain4j.exception.IllegalConfigurationException;
66+
import dev.langchain4j.model.chat.request.json.JsonSchema;
6467
import dev.langchain4j.service.Moderate;
68+
import dev.langchain4j.service.output.JsonSchemas;
6569
import dev.langchain4j.service.output.ServiceOutputParser;
6670
import io.quarkiverse.langchain4j.ModelName;
6771
import io.quarkiverse.langchain4j.RegisterAiService;
@@ -117,6 +121,7 @@
117121
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
118122
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
119123
import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem;
124+
import io.quarkus.deployment.recording.RecorderContext;
120125
import io.quarkus.gizmo.ClassCreator;
121126
import io.quarkus.gizmo.ClassOutput;
122127
import io.quarkus.gizmo.FieldDescriptor;
@@ -922,6 +927,7 @@ public void markIgnoredAnnotations(BuildProducer<MethodParameterIgnoredAnnotatio
922927
public void handleAiServices(
923928
LangChain4jBuildConfig config,
924929
AiServicesRecorder recorder,
930+
RecorderContext recorderContext,
925931
CombinedIndexBuildItem indexBuildItem,
926932
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
927933
List<MethodParameterAllowedAnnotationsBuildItem> methodParameterAllowedAnnotationsItems,
@@ -1178,6 +1184,7 @@ public void handleAiServices(
11781184

11791185
}
11801186

1187+
registerJsonSchema(recorderContext);
11811188
recorder.setMetadata(perClassMetadata);
11821189
}
11831190

@@ -1246,16 +1253,18 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
12461253

12471254
// TODO give user ability to provide custom OutputParser
12481255
String outputFormatInstructions = "";
1249-
if (generateResponseSchema && !returnType.equals(Multi.class))
1256+
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
1257+
if (!returnType.equals(Multi.class)) {
12501258
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
1259+
}
12511260

12521261
List<TemplateParameterInfo> templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates);
12531262
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo = gatherSystemMessageInfo(method, templateParams);
12541263
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams);
12551264

12561265
AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema,
12571266
systemMessageInfo,
1258-
userMessageInfo.template(), outputFormatInstructions);
1267+
userMessageInfo.template(), outputFormatInstructions, jsonSchemaFrom(returnType));
12591268

12601269
if (!generateResponseSchema && responseSchemaInfo.isInSystemMessage())
12611270
throw new RuntimeException(
@@ -1293,6 +1302,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
12931302
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
12941303
}
12951304

1305+
private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
1306+
if (isMulti(returnType)) {
1307+
return Optional.empty();
1308+
}
1309+
return JsonSchemas.jsonSchemaFrom(returnType);
1310+
}
1311+
12961312
private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
12971313
List<String> methodToolClassNames) {
12981314
List<String> allTools = new ArrayList<>(methodToolClassNames);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package io.quarkiverse.langchain4j.deployment;
2+
3+
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
4+
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
5+
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
6+
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
7+
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
8+
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
9+
import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
10+
import dev.langchain4j.model.chat.request.json.JsonSchema;
11+
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
12+
import io.quarkiverse.langchain4j.runtime.substitution.JsonArraySchemaObjectSubstitution;
13+
import io.quarkiverse.langchain4j.runtime.substitution.JsonBooleanSchemaObjectSubstitution;
14+
import io.quarkiverse.langchain4j.runtime.substitution.JsonEnumSchemaObjectSubstitution;
15+
import io.quarkiverse.langchain4j.runtime.substitution.JsonIntegerSchemaObjectSubstitution;
16+
import io.quarkiverse.langchain4j.runtime.substitution.JsonNumberSchemaObjectSubstitution;
17+
import io.quarkiverse.langchain4j.runtime.substitution.JsonObjectSchemaObjectSubstitution;
18+
import io.quarkiverse.langchain4j.runtime.substitution.JsonReferenceSchemaObjectSubstitution;
19+
import io.quarkiverse.langchain4j.runtime.substitution.JsonSchemaObjectSubstitution;
20+
import io.quarkiverse.langchain4j.runtime.substitution.JsonStringSchemaObjectSubstitution;
21+
import io.quarkus.deployment.recording.RecorderContext;
22+
23+
final class ObjectSubstitutionUtil {
24+
25+
private ObjectSubstitutionUtil() {
26+
}
27+
28+
static void registerJsonSchema(RecorderContext recorderContext) {
29+
recorderContext.registerSubstitution(JsonSchema.class, JsonSchemaObjectSubstitution.Serialized.class,
30+
JsonSchemaObjectSubstitution.class);
31+
recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class,
32+
JsonArraySchemaObjectSubstitution.class);
33+
recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class,
34+
JsonBooleanSchemaObjectSubstitution.class);
35+
recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class,
36+
JsonEnumSchemaObjectSubstitution.class);
37+
recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class,
38+
JsonIntegerSchemaObjectSubstitution.class);
39+
recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class,
40+
JsonNumberSchemaObjectSubstitution.class);
41+
recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class,
42+
JsonObjectSchemaObjectSubstitution.class);
43+
recorderContext.registerSubstitution(JsonReferenceSchema.class,
44+
JsonReferenceSchemaObjectSubstitution.Serialized.class,
45+
JsonReferenceSchemaObjectSubstitution.class);
46+
recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class,
47+
JsonStringSchemaObjectSubstitution.class);
48+
}
49+
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING;
77
import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD;
88
import static io.quarkiverse.langchain4j.deployment.DotNames.UNI;
9+
import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema;
910

1011
import java.lang.reflect.Modifier;
1112
import java.util.ArrayList;
@@ -46,21 +47,12 @@
4647
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
4748
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
4849
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
49-
import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
5050
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
5151
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
5252
import dev.langchain4j.model.output.structured.Description;
5353
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
5454
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
5555
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
56-
import io.quarkiverse.langchain4j.runtime.tool.JsonArraySchemaObjectSubstitution;
57-
import io.quarkiverse.langchain4j.runtime.tool.JsonBooleanSchemaObjectSubstitution;
58-
import io.quarkiverse.langchain4j.runtime.tool.JsonEnumSchemaObjectSubstitution;
59-
import io.quarkiverse.langchain4j.runtime.tool.JsonIntegerSchemaObjectSubstitution;
60-
import io.quarkiverse.langchain4j.runtime.tool.JsonNumberSchemaObjectSubstitution;
61-
import io.quarkiverse.langchain4j.runtime.tool.JsonObjectSchemaObjectSubstitution;
62-
import io.quarkiverse.langchain4j.runtime.tool.JsonReferenceSchemaObjectSubstitution;
63-
import io.quarkiverse.langchain4j.runtime.tool.JsonStringSchemaObjectSubstitution;
6456
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
6557
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
6658
import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper;
@@ -342,23 +334,7 @@ public ToolsMetadataBuildItem filterOutRemovedTools(
342334
if (beforeRemoval != null) {
343335
recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class,
344336
ToolSpecificationObjectSubstitution.class);
345-
recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class,
346-
JsonArraySchemaObjectSubstitution.class);
347-
recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class,
348-
JsonBooleanSchemaObjectSubstitution.class);
349-
recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class,
350-
JsonEnumSchemaObjectSubstitution.class);
351-
recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class,
352-
JsonIntegerSchemaObjectSubstitution.class);
353-
recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class,
354-
JsonNumberSchemaObjectSubstitution.class);
355-
recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class,
356-
JsonObjectSchemaObjectSubstitution.class);
357-
recorderContext.registerSubstitution(JsonReferenceSchema.class,
358-
JsonReferenceSchemaObjectSubstitution.Serialized.class,
359-
JsonReferenceSchemaObjectSubstitution.class);
360-
recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class,
361-
JsonStringSchemaObjectSubstitution.class);
337+
registerJsonSchema(recorderContext);
362338
Map<String, List<ToolMethodCreateInfo>> metadataWithoutRemovedBeans = beforeRemoval.getMetadata().entrySet()
363339
.stream()
364340
.filter(entry -> validationPhase.getContext().removedBeans().stream()

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.eclipse.microprofile.config.ConfigProvider;
1212

1313
import dev.langchain4j.agent.tool.ToolSpecification;
14+
import dev.langchain4j.model.chat.request.json.JsonSchema;
1415
import dev.langchain4j.service.tool.ToolExecutor;
1516
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
1617
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
@@ -370,11 +371,12 @@ public record SpanInfo(String name) {
370371
}
371372

372373
public record ResponseSchemaInfo(boolean enabled, boolean isInSystemMessage, Optional<Boolean> isInUserMessage,
373-
String outputFormatInstructions) {
374+
String outputFormatInstructions, Optional<JsonSchema> structuredOutputSchema) {
374375

375376
public static ResponseSchemaInfo of(boolean enabled, Optional<TemplateInfo> systemMessageInfo,
376377
Optional<TemplateInfo> userMessageInfo,
377-
String outputFormatInstructions) {
378+
String outputFormatInstructions,
379+
Optional<JsonSchema> structuredOutputSchema) {
378380

379381
boolean systemMessage = systemMessageInfo.flatMap(TemplateInfo::text)
380382
.map(text -> text.contains(ResponseSchemaUtil.placeholder()))
@@ -385,7 +387,8 @@ public static ResponseSchemaInfo of(boolean enabled, Optional<TemplateInfo> syst
385387
userMessage = Optional.of(userMessageInfo.get().text.get().contains(ResponseSchemaUtil.placeholder()));
386388
}
387389

388-
return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions);
390+
return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions,
391+
structuredOutputSchema);
389392
}
390393
}
391394
}

0 commit comments

Comments
 (0)