Skip to content

Commit 8ed5c06

Browse files
committed
Support function calling for reflection Method with multiple arguments
1 parent 05292ac commit 8ed5c06

File tree

4 files changed

+457
-4
lines changed

4 files changed

+457
-4
lines changed

pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
<azure-open-ai-client.version>1.0.0-beta.10</azure-open-ai-client.version>
154154
<jtokkit.version>1.1.0</jtokkit.version>
155155
<victools.version>4.31.1</victools.version>
156+
<jackson-module-jsonSchema.version>2.17.2</jackson-module-jsonSchema.version>
156157

157158
<!-- NOTE: keep them align -->
158159
<bedrockruntime.version>2.26.7</bedrockruntime.version>

spring-ai-core/pom.xml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
<?xml version="1.0" encoding="UTF-8"?>
22
<project xmlns="http://maven.apache.org/POM/4.0.0"
3-
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
45
<modelVersion>4.0.0</modelVersion>
56
<parent>
67
<groupId>org.springframework.ai</groupId>
@@ -25,6 +26,12 @@
2526

2627
<dependencies>
2728

29+
<dependency>
30+
<groupId>com.fasterxml.jackson.module</groupId>
31+
<artifactId>jackson-module-jsonSchema</artifactId>
32+
<version>${jackson-module-jsonSchema.version}</version>
33+
</dependency>
34+
2835
<dependency>
2936
<groupId>io.swagger.core.v3</groupId>
3037
<artifactId>swagger-annotations</artifactId>
@@ -127,7 +134,7 @@
127134
<scope>test</scope>
128135
</dependency>
129136

130-
</dependencies>
137+
</dependencies>
131138

132139
<profiles>
133140
<profile>
@@ -144,7 +151,8 @@
144151
<configuration>
145152
<sourceDirectory>${basedir}/src/main/resources/antlr4</sourceDirectory>
146153
<outputDirectory>${basedir}/src/main/java</outputDirectory>
147-
<!-- <outputDirectory>${project.build.directory}/generated-sources/antlr4</outputDirectory> -->
154+
<!--
155+
<outputDirectory>${project.build.directory}/generated-sources/antlr4</outputDirectory> -->
148156
<visitor>true</visitor>
149157
</configuration>
150158
<executions>
@@ -161,4 +169,4 @@
161169
</profiles>
162170

163171

