diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index 16ae7199f40..97d98b22b91 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -28,6 +29,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -59,6 +61,25 @@ class OpenAiChatModelFunctionCallingIT { @Autowired ChatModel chatModel; + @Test + void functionCallSupplier() { + + Map state = new ConcurrentHashMap<>(); + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light on in the living room") + .functions(FunctionCallback.builder() + .function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON")) + .build()) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + assertThat(state).containsEntry("Light", "ON"); + } + @Test void functionCallTest() { functionCallTest(OpenAiChatOptions.builder() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index e538f8ee3f6..1d56c04d3fd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -395,6 +395,11 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues } ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(inputType); + + if ((inputType == Void.class) && !node.has("properties")) { + node.putObject("properties"); + } + if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI // version of it). toUpperCaseTypeValues(node); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java index 2e68130b0ff..9c947d79e1f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java @@ -18,7 +18,9 @@ import java.lang.reflect.Type; import java.util.Arrays; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -43,7 +45,7 @@ /** * Default implementation of the {@link FunctionCallback.Builder}. - * + * * @author Christian Tzolov * @since 1.0.0 */ @@ -137,6 +139,20 @@ public FunctionInvokingSpec function(String name, BiFunction(name, biFunction); } + @Override + public FunctionInvokingSpec function(String name, Supplier supplier) { + Function function = (input) -> supplier.get(); + return new DefaultFunctionInvokingSpec<>(name, function).inputType(Void.class); + } + + public FunctionInvokingSpec function(String name, Consumer consumer) { + Function function = (I input) -> { + consumer.accept(input); + return null; + }; + return new DefaultFunctionInvokingSpec<>(name, function); + } + @Override public MethodInvokingSpec method(String methodName, Class... argumentTypes) { return new DefaultMethodInvokingSpec(methodName, argumentTypes); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index 4416e239db9..0bd31b380fd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -17,7 +17,9 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.databind.ObjectMapper; @@ -141,6 +143,16 @@ interface Builder { */ FunctionInvokingSpec function(String name, BiFunction biFunction); + /** + * Builds a {@link Supplier} invoking {@link FunctionCallback} instance. + */ + FunctionInvokingSpec function(String name, Supplier supplier); + + /** + * Builds a {@link Consumer} invoking {@link FunctionCallback} instance. + */ + FunctionInvokingSpec function(String name, Consumer consumer); + /** * Builds a Method invoking {@link FunctionCallback} instance. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index d83123314ed..f36ac5bbf81 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -17,9 +17,12 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.annotation.JsonClassDescription; +import kotlin.jvm.functions.Function0; import kotlin.jvm.functions.Function1; import kotlin.jvm.functions.Function2; @@ -30,6 +33,7 @@ import org.springframework.context.annotation.Description; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.KotlinDetector; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; @@ -71,7 +75,8 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) { ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName); - ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0); + ResolvableType functionInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(functionType)) + ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(functionType, 0); Class functionInputClass = functionInputType.toClass(); String functionDescription = defaultDescription; @@ -109,7 +114,7 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable .schemaType(this.schemaType) .description(functionDescription) .function(beanName, KotlinDelegate.wrapKotlinFunction(bean)) - .inputType(functionInputClass) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) .build(); } else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { @@ -117,7 +122,15 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { .description(functionDescription) .schemaType(this.schemaType) .function(beanName, KotlinDelegate.wrapKotlinBiFunction(bean)) - .inputType(functionInputClass) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) + .build(); + } + else if (KotlinDelegate.isKotlinSupplier(functionType.toClass())) { + return FunctionCallback.builder() + .description(functionDescription) + .schemaType(this.schemaType) + .function(beanName, KotlinDelegate.wrapKotlinSupplier(bean)) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) .build(); } } @@ -126,7 +139,7 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { .schemaType(this.schemaType) .description(functionDescription) .function(beanName, function) - .inputType(functionInputClass) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) .build(); } else if (bean instanceof BiFunction) { @@ -134,7 +147,23 @@ else if (bean instanceof BiFunction) { .description(functionDescription) .schemaType(this.schemaType) .function(beanName, (BiFunction) bean) - .inputType(functionInputClass) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) + .build(); + } + else if (bean instanceof Supplier supplier) { + return FunctionCallback.builder() + .description(functionDescription) + .schemaType(this.schemaType) + .function(beanName, supplier) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) + .build(); + } + else if (bean instanceof Consumer consumer) { + return FunctionCallback.builder() + .description(functionDescription) + .schemaType(this.schemaType) + .function(beanName, consumer) + .inputType(ParameterizedTypeReference.forType(functionInputType.getType())) .build(); } else { @@ -150,6 +179,15 @@ public enum SchemaType { private static class KotlinDelegate { + public static boolean isKotlinSupplier(Class clazz) { + return Function0.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static Supplier wrapKotlinSupplier(Object function) { + return () -> ((Function0) function).invoke(); + } + public static boolean isKotlinFunction(Class clazz) { return Function1.class.isAssignableFrom(clazz); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index f5f601724df..f01cef31a5a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -20,8 +20,11 @@ import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; +import kotlin.jvm.functions.Function0; import kotlin.jvm.functions.Function1; import kotlin.jvm.functions.Function2; @@ -44,6 +47,16 @@ */ public abstract class TypeResolverHelper { + /** + * Returns the input class of a given Consumer class. + * @param consumerClass The consumer class. + * @return The input class of the consumer. + */ + public static Class getConsumerInputClass(Class> consumerClass) { + ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class); + return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass()); + } + /** * Returns the input class of a given function class. * @param biFunctionClass The function class. @@ -199,6 +212,12 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType else if (BiFunction.class.isAssignableFrom(resolvableClass)) { functionArgumentResolvableType = functionType.as(BiFunction.class); } + else if (Supplier.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(Supplier.class); + } + else if (Consumer.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(Consumer.class); + } else if (KotlinDetector.isKotlinPresent()) { if (KotlinDelegate.isKotlinFunction(resolvableClass)) { functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType); @@ -206,6 +225,9 @@ else if (KotlinDetector.isKotlinPresent()) { else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) { functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(functionType); } + else if (KotlinDelegate.isKotlinSupplier(resolvableClass)) { + functionArgumentResolvableType = KotlinDelegate.adaptToKotlinSupplierType(functionType); + } } if (functionArgumentResolvableType == ResolvableType.NONE) { @@ -218,6 +240,14 @@ else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) { private static class KotlinDelegate { + public static boolean isKotlinSupplier(Class clazz) { + return Function0.class.isAssignableFrom(clazz); + } + + public static ResolvableType adaptToKotlinSupplierType(ResolvableType resolvableType) { + return resolvableType.as(Function0.class); + } + public static boolean isKotlinFunction(Class clazz) { return Function1.class.isAssignableFrom(clazz); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java index 285a4f740d1..0ba01629eae 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.function; +import java.util.function.Consumer; import java.util.function.Function; import org.junit.jupiter.params.ParameterizedTest; @@ -39,7 +40,7 @@ public class TypeResolverHelperIT { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction", - "scannedStandaloneWeatherFunction", "componentWeatherFunction" }) + "scannedStandaloneWeatherFunction", "componentWeatherFunction", "weatherConsumer" }) void beanInputTypeResolutionWithResolvableType(String beanName) { assertThat(this.applicationContext).isNotNull(); ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName); @@ -89,6 +90,13 @@ StandaloneWeatherFunction standaloneWeatherFunction() { return new StandaloneWeatherFunction(); } + @Bean + Consumer weatherConsumer() { + return (weatherRequest) -> { + System.out.println(weatherRequest); + }; + } + } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java index 8051fe9912f..45e94a1e67f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.function; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; @@ -35,6 +36,12 @@ */ public class TypeResolverHelperTests { + @Test + public void testGetConsumerInputType() { + Class inputType = TypeResolverHelper.getConsumerInputClass(MyConsumer.class); + assertThat(inputType).isEqualTo(Request.class); + } + @Test public void testGetFunctionInputType() { Class inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class); @@ -63,6 +70,14 @@ public String apply(Response response) { } + public static class MyConsumer implements Consumer { + + @Override + public void accept(Request request) { + } + + } + public static class MockWeatherService implements Function { @Override diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 2f2c7d5d4ac..10888b34df2 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -97,6 +97,7 @@ * xref:api/prompt.adoc[] * xref:api/structured-output-converter.adoc[Structured Output] * xref:api/functions.adoc[Function Calling] +** xref:api/function-callback.adoc[FunctionCallback API] * xref:api/multimodality.adoc[Multimodality] * xref:api/etl-pipeline.adoc[] * xref:api/testing.adoc[AI Model Evaluation] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/function-callback.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/function-callback.adoc new file mode 100644 index 00000000000..583c505ac9e --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/function-callback.adoc @@ -0,0 +1,242 @@ += FunctionCallback + +== Overview + +The `FunctionCallback` interface in Spring AI provides a standardized way to implement Large Language Model (LLM) function calling capabilities. It allows developers to register custom functions that can be called by AI models when specific conditions or intents are detected in the prompts. + +== FunctionCallback Interface + +The main interface defines several key methods: + +* `getName()`: Returns the unique function name within the AI model context +* `getDescription()`: Provides a description that helps the model decide when to invoke the function +* `getInputTypeSchema()`: Defines the JSON schema for the function's input parameters +* `call(String functionInput)`: Handles the actual function execution +* `call(String functionInput, ToolContext toolContext)`: Extended version that supports additional context + +== Builder Pattern + +Spring AI provides a fluent builder API for creating `FunctionCallback` implementations. + +This is particularly useful for defining function callbacks that you can register, pragmatically, on the fly, with your `ChatClient` or `ChatModel` model calls. + +The builders helps with complex configurations, such as custom response handling, schema types (e.g. JSONSchema or OpenAPI), and object mapping. + +=== Function-Invoking Approach + +Converts any `java.util.function.Function`, `BiFunction`, `Supplier` or `Consumer` into a `FunctionCallback` that can be called by the AI model. + +NOTE: You can use lambda expressions or method references to define the function logic but you must provide the input type of the function using the `inputType(TYPE)`. + +==== Function + +[source,java] +---- +FunctionCallback callback = FunctionCallback.builder() + .description("Process a new order") + .function("processOrder", (Order order) -> processOrderLogic(order)) + .inputType(Order.class) + .build(); +---- + +==== BiFunction with ToolContext + +[source,java] +---- +FunctionCallback callback = FunctionCallback.builder() + .description("Process a new order with context") + .function("processOrder", (Order order, ToolContext context) -> + processOrderWithContext(order, context)) + .inputType(Order.class) + .build(); +---- + +==== Supplier + +Use `java.util.Supplier` or `java.util.function.Function` to define functions that don't take any input: + +[source,java] +---- +FunctionCallback.builder() + .description("Turns light onn in the living room") + .function("turnsLight", () -> state.put("Light", "ON")) + .inputType(Void.class) + .build(); +---- + +==== Consumer + +Use `java.util.Consumer` or `java.util.function.Function` to define functions that don't produce output: + +[source,java] +---- +record LightInfo(String roomName, boolean isOn) {} + +FunctionCallback.builder() + .description("Turns light on/off in a selected room") + .function("turnsLight", (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + }) + .inputType(LightInfo.class) + .build(); +---- + +==== Generics Input Type + +Use the `ParameterizedTypeReference` to define functions with generic input types: + +[source,java] +---- +record TrainSearchRequest(T data) {} + +record TrainSearchSchedule(String from, String to, String date) {} + +record TrainSearchScheduleResponse(String from, String to, String date, String trainNumber) {} + +FunctionCallback.builder() + .description("Schedule a train reservation") + .function("trainSchedule", (TrainSearchRequest request) -> { + logger.info("Schedule: " + request.data().from() + " to " + request.data().to()); + return new TrainSearchScheduleResponse(request.data().from(), request. data().to(), "", "123"); + }) + .inputType(new ParameterizedTypeReference>() {}) + .build(); +---- + +=== Method Invoking Approach + +Enables method invocation through reflection while automatically handling JSON schema generation and parameter conversion. It’s particularly useful for integrating Java methods as callable functions within AI model interactions. + +The method invoking implements the `FunctionCallback` interface and provides: + +- Automatic JSON schema generation for method parameters +- Support for both static and instance methods +- Any number of parameters (including none) and return values (including void) +- Any parameter/return types (primitives, objects, collections) +- Special handling for `ToolContext` parameters + +==== Static Method Invocation + +You can refer to a static method in a class by providing the method name, parameter types, and the target class. + +[source,java] +---- +public class WeatherService { + public static String getWeather(String city, TemperatureUnit unit) { + return "Temperature in " + city + ": 20" + unit; + } +} + +FunctionCallback callback = FunctionCallback.builder() + .description("Get weather information for a city") + .method("getWeather", String.class, TemperatureUnit.class) + .targetClass(WeatherService.class) + .build(); +---- + +==== Object instance Method Invocation + +You can refer to an instance method in a class by providing the method name, parameter types, and the target object instance. + +[source,java] +---- +public class DeviceController { + public void setDeviceState(String deviceId, boolean state, ToolContext context) { + Map contextData = context.getContext(); + // Implementation using context data + } +} + +DeviceController controller = new DeviceController(); + +String response = ChatClient.create(chatModel).prompt() + .user("Turn on the living room lights") + .functions(FunctionCallback.builder() + .description("Control device state") + .method("setDeviceState", String.class,boolean.class,ToolContext.class) + .targetObject(controller) + .build()) + .toolContext(Map.of("location", "home")) + .call() + .content(); +---- + +TIP: Optionally, using the `.name()`, you can set a custom function name different from the method name. + +== Customization Options + +== Schema Type Support + +The framework supports different schema types for function parameter validation: + +* JSON Schema (default) +* OpenAPI Schema (for Vertex AI compatibility) + +[source,java] +---- +FunctionCallback.builder() + .schemaType(SchemaType.OPEN_API_SCHEMA) + // ... other configuration + .build(); +---- + +=== Custom Response Handling + +[source,java] +---- +FunctionCallback.builder() + .responseConverter(response -> + customResponseFormatter.format(response)) + // ... other configuration + .build(); +---- + +=== Custom Object Mapping + +[source,java] +---- +FunctionCallback.builder() + .objectMapper(customObjectMapper) + // ... other configuration + .build(); +---- + +== Best Practices + +=== Descriptive Names and Descriptions + +* Provide unique function names +* Write comprehensive descriptions to help the model understand when to invoke the function + +=== Input Type & Schema + +* For the function invoking approach, define input types explicitly and use `ParameterizedTypeReference` for generic types. +* Consider using custom schema when auto-generated ones don't meet requirements. + +=== Error Handling + +* Implement proper error handling in function implementations and return the error message in the response +* You can use the ToolContext to provide additional error context when needed + +=== Tool Context Usage + +* Use ToolContext when additional state or context is required that is provided from the User and not part of the function input generated by the AI model. +* Use `BiFunction` to access the ToolContext in the function invocation approach and add `ToolContext` parameter in the method invoking approach. + + +== Notes on Schema Generation + +* The framework automatically generates JSON schemas from Java types +* For function invoking, the schema is generated based on the input type for the function that needs to be set using `inputType(TYPE)`. Use `ParameterizedTypeReference` for generic types. +* Generated schemas respect Jackson annotations on model classes +* You can bypass the automatic generation by providing custom schemas using `inputTypeSchema()` + +== Common Pitfalls to Avoid + +=== Lack of Description +* Always provide explicit descriptions instead of relying on auto-generated ones +* Clear descriptions improve model's function selection accuracy + +=== Schema Mismatches +* Ensure input types match the Function's input parameter types. +* Use `ParameterizedTypeReference` for generic types. \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index 3bee692bc22..8c001dc7eba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -66,7 +66,8 @@ void functionCallTest() { var promptOptions = MistralAiChatOptions.builder() .withFunctionCallbacks(List.of(FunctionCallback.builder() .description("Get payment status of a transaction") - .function("retrievePaymentStatus", transaction -> new Status(DATA.get(transaction).status())) + .function("retrievePaymentStatus", + (Transaction transaction) -> new Status(DATA.get(transaction).status())) .inputType(Transaction.class) .build())) .build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index 940e2da7d24..1ae2a5d84fe 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.openai.tool; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -72,6 +74,36 @@ void functionCallTest() { }); } + @Test + void lambdaFunctionCallTest() { + Map state = new ConcurrentHashMap<>(); + + record LightInfo(String roomName, boolean isOn) { + } + + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // @formatter:off + String content = ChatClient.builder(chatModel).build().prompt() + .user("Turn the light on in the kitchen and in the living room!") + .functions(FunctionCallback.builder() + .description("Turn light on or off in a room") + .function("turnLight", (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + }) + .inputType(LightInfo.class) + .build()) + .call().content(); + // @formatter:on + logger.info("Response: {}", content); + assertThat(state).containsEntry("kitchen", Boolean.TRUE); + assertThat(state).containsEntry("living room", Boolean.TRUE); + }); + } + @Test void functionCallTest2() { this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 6408edc1b9a..b3ec7d8bcb3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -18,10 +18,14 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -52,179 +56,272 @@ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") class FunctionCallbackWithPlainFunctionBeanIT { - private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); + private static final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"), + "spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); + private static Map feedback = new ConcurrentHashMap<>(); + + @BeforeEach + void setUp() { + feedback.clear(); + } + + @Test + void functionCallingVoidInput() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOn").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOn")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingSupplier() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOnSupplier").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOnSupplier")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingVoidOutput() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("turnLight").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingConsumer() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLightConsumer").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + + }); + } + + @Test + void trainScheduler() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "Please schedule a train from San Francisco to Los Angeles on 2023-12-25"); + + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("trainReservation") + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: {}", response.getResult().getOutput().getContent()); + }); + } + @Test void functionCallWithDirectBiFunction() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithContext") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithContext") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithContext") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithContext") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithBiFunctionClass() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithClassBiFunction") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithClassBiFunction") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithClassBiFunction") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithClassBiFunction") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - // Test weatherFunctionTwo - response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + // Test weatherFunctionTwo + response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithPortableFunctionCallingOptions() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris?"); + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") - .build(); + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("weatherFunction") + .build(); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response.getResult().getOutput().getContent()); + logger.info("Response: {}", response.getResult().getOutput().getContent()); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); } @Test void streamFunctionCallTest() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { - - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - - Flux response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - - String content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).contains("30", "10", "15"); - - // Test weatherFunctionTwo - response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - - content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); - assertThat(content).contains("30", "10", "15"); - - }); + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + + // Test weatherFunctionTwo + response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + + content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); + assertThat(content).contains("30", "10", "15"); + + }); } @Configuration @@ -256,6 +353,70 @@ public Function weather return (weatherService::apply); } + record LightInfo(String roomName, boolean isOn) { + } + + @Bean + @Description("Turn light on or off in a room") + public Function turnLight() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + return null; + }; + } + + @Bean + @Description("Turn light on or off in a room") + public Consumer turnLightConsumer() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + }; + } + + @Bean + @Description("Turns light on in the living room") + public Function turnLivingRoomLightOn() { + return (Void v) -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOn", Boolean.TRUE); + return "Done"; + }; + } + + @Bean + @Description("Turns light on in the living room") + public Supplier turnLivingRoomLightOnSupplier() { + return () -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOnSupplier", Boolean.TRUE); + return "Done"; + }; + } + + record TrainSearchSchedule(String from, String to, String date) { + } + + record TrainSearchScheduleResponse(String from, String to, String date, String trainNumber) { + } + + record TrainSearchRequest(T data) { + } + + record TrainSearchResponse(T data) { + } + + @Bean + @Description("Schedule a train reservation") + public Function, TrainSearchResponse> trainReservation() { + return (TrainSearchRequest request) -> { + logger.info("Turning light to [" + request.data().from() + "] in " + request.data().to()); + return new TrainSearchResponse<>( + new TrainSearchScheduleResponse(request.data().from(), request.data().to(), "", "123")); + }; + } + } public static class MyBiFunction