Skip to content

Commit ab6121f

Browse files
committed
Add handling for Optional types
1 parent bd88ad9 commit ab6121f

File tree

6 files changed

+160
-53
lines changed

6 files changed

+160
-53
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ public class ToolProcessor {
8383
Object.class);
8484
private static final Logger log = Logger.getLogger(ToolProcessor.class);
8585

86+
public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional");
87+
public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt");
88+
public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong");
89+
public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble");
90+
8691
@BuildStep
8792
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
8893
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
@@ -488,11 +493,18 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
488493
ClassInfo classInfo = index.getClassByName(type.name());
489494

490495
List<String> required = new ArrayList<>();
496+
491497
if (classInfo != null) {
492498
for (FieldInfo field : classInfo.fields()) {
493499
String fieldName = field.name();
500+
Type fieldType = field.type();
494501

495-
Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
502+
boolean isOptional = isJavaOptionalType(fieldType);
503+
if (isOptional) {
504+
fieldType = unwrapOptionalType(fieldType);
505+
}
506+
507+
Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(fieldType, index, null);
496508
Map<String, Object> fieldDescription = new HashMap<>();
497509

498510
for (JsonSchemaProperty fieldProperty : fieldSchema) {
@@ -506,6 +518,10 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
506518
fieldDescription.put("description", String.join(",", descriptionValue));
507519
}
508520
}
521+
if (!isOptional) {
522+
required.add(fieldName);
523+
}
524+
509525
properties.put(fieldName, fieldDescription);
510526
}
511527
}
@@ -517,10 +533,39 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
517533
throw new IllegalArgumentException("Unsupported type: " + type);
518534
}
519535

536+
private boolean isJavaOptionalType(Type type) {
537+
DotName typeName = type.name();
538+
return typeName.equals(DotName.createSimple("java.util.Optional"))
539+
|| typeName.equals(DotName.createSimple("java.util.OptionalInt"))
540+
|| typeName.equals(DotName.createSimple("java.util.OptionalLong"))
541+
|| typeName.equals(DotName.createSimple("java.util.OptionalDouble"));
542+
}
543+
544+
private Type unwrapOptionalType(Type optionalType) {
545+
if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) {
546+
ParameterizedType parameterizedType = optionalType.asParameterizedType();
547+
return parameterizedType.arguments().get(0);
548+
}
549+
return optionalType;
550+
}
551+
520552
private boolean isComplexType(Type type) {
521553
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
522554
}
523555

