1+ /*
2+ * Copyright 2024 - 2024 the original author or authors.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * https://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+ package org .springframework .ai .model .function ;
17+
18+ import java .lang .reflect .Method ;
19+ import java .lang .reflect .Modifier ;
20+ import java .util .List ;
21+ import java .util .Map ;
22+ import java .util .stream .Collectors ;
23+ import java .util .stream .Stream ;
24+
25+ import org .slf4j .Logger ;
26+ import org .slf4j .LoggerFactory ;
27+ import org .springframework .ai .model .ModelOptionsUtils ;
28+ import org .springframework .util .Assert ;
29+
30+ import com .fasterxml .jackson .databind .JsonNode ;
31+ import com .fasterxml .jackson .databind .ObjectMapper ;
32+ import com .fasterxml .jackson .databind .node .ObjectNode ;
33+ import com .fasterxml .jackson .module .jsonSchema .JsonSchema ;
34+ import com .fasterxml .jackson .module .jsonSchema .JsonSchemaGenerator ;
35+
36+ /**
37+ * A {@link FunctionCallback} implementation that invokes a method on a given object. It
38+ * supports both static and non-static methods. Aslo it supports methods with arbitrary
39+ * number of input parameters and methods with void return type.
40+ *
41+ * @author Christian Tzolov
42+ * @since 1.0.0
43+ */
44+ public class MethodFunctionCallback implements FunctionCallback {
45+
46+ private static Logger logger = LoggerFactory .getLogger (MethodFunctionCallback .class );
47+
48+ /**
49+ * Object instance that contains the method to be invoked. If the method is static
50+ * this object can be null.
51+ */
52+ private final Object functionObject ;
53+
54+ /**
55+ * The method to be invoked.
56+ */
57+ private final Method method ;
58+
59+ /**
60+ * Description to help the LLM model to understand woth the method does and when to
61+ * use it.
62+ */
63+ private final String description ;
64+
65+ /**
66+ * Internal ObjectMapper used to serialize/deserialize the method input and output.
67+ */
68+ private final ObjectMapper mapper ;
69+
70+ /**
71+ * The JSON schema generated from the method input parameters.
72+ */
73+ private final String inptuSchema ;
74+
75+ public MethodFunctionCallback (Object functionObject , Method method , String description , ObjectMapper mapper ) {
76+
77+ Assert .notNull (method , "Method must not be null" );
78+ Assert .notNull (mapper , "ObjectMapper must not be null" );
79+ Assert .hasText (description , "Description must not be empty" );
80+
81+ this .method = method ;
82+ this .description = description ;
83+ this .mapper = mapper ;
84+ this .functionObject = functionObject ;
85+
86+ Assert .isTrue (this .functionObject != null || Modifier .isStatic (this .method .getModifiers ()),
87+ "Function object must be provided for non-static methods!" );
88+
89+ // Generate the JSON schema from the method input parameters
90+ Map <String , Class <?>> methodParameters = Stream .of (method .getParameters ())
91+ .collect (Collectors .toMap (param -> param .getName (), param -> param .getType ()));
92+ this .inptuSchema = generateJsonSchema (methodParameters );
93+
94+ logger .info ("Generated JSON Schema: \n :" + this .inptuSchema );
95+ }
96+
97+ @ Override
98+ public String getName () {
99+ return method .getName ();
100+ }
101+
102+ @ Override
103+ public String getDescription () {
104+ return this .description ;
105+ }
106+
107+ @ Override
108+ public String getInputTypeSchema () {
109+ return this .inptuSchema ;
110+ }
111+
112+ @ Override
113+ public String call (String functionInput ) {
114+
115+ try {
116+
117+ Map <String , Object > map = this .mapper .readValue (functionInput , Map .class );
118+
119+ Object [] methodArgs = Stream .of (this .method .getParameters ()).map (parameter -> {
120+ Object rawValue = map .get (parameter .getName ());
121+ Class <?> type = parameter .getType ();
122+ return this .toJavaType (rawValue , type );
123+ }).toArray ();
124+
125+ Object response = this .method .invoke (this .functionObject , methodArgs );
126+
127+ var returnType = this .method .getReturnType ();
128+ if (returnType == Void .TYPE ) {
129+ return "Done" ;
130+ }
131+
132+ if (returnType == Class .class || returnType .isRecord () || returnType == List .class
133+ || returnType == Map .class ) {
134+ return ModelOptionsUtils .toJsonString (response );
135+
136+ }
137+ return "" + response ;
138+ }
139+ catch (Exception e ) {
140+ throw new RuntimeException (e );
141+ }
142+
143+ }
144+
145+ /**
146+ * Generates a JSON schema from the given named classes.
147+ * @param namedClasses The named classes to generate the schema from.
148+ * @return The generated JSON schema.
149+ */
150+ protected String generateJsonSchema (Map <String , Class <?>> namedClasses ) {
151+ try {
152+ JsonSchemaGenerator schemaGen = new JsonSchemaGenerator (this .mapper );
153+
154+ ObjectNode rootNode = this .mapper .createObjectNode ();
155+ rootNode .put ("$schema" , "https://json-schema.org/draft/2020-12/schema" );
156+ rootNode .put ("type" , "object" );
157+ ObjectNode propertiesNode = rootNode .putObject ("properties" );
158+
159+ for (Map .Entry <String , Class <?>> entry : namedClasses .entrySet ()) {
160+ String className = entry .getKey ();
161+ Class <?> clazz = entry .getValue ();
162+
163+ JsonSchema schema = schemaGen .generateSchema (clazz );
164+ JsonNode schemaNode = this .mapper .valueToTree (schema );
165+ propertiesNode .set (className , schemaNode );
166+ }
167+
168+ return this .mapper .writerWithDefaultPrettyPrinter ().writeValueAsString (rootNode );
169+ }
170+ catch (Exception e ) {
171+ throw new RuntimeException (e );
172+ }
173+ }
174+
175+ /**
176+ * Converts the given value to the specified Java type.
177+ * @param value The value to convert.
178+ * @param javaType The Java type to convert to.
179+ * @return Returns the converted value.
180+ */
181+ protected Object toJavaType (Object value , Class <?> javaType ) {
182+
183+ if (value == null ) {
184+ return null ;
185+ }
186+ if (javaType == String .class ) {
187+ return value .toString ();
188+ }
189+ else if (javaType == Integer .class || javaType == int .class ) {
190+ return Integer .parseInt (value .toString ());
191+ }
192+ else if (javaType == Long .class || javaType == long .class ) {
193+ return Long .parseLong (value .toString ());
194+ }
195+ else if (javaType == Double .class || javaType == double .class ) {
196+ return Double .parseDouble (value .toString ());
197+ }
198+ else if (javaType == Float .class || javaType == float .class ) {
199+ return Float .parseFloat (value .toString ());
200+ }
201+ else if (javaType == Boolean .class || javaType == boolean .class ) {
202+ return Boolean .parseBoolean (value .toString ());
203+ }
204+ else if (javaType .isEnum ()) {
205+ return Enum .valueOf ((Class <Enum >) javaType , value .toString ());
206+ }
207+ // else if (type == Class.class || type.isRecord()) {
208+ // return ModelOptionsUtils.mapToClass((Map<String, Object>) value, type);
209+ // }
210+
211+ try {
212+ String json = new ObjectMapper ().writeValueAsString (value );
213+ return this .mapper .readValue (json , javaType );
214+ }
215+ catch (Exception e ) {
216+ throw new RuntimeException (e );
217+ }
218+ }
219+
220+ /**
221+ * Creates a new {@link Builder} for the {@link MethodFunctionCallback}.
222+ * @return The builder.
223+ */
224+ public static MethodFunctionCallback .Builder builder () {
225+ return new Builder ();
226+ }
227+
228+ /**
229+ * Builder for the {@link MethodFunctionCallback}.
230+ */
231+ public static class Builder {
232+
233+ private Method method ;
234+
235+ private String description ;
236+
237+ private ObjectMapper mapper = new ObjectMapper ();
238+
239+ private Object functionObject = null ;
240+
241+ public MethodFunctionCallback .Builder withFunctionObject (Object functionObject ) {
242+ this .functionObject = functionObject ;
243+ return this ;
244+ }
245+
246+ public MethodFunctionCallback .Builder withMethod (Method method ) {
247+ Assert .notNull (method , "Method must not be null" );
248+ this .method = method ;
249+ return this ;
250+ }
251+
252+ public MethodFunctionCallback .Builder withDescription (String description ) {
253+ Assert .hasText (description , "Description must not be empty" );
254+ this .description = description ;
255+ return this ;
256+ }
257+
258+ public MethodFunctionCallback .Builder withMapper (ObjectMapper mapper ) {
259+ this .mapper = mapper ;
260+ return this ;
261+ }
262+
263+ public MethodFunctionCallback build () {
264+ return new MethodFunctionCallback (this .functionObject , this .method , this .description , this .mapper );
265+ }
266+
267+ }
268+
269+ }
0 commit comments