Skip to content

Commit 73205e4

Browse files
committed
Add function calling support to invoke methods with dynamic arguments and return values
- Introduced `MethodFunctionCallback` class to invoke static and non-static methods with arbitrary parameters and handle method results, including void and complex return types. - Integrated Jackson for JSON schema generation based on method parameters. - Added support for dynamic argument mapping using ObjectMapper for method invocations. - Updated `pom.xml` files to include `jackson-module-jsonSchema` dependency. - Added unit tests to verify `MethodFunctionCallback` behavior with static and non-static methods. Dependencies: - Added `jackson-module-jsonSchema` version 2.17.2 to manage JSON schema generation.
1 parent c3c95a8 commit 73205e4

File tree

4 files changed

+459
-2
lines changed

4 files changed

+459
-2
lines changed

pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@
175175
<victools.version>4.31.1</victools.version>
176176
<kotlin.version>1.9.25</kotlin.version>
177177

178+
<jackson-module-jsonSchema.version>2.17.2</jackson-module-jsonSchema.version>
179+
178180
<!-- NOTE: keep them align -->
179181
<bedrockruntime.version>2.26.7</bedrockruntime.version>
180182
<awssdk.version>2.26.7</awssdk.version>

spring-ai-core/pom.xml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242

4343
<dependencies>
4444

45+
<dependency>
46+
<groupId>com.fasterxml.jackson.module</groupId>
47+
<artifactId>jackson-module-jsonSchema</artifactId>
48+
<version>${jackson-module-jsonSchema.version}</version>
49+
</dependency>
50+
4551
<dependency>
4652
<groupId>io.swagger.core.v3</groupId>
4753
<artifactId>swagger-annotations</artifactId>
@@ -171,7 +177,8 @@
171177
<configuration>
172178
<sourceDirectory>${basedir}/src/main/resources/antlr4</sourceDirectory>
173179
<outputDirectory>${basedir}/src/main/java</outputDirectory>
174-
<!-- <outputDirectory>${project.build.directory}/generated-sources/antlr4</outputDirectory> -->
180+
<!--
181+
<outputDirectory>${project.build.directory}/generated-sources/antlr4</outputDirectory> -->
175182
<visitor>true</visitor>
176183
</configuration>
177184
<executions>
@@ -188,4 +195,4 @@
188195
</profiles>
189196

190197