164-
</project>
172+
</project>
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.function;
17+
18+
import java.lang.reflect.Method;
19+
import java.lang.reflect.Modifier;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.stream.Collectors;
23+
import java.util.stream.Stream;
24+
25+
import org.slf4j.Logger;
26+
import org.slf4j.LoggerFactory;
27+
import org.springframework.ai.model.ModelOptionsUtils;
28+
import org.springframework.util.Assert;
29+
30+
import com.fasterxml.jackson.databind.JsonNode;
31+
import com.fasterxml.jackson.databind.ObjectMapper;
32+
import com.fasterxml.jackson.databind.node.ObjectNode;
33+
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
34+
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
35+
36+
/**
37+
* A {@link FunctionCallback} implementation that invokes a method on a given object. It
38+
* supports both static and non-static methods. Aslo it supports methods with arbitrary
39+
* number of input parameters and methods with void return type.
40+
*
41+
* @author Christian Tzolov
42+
* @since 1.0.0
43+
*/
44+
public class MethodFunctionCallback implements FunctionCallback {
45+
46+
private static Logger logger = LoggerFactory.getLogger(MethodFunctionCallback.class);
47+
48+
/**
49+
* Object instance that contains the method to be invoked. If the method is static
50+
* this object can be null.
51+
*/
52+
private final Object functionObject;
53+
54+
/**
55+
* The method to be invoked.
56+
*/
57+
private final Method method;
58+
59+
/**
60+
* Description to help the LLM model to understand woth the method does and when to
61+
* use it.
62+
*/
63+
private final String description;
64+
65+
/**
66+
* Internal ObjectMapper used to serialize/deserialize the method input and output.
67+
*/
68+
private final ObjectMapper mapper;
69+
70+
/**
71+
* The JSON schema generated from the method input parameters.
72+
*/
73+
private final String inptuSchema;
74+
75+
public MethodFunctionCallback(Object functionObject, Method method, String description, ObjectMapper mapper) {
76+
77+
Assert.notNull(method, "Method must not be null");
78+
Assert.notNull(mapper, "ObjectMapper must not be null");
79+
Assert.hasText(description, "Description must not be empty");
80+
81+
this.method = method;
82+
this.description = description;
83+
this.mapper = mapper;
84+
this.functionObject = functionObject;
85+
86+
Assert.isTrue(this.functionObject != null || Modifier.isStatic(this.method.getModifiers()),
87+
"Function object must be provided for non-static methods!");
88+
89+
// Generate the JSON schema from the method input parameters
90+
Map<String, Class<?>> methodParameters = Stream.of(method.getParameters())
91+
.collect(Collectors.toMap(param -> param.getName(), param -> param.getType()));
92+
this.inptuSchema = generateJsonSchema(methodParameters);
93+
94+
logger.info("Generated JSON Schema: \n:" + this.inptuSchema);
95+
}
96+
97+
@Override
98+
public String getName() {
99+
return method.getName();
100+
}
101+
102+
@Override
103+
public String getDescription() {
104+
return this.description;
105+
}
106+
107+
@Override
108+
public String getInputTypeSchema() {
109+
return this.inptuSchema;
110+
}
111+
112+
@Override
113+
public String call(String functionInput) {
114+
115+
try {
116+
117+
Map<String, Object> map = this.mapper.readValue(functionInput, Map.class);
118+
119+
Object[] methodArgs = Stream.of(this.method.getParameters()).map(parameter -> {
120+
Object rawValue = map.get(parameter.getName());
121+
Class<?> type = parameter.getType();
122+
return this.toJavaType(rawValue, type);
123+
}).toArray();
124+
125+
Object response = this.method.invoke(this.functionObject, methodArgs);
126+
127+
var returnType = this.method.getReturnType();
128+
if (returnType == Void.TYPE) {
129+
return "Done";
130+
}
131+
132+
if (returnType == Class.class || returnType.isRecord() || returnType == List.class
133+
|| returnType == Map.class) {
134+
return ModelOptionsUtils.toJsonString(response);
135+
136+
}
137+
return "" + response;
138+
}
139+
catch (Exception e) {
140+
throw new RuntimeException(e);
141+
}
142+
143+
}
144+
145+
/**
146+
* Generates a JSON schema from the given named classes.
147+
* @param namedClasses The named classes to generate the schema from.
148+
* @return The generated JSON schema.
149+
*/
150+
protected String generateJsonSchema(Map<String, Class<?>> namedClasses) {
151+
try {
152+
JsonSchemaGenerator schemaGen = new JsonSchemaGenerator(this.mapper);
153+
154+
ObjectNode rootNode = this.mapper.createObjectNode();
155+
rootNode.put("$schema", "https://json-schema.org/draft/2020-12/schema");
156+
rootNode.put("type", "object");
157+
ObjectNode propertiesNode = rootNode.putObject("properties");
158+
159+
for (Map.Entry<String, Class<?>> entry : namedClasses.entrySet()) {
160+
String className = entry.getKey();
161+
Class<?> clazz = entry.getValue();
162+
163+
JsonSchema schema = schemaGen.generateSchema(clazz);
164+
JsonNode schemaNode = this.mapper.valueToTree(schema);
165+
propertiesNode.set(className, schemaNode);
166+
}
167+
168+
return this.mapper.writerWithDefaultPrettyPrinter().writeValueAsString(rootNode);
169+
}
170+
catch (Exception e) {
171+
throw new RuntimeException(e);
172+
}
173+
}
174+
175+
/**
176+
* Converts the given value to the specified Java type.
177+
* @param value The value to convert.
178+
* @param javaType The Java type to convert to.
179+
* @return Returns the converted value.
180+
*/
181+
protected Object toJavaType(Object value, Class<?> javaType) {
182+
183+
if (value == null) {
184+
return null;
185+
}
186+
if (javaType == String.class) {
187+
return value.toString();
188+
}
189+
else if (javaType == Integer.class || javaType == int.class) {
190+
return Integer.parseInt(value.toString());
191+
}
192+
else if (javaType == Long.class || javaType == long.class) {
193+
return Long.parseLong(value.toString());
194+
}
195+
else if (javaType == Double.class || javaType == double.class) {
196+
return Double.parseDouble(value.toString());
197+
}
198+
else if (javaType == Float.class || javaType == float.class) {
199+
return Float.parseFloat(value.toString());
200+
}
201+
else if (javaType == Boolean.class || javaType == boolean.class) {
202+
return Boolean.parseBoolean(value.toString());
203+
}
204+
else if (javaType.isEnum()) {
205+
return Enum.valueOf((Class<Enum>) javaType, value.toString());
206+
}
207+
// else if (type == Class.class || type.isRecord()) {
208+
// return ModelOptionsUtils.mapToClass((Map<String, Object>) value, type);
209+
// }
210+
211+
try {
212+
String json = new ObjectMapper().writeValueAsString(value);
213+
return this.mapper.readValue(json, javaType);
214+
}
215+
catch (Exception e) {
216+
throw new RuntimeException(e);
217+
}
218+
}
219+
220+
/**
221+
* Creates a new {@link Builder} for the {@link MethodFunctionCallback}.
222+
* @return The builder.
223+
*/
224+
public static MethodFunctionCallback.Builder builder() {
225+
return new Builder();
226+
}
227+
228+
/**
229+
* Builder for the {@link MethodFunctionCallback}.
230+
*/
231+
public static class Builder {
232+
233+
private Method method;
234+
235+
private String description;
236+
237+
private ObjectMapper mapper = new ObjectMapper();
238+
239+
private Object functionObject = null;
240+
241+
public MethodFunctionCallback.Builder withFunctionObject(Object functionObject) {
242+
this.functionObject = functionObject;
243+
return this;
244+
}
245+
246+
public MethodFunctionCallback.Builder withMethod(Method method) {
247+
Assert.notNull(method, "Method must not be null");
248+
this.method = method;
249+
return this;
250+
}
251+
252+
public MethodFunctionCallback.Builder withDescription(String description) {
253+
Assert.hasText(description, "Description must not be empty");
254+
this.description = description;
255+
return this;
256+
}
257+
258+
public MethodFunctionCallback.Builder withMapper(ObjectMapper mapper) {
259+
this.mapper = mapper;
260+
return this;
261+
}
262+
263+
public MethodFunctionCallback build() {
264+
return new MethodFunctionCallback(this.functionObject, this.method, this.description, this.mapper);
265+
}
266+
267+
}
268+
269+
}

0 commit comments

Comments
 (0)