Skip to content

Commit 0fa2a96

Browse files
authored
Merge pull request #1039 from Tarjei400/main
Adjusted method signature mapping to json schema, to allow collections in tool arguments
2 parents ed12693 + b8417fb commit 0fa2a96

File tree

3 files changed

+135
-7
lines changed

3 files changed

+135
-7
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import org.jboss.jandex.AnnotationValue;
2727
import org.jboss.jandex.ClassInfo;
2828
import org.jboss.jandex.DotName;
29+
import org.jboss.jandex.FieldInfo;
2930
import org.jboss.jandex.IndexView;
3031
import org.jboss.jandex.MethodInfo;
3132
import org.jboss.jandex.MethodParameterInfo;
33+
import org.jboss.jandex.ParameterizedType;
3234
import org.jboss.jandex.Type;
3335
import org.jboss.logging.Logger;
3436
import org.objectweb.asm.ClassVisitor;
@@ -406,11 +408,27 @@ private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOu
406408

407409
private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) {
408410
Type type = parameter.type();
409-
DotName typeName = parameter.type().name();
410-
411411
AnnotationInstance pInstance = parameter.annotation(P);
412+
412413
JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString());
413414

415+
return toJsonSchemaProperties(type, index, description);
416+
}
417+
418+
private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView index, JsonSchemaProperty description) {
419+
DotName typeName = type.name();
420+
421+
if (type.kind() == Type.Kind.WILDCARD_TYPE) {
422+
Type boundType = type.asWildcardType().extendsBound();
423+
if (boundType == null) {
424+
boundType = type.asWildcardType().superBound();
425+
}
426+
if (boundType != null) {
427+
return toJsonSchemaProperties(boundType, index, description);
428+
} else {
429+
throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type);
430+
}
431+
}
414432
if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName)
415433
|| DotNames.PRIMITIVE_CHAR.equals(typeName)) {
416434
return removeNulls(STRING, description);
@@ -435,17 +453,64 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(MethodParameterInfo
435453
return removeNulls(NUMBER, description);
436454
}
437455

438-
if ((type.kind() == Type.Kind.ARRAY)
439-
|| DotNames.LIST.equals(typeName)
440-
|| DotNames.SET.equals(typeName)) { // TODO something else?
441-
return removeNulls(ARRAY, description); // TODO provide type of array?
456+
// TODO something else?
457+
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
458+
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
459+
: null;
460+
461+
Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0)
462+
: type.asArrayType().component();
463+
464+
Iterable<JsonSchemaProperty> elementProperties = toJsonSchemaProperties(elementType, index, null);
465+
466+
JsonSchemaProperty itemsSchema;
467+
if (isComplexType(elementType)) {
468+
Map<String, Object> fieldDescription = new HashMap<>();
469+
470+
for (JsonSchemaProperty fieldProperty : elementProperties) {
471+
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
472+
}
473+
itemsSchema = JsonSchemaProperty.from("items", fieldDescription);
474+
} else {
475+
itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next());
476+
}
477+
478+
return removeNulls(ARRAY, itemsSchema, description);
442479
}
443480

444481
if (isEnum(type, index)) {
445482
return removeNulls(STRING, enums(enumConstants(type)), description);
446483
}
447484

448-
return removeNulls(OBJECT, description); // TODO provide internals
485+
if (type.kind() == Type.Kind.CLASS) {
486+
Map<String, Object> properties = new HashMap<>();
487+
ClassInfo classInfo = index.getClassByName(type.name());
488+
489+
List<String> required = new ArrayList<>();
490+
if (classInfo != null) {
491+
for (FieldInfo field : classInfo.fields()) {
492+
String fieldName = field.name();
493+
494+
Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
495+
Map<String, Object> fieldDescription = new HashMap<>();
496+
497+
for (JsonSchemaProperty fieldProperty : fieldSchema) {
498+
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
499+
}
500+
501+
properties.put(fieldName, fieldDescription);
502+
}
503+
}
504+
505+
JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties);
506+
return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description);
507+
}
508+
509+
throw new IllegalArgumentException("Unsupported type: " + type);
510+
}
511+
512+
private boolean isComplexType(Type type) {
513+
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
449514
}
450515

451516
private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {

integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.acme.example.openai.aiservices;
22

3+
import java.util.List;
34
import java.util.Map;
45
import java.util.concurrent.ConcurrentHashMap;
56

@@ -26,6 +27,18 @@ public AssistantWithToolsResource(Assistant assistant) {
2627
this.assistant = assistant;
2728
}
2829

30+
public static class TestData {
31+
String foo;
32+
Integer bar;
33+
Double baz;
34+
35+
TestData(String foo, Integer bar, Double baz) {
36+
this.foo = foo;
37+
this.bar = bar;
38+
this.baz = baz;
39+
}
40+
}
41+
2942
@GET
3043
public String get(@RestQuery String message) {
3144
return assistant.chat(message);
@@ -54,6 +67,25 @@ int add(int a, int b) {
5467
double sqrt(int x) {
5568
return Math.sqrt(x);
5669
}
70+
71+
@Tool("Calculates the the sum of all provided numbers")
72+
double sumAll(List<Double> x) {
73+
74+
return x.stream().reduce(0.0, (a, b) -> a + b);
75+
}
76+
77+
@Tool("Evaluate test data object")
78+
public TestData evaluateTestObject(List<TestData> data) {
79+
return new TestData("Empty", 0, 0.0);
80+
}
81+
82+
@Tool("Calculates all factors of the provided integer.")
83+
List<Integer> getFactors(int x) {
84+
return java.util.stream.IntStream.rangeClosed(1, x)
85+
.filter(i -> x % i == 0)
86+
.boxed()
87+
.toList();
88+
}
5789
}
5890

5991
@RequestScoped
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.acme.example.openai.aiservices;
2+
3+
import static io.restassured.RestAssured.given;
4+
import static org.hamcrest.Matchers.containsString;
5+
6+
import java.net.URL;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
import io.quarkus.test.common.http.TestHTTPEndpoint;
11+
import io.quarkus.test.common.http.TestHTTPResource;
12+
import io.quarkus.test.junit.QuarkusTest;
13+
14+
@QuarkusTest
15+
public class AssistantResourceWithToolsTest {
16+
17+
@TestHTTPEndpoint(AssistantWithToolsResource.class)
18+
@TestHTTPResource
19+
URL url;
20+
21+
@Test
22+
public void get() {
23+
given()
24+
.baseUri(url.toString())
25+
.queryParam("message", "This is a test")
26+
.get()
27+
.then()
28+
.statusCode(200)
29+
.body(containsString("MockGPT"));
30+
}
31+
}

0 commit comments

Comments
 (0)