Skip to content

Commit ca7936f

Browse files
committed
feat: Add automatic function description generation for AI function callbacks
- Generate descriptions from function/method names when none provided - Add warning logs recommending explicit description setting - Add integration tests for auto-generated descriptions
1 parent 0b14cee commit ca7936f

File tree

4 files changed

+78
-4
lines changed

4 files changed

+78
-4
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.chat.model.ChatResponse;
4343
import org.springframework.ai.converter.BeanOutputConverter;
4444
import org.springframework.ai.converter.ListOutputConverter;
45+
import org.springframework.ai.model.function.FunctionCallback;
4546
import org.springframework.beans.factory.annotation.Autowired;
4647
import org.springframework.beans.factory.annotation.Value;
4748
import org.springframework.boot.test.context.SpringBootTest;
@@ -221,6 +222,25 @@ void functionCallTest() {
221222
assertThat(response).contains("30", "10", "15");
222223
}
223224

225+
@Test
226+
void functionCallWithGeneratedDescription() {
227+
228+
// @formatter:off
229+
String response = ChatClient.create(this.chatModel).prompt()
230+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
231+
.functions(FunctionCallback.builder()
232+
.function("getCurrentWeatherInLocation", new MockWeatherService())
233+
.inputType(MockWeatherService.Request.class)
234+
.build())
235+
.call()
236+
.content();
237+
// @formatter:on
238+
239+
logger.info("Response: {}", response);
240+
241+
assertThat(response).contains("30", "10", "15");
242+
}
243+
224244
@Test
225245
void defaultFunctionCallTest() {
226246

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodFunctionCallbackIT.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ void beforeEach() {
5151
arguments.clear();
5252
}
5353

54+
@Test
55+
void methodGetWeatherGeneratedDescription() {
56+
57+
// @formatter:off
58+
String response = ChatClient.create(this.chatModel).prompt()
59+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
60+
.functions(FunctionCallback.builder()
61+
.method("getWeatherInLocation", String.class, Unit.class)
62+
.targetClass(TestFunctionClass.class)
63+
.build())
64+
.call()
65+
.content();
66+
// @formatter:on
67+
68+
logger.info("Response: {}", response);
69+
70+
assertThat(response).contains("30", "10", "15");
71+
}
72+
5473
@Test
5574
void methodGetWeatherStatic() {
5675

@@ -201,6 +220,10 @@ public static void argumentLessReturnVoid() {
201220
arguments.put("method called", "argumentLessReturnVoid");
202221
}
203222

223+
public static String getWeatherInLocation(String city, Unit unit) {
224+
return getWeatherStatic(city, unit);
225+
}
226+
204227
public static String getWeatherStatic(String city, Unit unit) {
205228

206229
logger.info("City: " + city + " Unit: " + unit);

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import com.fasterxml.jackson.databind.ObjectMapper;
2424
import com.fasterxml.jackson.databind.SerializationFeature;
2525
import com.fasterxml.jackson.databind.json.JsonMapper;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
2628

2729
import org.springframework.ai.chat.model.ToolContext;
2830
import org.springframework.ai.model.ModelOptionsUtils;
@@ -31,9 +33,11 @@
3133
import org.springframework.ai.model.function.FunctionCallback.MethodInvokerBuilder;
3234
import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType;
3335
import org.springframework.ai.util.JacksonUtils;
36+
import org.springframework.ai.util.ParsingUtils;
3437
import org.springframework.core.ParameterizedTypeReference;
3538
import org.springframework.util.Assert;
3639
import org.springframework.util.ReflectionUtils;
40+
import org.springframework.util.StringUtils;
3741

3842
/**
3943
* @author Christian Tzolov
@@ -42,6 +46,8 @@
4246

4347
public class DefaultFunctionCallbackBuilder implements FunctionCallback.Builder {
4448

49+
private final static Logger logger = LoggerFactory.getLogger(DefaultFunctionCallbackBuilder.class);
50+
4551
private String description;
4652

4753
private SchemaType schemaType = SchemaType.JSON_SCHEMA;
@@ -151,7 +157,6 @@ public FunctionInvokerBuilder<I, O> inputType(ParameterizedTypeReference<?> inpu
151157
@Override
152158
public FunctionCallback build() {
153159

154-
Assert.hasText(description, "Description must not be empty");
155160
Assert.notNull(objectMapper, "ObjectMapper must not be null");
156161
Assert.hasText(this.name, "Name must not be empty");
157162
Assert.notNull(responseConverter, "ResponseConverter must not be null");
@@ -165,10 +170,17 @@ public FunctionCallback build() {
165170
BiFunction<I, ToolContext, O> finalBiFunction = (this.biFunction != null) ? this.biFunction
166171
: (request, context) -> this.function.apply(request);
167172

168-
return new FunctionCallbackWrapper(this.name, description, inputTypeSchema, this.inputType,
173+
return new FunctionCallbackWrapper(this.name, this.getDescription(), inputTypeSchema, this.inputType,
169174
(Function<I, String>) responseConverter, objectMapper, finalBiFunction);
170175
}
171176

177+
private String getDescription() {
178+
if (StringUtils.hasText(description)) {
179+
return description;
180+
}
181+
return generateDescription(this.name);
182+
}
183+
172184
}
173185

174186
public class MethodInvokerBuilderImpl implements FunctionCallback.MethodInvokerBuilder {
@@ -215,10 +227,29 @@ public FunctionCallback build() {
215227
Assert.isTrue(this.targetClass != null || this.targetObject != null,
216228
"Target class or object must not be null");
217229
var method = ReflectionUtils.findMethod(targetClass, methodName, argumentTypes);
218-
return new MethodFunctionCallback(this.targetObject, method, description, objectMapper, this.name,
230+
return new MethodFunctionCallback(this.targetObject, method, this.getDescription(), objectMapper, this.name,
219231
responseConverter);
220232
}
221233

234+
private String getDescription() {
235+
if (StringUtils.hasText(description)) {
236+
return description;
237+
}
238+
239+
return generateDescription(StringUtils.hasText(this.name) ? this.name : this.methodName);
240+
}
241+
242+
}
243+
244+
private String generateDescription(String fromName) {
245+
246+
String generatedDescription = ParsingUtils.reConcatenateCamelCase(fromName, " ");
247+
248+
logger.warn("Description is not set! A best effort attempt to generate a description:'{}' from the:'{}'",
249+
generatedDescription, fromName);
250+
logger.warn("It is recommended to set the Description explicitly! Use the 'description()' method!");
251+
252+
return generatedDescription;
222253
}
223254

224255
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodFunctionCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public class MethodFunctionCallback implements FunctionCallback {
122122

123123
this.inputSchema = this.generateJsonSchema(methodParameters);
124124

125-
logger.info("Generated JSON Schema: {}", this.inputSchema);
125+
logger.debug("Generated JSON Schema: {}", this.inputSchema);
126126
}
127127

128128
@Override

0 commit comments

Comments
 (0)