191-
</project>
198+
</project>
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model.function;
17+
18+
import java.lang.reflect.Method;
19+
import java.lang.reflect.Modifier;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.stream.Collectors;
23+
import java.util.stream.Stream;
24+
25+
import org.slf4j.Logger;
26+
import org.slf4j.LoggerFactory;
27+
import org.springframework.ai.model.ModelOptionsUtils;
28+
import org.springframework.util.Assert;
29+
30+
import com.fasterxml.jackson.databind.JsonNode;
31+
import com.fasterxml.jackson.databind.ObjectMapper;
32+
import com.fasterxml.jackson.databind.node.ObjectNode;
33+
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
34+
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
35+
36+
/**
37+
* A {@link FunctionCallback} implementation that invokes a method on a given object. It
38+
* supports both static and non-static methods.
39+
*
40+
* Supports methods with arbitrary number of input parameters and methods with void return
41+
* type.
42+
*
43+
* @author Christian Tzolov
44+
* @since 1.0.0
45+
*/
46+
public class MethodFunctionCallback implements FunctionCallback {
47+
48+
private static Logger logger = LoggerFactory.getLogger(MethodFunctionCallback.class);
49+
50+
/**
51+
* Object instance that contains the method to be invoked. If the method is static
52+
* this object can be null.
53+
*/
54+
private final Object functionObject;
55+
56+
/**
57+
* The method to be invoked.
58+
*/
59+
private final Method method;
60+
61+
/**
62+
* Description to help the LLM model to understand woth the method does and when to
63+
* use it.
64+
*/
65+
private final String description;
66+
67+
/**
68+
* Internal ObjectMapper used to serialize/deserialize the method input and output.
69+
*/
70+
private final ObjectMapper mapper;
71+
72+
/**
73+
* The JSON schema generated from the method input parameters.
74+
*/
75+
private final String inputSchema;
76+
77+
public MethodFunctionCallback(Object functionObject, Method method, String description, ObjectMapper mapper) {
78+
79+
Assert.notNull(method, "Method must not be null");
80+
Assert.notNull(mapper, "ObjectMapper must not be null");
81+
Assert.hasText(description, "Description must not be empty");
82+
83+
this.method = method;
84+
this.description = description;
85+
this.mapper = mapper;
86+
this.functionObject = functionObject;
87+
88+
Assert.isTrue(this.functionObject != null || Modifier.isStatic(this.method.getModifiers()),
89+
"Function object must be provided for non-static methods!");
90+
91+
// Generate the JSON schema from the method input parameters
92+
Map<String, Class<?>> methodParameters = Stream.of(method.getParameters())
93+
.collect(Collectors.toMap(param -> param.getName(), param -> param.getType()));
94+
95+
this.inputSchema = this.generateJsonSchema(methodParameters);
96+
97+
logger.info("Generated JSON Schema: \n:" + this.inputSchema);
98+
}
99+
100+
@Override
101+
public String getName() {
102+
return method.getName();
103+
}
104+
105+
@Override
106+
public String getDescription() {
107+
return this.description;
108+
}
109+
110+
@Override
111+
public String getInputTypeSchema() {
112+
return this.inputSchema;
113+
}
114+
115+
@Override
116+
public String call(String functionInput) {
117+
118+
try {
119+
120+
@SuppressWarnings("unchecked")
121+
Map<String, Object> map = this.mapper.readValue(functionInput, Map.class);
122+
123+
Object[] methodArgs = Stream.of(this.method.getParameters()).map(parameter -> {
124+
Object rawValue = map.get(parameter.getName());
125+
Class<?> type = parameter.getType();
126+
return this.toJavaType(rawValue, type);
127+
}).toArray();
128+
129+
Object response = this.method.invoke(this.functionObject, methodArgs);
130+
131+
var returnType = this.method.getReturnType();
132+
if (returnType == Void.TYPE) {
133+
return "Done";
134+
}
135+
136+
if (returnType == Class.class || returnType.isRecord() || returnType == List.class
137+
|| returnType == Map.class) {
138+
return ModelOptionsUtils.toJsonString(response);
139+
140+
}
141+
return "" + response;
142+
}
143+
catch (Exception e) {
144+
throw new RuntimeException(e);
145+
}
146+
147+
}
148+
149+
/**
150+
* Generates a JSON schema from the given named classes.
151+
* @param namedClasses The named classes to generate the schema from.
152+
* @return The generated JSON schema.
153+
*/
154+
protected String generateJsonSchema(Map<String, Class<?>> namedClasses) {
155+
try {
156+
JsonSchemaGenerator schemaGen = new JsonSchemaGenerator(this.mapper);
157+
158+
ObjectNode rootNode = this.mapper.createObjectNode();
159+
rootNode.put("$schema", "https://json-schema.org/draft/2020-12/schema");
160+
rootNode.put("type", "object");
161+
ObjectNode propertiesNode = rootNode.putObject("properties");
162+
163+
for (Map.Entry<String, Class<?>> entry : namedClasses.entrySet()) {
164+
String className = entry.getKey();
165+
Class<?> clazz = entry.getValue();
166+
167+
JsonSchema schema = schemaGen.generateSchema(clazz);
168+
JsonNode schemaNode = this.mapper.valueToTree(schema);
169+
propertiesNode.set(className, schemaNode);
170+
}
171+
172+
return this.mapper.writerWithDefaultPrettyPrinter().writeValueAsString(rootNode);
173+
}
174+
catch (Exception e) {
175+
throw new RuntimeException(e);
176+
}
177+
}
178+
179+
/**
180+
* Converts the given value to the specified Java type.
181+
* @param value The value to convert.
182+
* @param javaType The Java type to convert to.
183+
* @return Returns the converted value.
184+
*/
185+
protected Object toJavaType(Object value, Class<?> javaType) {
186+
187+
if (value == null) {
188+
return null;
189+
}
190+
if (javaType == String.class) {
191+
return value.toString();
192+
}
193+
else if (javaType == Integer.class || javaType == int.class) {
194+
return Integer.parseInt(value.toString());
195+
}
196+
else if (javaType == Long.class || javaType == long.class) {
197+
return Long.parseLong(value.toString());
198+
}
199+
else if (javaType == Double.class || javaType == double.class) {
200+
return Double.parseDouble(value.toString());
201+
}
202+
else if (javaType == Float.class || javaType == float.class) {
203+
return Float.parseFloat(value.toString());
204+
}
205+
else if (javaType == Boolean.class || javaType == boolean.class) {
206+
return Boolean.parseBoolean(value.toString());
207+
}
208+
else if (javaType.isEnum()) {
209+
return Enum.valueOf((Class<Enum>) javaType, value.toString());
210+
}
211+
// else if (type == Class.class || type.isRecord()) {
212+
// return ModelOptionsUtils.mapToClass((Map<String, Object>) value, type);
213+
// }
214+
215+
try {
216+
String json = new ObjectMapper().writeValueAsString(value);
217+
return this.mapper.readValue(json, javaType);
218+
}
219+
catch (Exception e) {
220+
throw new RuntimeException(e);
221+
}
222+
}
223+
224+
/**
225+
* Creates a new {@link Builder} for the {@link MethodFunctionCallback}.
226+
* @return The builder.
227+
*/
228+
public static MethodFunctionCallback.Builder builder() {
229+
return new Builder();
230+
}
231+
232+
/**
233+
* Builder for the {@link MethodFunctionCallback}.
234+
*/
235+
public static class Builder {
236+
237+
private Method method;
238+
239+
private String description;
240+
241+
private ObjectMapper mapper = new ObjectMapper();
242+
243+
private Object functionObject = null;
244+
245+
public MethodFunctionCallback.Builder withFunctionObject(Object functionObject) {
246+
this.functionObject = functionObject;
247+
return this;
248+
}
249+
250+
public MethodFunctionCallback.Builder withMethod(Method method) {
251+
Assert.notNull(method, "Method must not be null");
252+
this.method = method;
253+
return this;
254+
}
255+
256+
public MethodFunctionCallback.Builder withDescription(String description) {
257+
Assert.hasText(description, "Description must not be empty");
258+
this.description = description;
259+
return this;
260+
}
261+
262+
public MethodFunctionCallback.Builder withMapper(ObjectMapper mapper) {
263+
this.mapper = mapper;
264+
return this;
265+
}
266+
267+
public MethodFunctionCallback build() {
268+
return new MethodFunctionCallback(this.functionObject, this.method, this.description, this.mapper);
269+
}
270+
271+
}
272+
273+
}

0 commit comments

Comments
 (0)