556+
private boolean isOptionalField(FieldInfo field, IndexView index) {
557+
Type fieldType = field.type();
558+
DotName fieldTypeName = fieldType.name();
559+
560+
if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName)
561+
|| OPTIONAL_DOUBLE.equals(fieldTypeName)) {
562+
return true;
563+
}
564+
565+
return false;
566+
567+
}
568+
524569
private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
525570
return stream(properties)
526571
.filter(Objects::nonNull)

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import dev.langchain4j.service.Result;
1414
import dev.langchain4j.service.TokenStream;
1515
import dev.langchain4j.service.TypeUtils;
16-
//import dev.langchain4j.service.output.OutputParser;
1716
import dev.langchain4j.service.output.ServiceOutputParser;
1817
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
1918
import io.smallrye.mutiny.Multi;
@@ -23,13 +22,17 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser {
2322

2423
@Override
2524
public String outputFormatInstructions(Type returnType) {
26-
Class<?> rawClass = getRawClass(returnType);
25+
boolean isOptional = isJavaOptional(returnType);
26+
Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType;
27+
28+
Class<?> rawClass = getRawClass(actualType);
2729

2830
if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class
2931
&& rawClass != Response.class && !Multi.class.equals(rawClass)) {
3032
try {
3133
var schema = this.toJsonSchema(returnType);
32-
return "You must answer strictly with json according to the following json schema format: " + schema;
34+
return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: "
35+
+ schema;
3336
} catch (Exception e) {
3437
return "";
3538
}
@@ -77,7 +80,10 @@ private String extractJsonBlock(String text) {
7780

7881
public String toJsonSchema(Type type) throws Exception {
7982
Map<String, Object> schema = new HashMap<>();
80-
Class<?> rawClass = getRawClass(type);
83+
boolean isOptional = isJavaOptional(type);
84+
Type actualType = isOptional ? unwrapOptionalType(type) : type;
85+
86+
Class<?> rawClass = getRawClass(actualType);
8187

8288
if (type instanceof WildcardType wildcardType) {
8389
Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0]
@@ -104,22 +110,64 @@ public String toJsonSchema(Type type) throws Exception {
104110
schema.put("type", "object");
105111
Map<String, Object> properties = new HashMap<>();
106112

113+
List<String> required = new ArrayList<>();
107114
for (Field field : rawClass.getDeclaredFields()) {
108-
field.setAccessible(true);
109-
Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType());
110-
properties.put(field.getName(), fieldSchema);
111-
if (field.isAnnotationPresent(Description.class)) {
112-
Description description = field.getAnnotation(Description.class);
113-
fieldSchema.put("description", description.value());
115+
try {
116+
field.setAccessible(true);
117+
Type fieldType = field.getGenericType();
118+
119+
// Check if the field is Optional and unwrap it if necessary
120+
boolean fieldIsOptional = isJavaOptional(fieldType);
121+
Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType;
122+
123+
Map<String, Object> fieldSchema = toJsonSchemaMap(fieldActualType);
124+
properties.put(field.getName(), fieldSchema);
125+
126+
if (field.isAnnotationPresent(Description.class)) {
127+
Description description = field.getAnnotation(Description.class);
128+
fieldSchema.put("description", String.join(",", description.value()));
129+
}
130+
131+
// Only add to required if it is not Optional
132+
if (!fieldIsOptional) {
133+
required.add(field.getName());
134+
} else {
135+
fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema
136+
}
137+
138+
} catch (Exception e) {
139+
114140
}
141+
115142
}
116143
schema.put("properties", properties);
144+
if (!required.isEmpty()) {
145+
schema.put("required", required);
146+
}
147+
}
148+
if (isOptional) {
149+
schema.put("nullable", true);
117150
}
118-
119151
ObjectMapper mapper = new ObjectMapper();
120152
return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string
121153
}
122154

155+
private boolean isJavaOptional(Type type) {
156+
if (type instanceof ParameterizedType) {
157+
Type rawType = ((ParameterizedType) type).getRawType();
158+
return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class
159+
|| rawType == OptionalDouble.class;
160+
}
161+
return false;
162+
}
163+
164+
private Type unwrapOptionalType(Type optionalType) {
165+
if (optionalType instanceof ParameterizedType) {
166+
return ((ParameterizedType) optionalType).getActualTypeArguments()[0];
167+
}
168+
return optionalType;
169+
}
170+
123171
private Class<?> getRawClass(Type type) {
124172
if (type instanceof Class<?>) {
125173
return (Class<?>) type;

integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22

33
import java.util.List;
44
import java.util.Map;
5+
import java.util.Optional;
56
import java.util.concurrent.ConcurrentHashMap;
7+
import java.util.function.Supplier;
68

9+
import com.fasterxml.jackson.annotation.JsonProperty;
10+
import dev.langchain4j.data.message.AiMessage;
11+
import dev.langchain4j.model.chat.ChatLanguageModel;
12+
import dev.langchain4j.model.output.Response;
713
import jakarta.annotation.PreDestroy;
814
import jakarta.enterprise.context.RequestScoped;
915
import jakarta.inject.Singleton;
@@ -30,38 +36,40 @@ public AssistantWithToolsResource(Assistant assistant) {
3036

3137
public static class TestData {
3238
@Description("Foo description for structured output")
39+
@JsonProperty("foo")
3340
String foo;
3441

3542
@Description("Foo description for structured output")
43+
@JsonProperty("bar")
3644
Integer bar;
3745

3846
@Description("Foo description for structured output")
39-
Double baz;
47+
@JsonProperty("baz")
48+
Optional<Double> baz;
49+
50+
51+
public TestData() {
52+
}
4053

4154
TestData(String foo, Integer bar, Double baz) {
4255
this.foo = foo;
4356
this.bar = bar;
44-
this.baz = baz;
57+
this.baz = Optional.of(baz);
4558
}
4659
}
4760

61+
4862
@GET
4963
public String get(@RestQuery String message) {
5064
return assistant.chat(message);
5165
}
5266

53-
@GET
54-
@Path("/many")
55-
public List<TestData> getMany(@RestQuery String message) {
56-
return assistant.chats(message);
57-
}
5867

5968
@RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class)
6069
public interface Assistant {
6170

6271
String chat(String userMessage);
6372

64-
List<TestData> chats(String userMessage);
6573
}
6674

6775
@Singleton

integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22

33
import java.util.ArrayList;
44
import java.util.List;
5+
import java.util.Optional;
6+
import java.util.function.Supplier;
57

8+
import com.fasterxml.jackson.annotation.JsonProperty;
9+
import dev.langchain4j.data.message.AiMessage;
10+
import dev.langchain4j.model.chat.ChatLanguageModel;
11+
import dev.langchain4j.model.output.Response;
612
import jakarta.ws.rs.POST;
713
import jakarta.ws.rs.Path;
814

@@ -23,26 +29,41 @@ public EntityMappedResource(EntityMappedDescriber describer) {
2329

2430
public static class TestData {
2531
@Description("Foo description for structured output")
32+
@JsonProperty("foo")
2633
String foo;
2734

2835
@Description("Foo description for structured output")
36+
@JsonProperty("bar")
2937
Integer bar;
3038

3139
@Description("Foo description for structured output")
32-
Double baz;
40+
@JsonProperty("baz")
41+
Optional<Double> baz;
42+
43+
44+
public TestData() {
45+
}
3346

3447
TestData(String foo, Integer bar, Double baz) {
3548
this.foo = foo;
3649
this.bar = bar;
37-
this.baz = baz;
50+
this.baz = Optional.of(baz);
3851
}
3952
}
4053

41-
@POST
42-
public List<String> generate(@RestQuery String message) {
43-
var result = describer.describe(message);
44-
45-
return result;
54+
public static class MirrorModelSupplier implements Supplier<ChatLanguageModel> {
55+
@Override
56+
public ChatLanguageModel get() {
57+
return (messages) -> new Response<>(new AiMessage("""
58+
[
59+
{
60+
"foo": "asd",
61+
"bar": 1,
62+
"baz": 2.0
63+
}
64+
]
65+
"""));
66+
}
4667
}
4768

4869
@POST
@@ -51,14 +72,16 @@ public List<TestData> generateMapped(@RestQuery String message) {
5172
List<TestData> inputs = new ArrayList<>();
5273
inputs.add(new TestData(message, 100, 100.0));
5374

54-
return describer.describeMapped(inputs);
75+
var test = describer.describeMapped(inputs);
76+
return test;
5577
}
5678

57-
@RegisterAiService
79+
80+
81+
@RegisterAiService(chatLanguageModelSupplier = MirrorModelSupplier.class)
5882
public interface EntityMappedDescriber {
5983

60-
@UserMessage("This is a describer returning a collection of strings")
61-
List<String> describe(String url);
84+
6285

6386
@UserMessage("This is a describer returning a collection of mapped entities")
6487
List<TestData> describeMapped(List<TestData> inputs);
Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.acme.example.openai.aiservices;
22

33
import static io.restassured.RestAssured.given;
4-
import static org.hamcrest.Matchers.containsString;
4+
import static org.hamcrest.Matchers.*;
55

66
import java.net.URL;
77

@@ -18,16 +18,6 @@ public class AssistantResourceWithEntityMappingTest {
1818
@TestHTTPResource
1919
URL url;
2020

21-
@Test
22-
public void get() {
23-
given()
24-
.baseUri(url.toString())
25-
.queryParam("message", "This is a test")
26-
.post()
27-
.then()
28-
.statusCode(200)
29-
.body(containsString("MockGPT"));
30-
}
3121

3222
@Test
3323
public void getMany() {
@@ -37,6 +27,9 @@ public void getMany() {
3727
.post()
3828
.then()
3929
.statusCode(200)
40-
.body(containsString("MockGPT"));
30+
.body("$", hasSize(1)) // Ensure that the response is an array with exactly one item
31+
.body("[0].foo", equalTo("asd")) // Check that foo is set correctly
32+
.body("[0].bar", equalTo(1)) // Check that bar is 100
33+
.body("[0].baz", equalTo(2.0F));
4134
}
4235
}

integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,4 @@ public void get() {
2929
.body(containsString("MockGPT"));
3030
}
3131

32-
@Test
33-
public void getMany() {
34-
given()
35-
.baseUri(url.toString() + "/many")
36-
.queryParam("message", "This is a test")
37-
.get()
38-
.then()
39-
.statusCode(200)
40-
.body(containsString("MockGPT"));
41-
}
4232
}

0 commit comments

Comments
 (0)