1+ /*
2+ * Copyright 2024 - 2024 the original author or authors.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * https://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+ package org .springframework .ai .model .function ;
17+
18+ import java .lang .reflect .Method ;
19+ import java .lang .reflect .Modifier ;
20+ import java .util .List ;
21+ import java .util .Map ;
22+ import java .util .stream .Collectors ;
23+ import java .util .stream .Stream ;
24+
25+ import org .slf4j .Logger ;
26+ import org .slf4j .LoggerFactory ;
27+ import org .springframework .ai .model .ModelOptionsUtils ;
28+ import org .springframework .util .Assert ;
29+
30+ import com .fasterxml .jackson .databind .JsonNode ;
31+ import com .fasterxml .jackson .databind .ObjectMapper ;
32+ import com .fasterxml .jackson .databind .node .ObjectNode ;
33+ import com .fasterxml .jackson .module .jsonSchema .JsonSchema ;
34+ import com .fasterxml .jackson .module .jsonSchema .JsonSchemaGenerator ;
35+
36+ /**
37+ * A {@link FunctionCallback} implementation that invokes a method on a given object. It
38+ * supports both static and non-static methods.
39+ *
40+ * Supports methods with arbitrary number of input parameters and methods with void return
41+ * type.
42+ *
43+ * @author Christian Tzolov
44+ * @since 1.0.0
45+ */
46+ public class MethodFunctionCallback implements FunctionCallback {
47+
48+ private static Logger logger = LoggerFactory .getLogger (MethodFunctionCallback .class );
49+
50+ /**
51+ * Object instance that contains the method to be invoked. If the method is static
52+ * this object can be null.
53+ */
54+ private final Object functionObject ;
55+
56+ /**
57+ * The method to be invoked.
58+ */
59+ private final Method method ;
60+
61+ /**
62+ * Description to help the LLM model to understand woth the method does and when to
63+ * use it.
64+ */
65+ private final String description ;
66+
67+ /**
68+ * Internal ObjectMapper used to serialize/deserialize the method input and output.
69+ */
70+ private final ObjectMapper mapper ;
71+
72+ /**
73+ * The JSON schema generated from the method input parameters.
74+ */
75+ private final String inputSchema ;
76+
77+ public MethodFunctionCallback (Object functionObject , Method method , String description , ObjectMapper mapper ) {
78+
79+ Assert .notNull (method , "Method must not be null" );
80+ Assert .notNull (mapper , "ObjectMapper must not be null" );
81+ Assert .hasText (description , "Description must not be empty" );
82+
83+ this .method = method ;
84+ this .description = description ;
85+ this .mapper = mapper ;
86+ this .functionObject = functionObject ;
87+
88+ Assert .isTrue (this .functionObject != null || Modifier .isStatic (this .method .getModifiers ()),
89+ "Function object must be provided for non-static methods!" );
90+
91+ // Generate the JSON schema from the method input parameters
92+ Map <String , Class <?>> methodParameters = Stream .of (method .getParameters ())
93+ .collect (Collectors .toMap (param -> param .getName (), param -> param .getType ()));
94+
95+ this .inputSchema = this .generateJsonSchema (methodParameters );
96+
97+ logger .info ("Generated JSON Schema: \n :" + this .inputSchema );
98+ }
99+
100+ @ Override
101+ public String getName () {
102+ return method .getName ();
103+ }
104+
105+ @ Override
106+ public String getDescription () {
107+ return this .description ;
108+ }
109+
110+ @ Override
111+ public String getInputTypeSchema () {
112+ return this .inputSchema ;
113+ }
114+
115+ @ Override
116+ public String call (String functionInput ) {
117+
118+ try {
119+
120+ @ SuppressWarnings ("unchecked" )
121+ Map <String , Object > map = this .mapper .readValue (functionInput , Map .class );
122+
123+ Object [] methodArgs = Stream .of (this .method .getParameters ()).map (parameter -> {
124+ Object rawValue = map .get (parameter .getName ());
125+ Class <?> type = parameter .getType ();
126+ return this .toJavaType (rawValue , type );
127+ }).toArray ();
128+
129+ Object response = this .method .invoke (this .functionObject , methodArgs );
130+
131+ var returnType = this .method .getReturnType ();
132+ if (returnType == Void .TYPE ) {
133+ return "Done" ;
134+ }
135+
136+ if (returnType == Class .class || returnType .isRecord () || returnType == List .class
137+ || returnType == Map .class ) {
138+ return ModelOptionsUtils .toJsonString (response );
139+
140+ }
141+ return "" + response ;
142+ }
143+ catch (Exception e ) {
144+ throw new RuntimeException (e );
145+ }
146+
147+ }
148+
149+ /**
150+ * Generates a JSON schema from the given named classes.
151+ * @param namedClasses The named classes to generate the schema from.
152+ * @return The generated JSON schema.
153+ */
154+ protected String generateJsonSchema (Map <String , Class <?>> namedClasses ) {
155+ try {
156+ JsonSchemaGenerator schemaGen = new JsonSchemaGenerator (this .mapper );
157+
158+ ObjectNode rootNode = this .mapper .createObjectNode ();
159+ rootNode .put ("$schema" , "https://json-schema.org/draft/2020-12/schema" );
160+ rootNode .put ("type" , "object" );
161+ ObjectNode propertiesNode = rootNode .putObject ("properties" );
162+
163+ for (Map .Entry <String , Class <?>> entry : namedClasses .entrySet ()) {
164+ String className = entry .getKey ();
165+ Class <?> clazz = entry .getValue ();
166+
167+ JsonSchema schema = schemaGen .generateSchema (clazz );
168+ JsonNode schemaNode = this .mapper .valueToTree (schema );
169+ propertiesNode .set (className , schemaNode );
170+ }
171+
172+ return this .mapper .writerWithDefaultPrettyPrinter ().writeValueAsString (rootNode );
173+ }
174+ catch (Exception e ) {
175+ throw new RuntimeException (e );
176+ }
177+ }
178+
179+ /**
180+ * Converts the given value to the specified Java type.
181+ * @param value The value to convert.
182+ * @param javaType The Java type to convert to.
183+ * @return Returns the converted value.
184+ */
185+ protected Object toJavaType (Object value , Class <?> javaType ) {
186+
187+ if (value == null ) {
188+ return null ;
189+ }
190+ if (javaType == String .class ) {
191+ return value .toString ();
192+ }
193+ else if (javaType == Integer .class || javaType == int .class ) {
194+ return Integer .parseInt (value .toString ());
195+ }
196+ else if (javaType == Long .class || javaType == long .class ) {
197+ return Long .parseLong (value .toString ());
198+ }
199+ else if (javaType == Double .class || javaType == double .class ) {
200+ return Double .parseDouble (value .toString ());
201+ }
202+ else if (javaType == Float .class || javaType == float .class ) {
203+ return Float .parseFloat (value .toString ());
204+ }
205+ else if (javaType == Boolean .class || javaType == boolean .class ) {
206+ return Boolean .parseBoolean (value .toString ());
207+ }
208+ else if (javaType .isEnum ()) {
209+ return Enum .valueOf ((Class <Enum >) javaType , value .toString ());
210+ }
211+ // else if (type == Class.class || type.isRecord()) {
212+ // return ModelOptionsUtils.mapToClass((Map<String, Object>) value, type);
213+ // }
214+
215+ try {
216+ String json = new ObjectMapper ().writeValueAsString (value );
217+ return this .mapper .readValue (json , javaType );
218+ }
219+ catch (Exception e ) {
220+ throw new RuntimeException (e );
221+ }
222+ }
223+
224+ /**
225+ * Creates a new {@link Builder} for the {@link MethodFunctionCallback}.
226+ * @return The builder.
227+ */
228+ public static MethodFunctionCallback .Builder builder () {
229+ return new Builder ();
230+ }
231+
232+ /**
233+ * Builder for the {@link MethodFunctionCallback}.
234+ */
235+ public static class Builder {
236+
237+ private Method method ;
238+
239+ private String description ;
240+
241+ private ObjectMapper mapper = new ObjectMapper ();
242+
243+ private Object functionObject = null ;
244+
245+ public MethodFunctionCallback .Builder withFunctionObject (Object functionObject ) {
246+ this .functionObject = functionObject ;
247+ return this ;
248+ }
249+
250+ public MethodFunctionCallback .Builder withMethod (Method method ) {
251+ Assert .notNull (method , "Method must not be null" );
252+ this .method = method ;
253+ return this ;
254+ }
255+
256+ public MethodFunctionCallback .Builder withDescription (String description ) {
257+ Assert .hasText (description , "Description must not be empty" );
258+ this .description = description ;
259+ return this ;
260+ }
261+
262+ public MethodFunctionCallback .Builder withMapper (ObjectMapper mapper ) {
263+ this .mapper = mapper ;
264+ return this ;
265+ }
266+
267+ public MethodFunctionCallback build () {
268+ return new MethodFunctionCallback (this .functionObject , this .method , this .description , this .mapper );
269+ }
270+
271+ }
272+
273+ }
0 commit comments