diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index 6f7090d8682..3b72c3bae26 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -72,6 +72,12 @@ spring-boot-starter-test test + + org.springframework + spring-context + 6.1.0-M3 + compile + diff --git a/spring-ai-core/src/main/java/org/springframework/ai/SpringAiFunctionManager.java b/spring-ai-core/src/main/java/org/springframework/ai/SpringAiFunctionManager.java new file mode 100644 index 00000000000..dfd6cfdef0f --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/SpringAiFunctionManager.java @@ -0,0 +1,61 @@ +package org.springframework.ai; + +import org.springframework.ai.annotations.SpringAIFunction; +import org.springframework.ai.model.AbstractToolFunctionCallback; +import org.springframework.ai.model.ToolFunctionCallback; +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.support.GenericApplicationContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public class SpringAiFunctionManager implements ApplicationContextAware { + + private GenericApplicationContext applicationContext; + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = (GenericApplicationContext) applicationContext; + } + + public int size() { + return this.applicationContext.getBeansWithAnnotation(SpringAIFunction.class).size(); + } + + /** + * @return The list of chat functions + */ + public List getAnnotatedToolFunctionCallbacks() { + var beans = this.applicationContext.getBeansWithAnnotation(SpringAIFunction.class); + List chatFunctions = new ArrayList<>(); + beans.forEach((k, v) -> { + if (v instanceof Function function) { + SpringAIFunction aiFunction = applicationContext.findAnnotationOnBean(k, SpringAIFunction.class); + chatFunctions.add(new AnnotationGeneratedFunctionCallback(aiFunction.name(), aiFunction.description(), + aiFunction.classType(), function)); + } + }); + + return chatFunctions; + } + +} + +class AnnotationGeneratedFunctionCallback extends AbstractToolFunctionCallback { + + private Function function; + + protected AnnotationGeneratedFunctionCallback(String name, String description, Class inputType, Function function) { + super(name, description, inputType); + this.function = function; + } + + @Override + public Object apply(Object o) { + return function.apply(o); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/annotations/SpringAIFunction.java b/spring-ai-core/src/main/java/org/springframework/ai/annotations/SpringAIFunction.java new file mode 100644 index 00000000000..0e0bda54594 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/annotations/SpringAIFunction.java @@ -0,0 +1,25 @@ +package org.springframework.ai.annotations; + +import org.springframework.context.annotation.Bean; +import org.springframework.core.annotation.AliasFor; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * An annotation used to define functions for use in + */ +@Bean +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface SpringAIFunction { + + String name(); + + String description(); + + Class classType(); + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index b1c13a32379..f519f31b12e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -18,6 +18,7 @@ import java.util.List; +import org.springframework.ai.SpringAiFunctionManager; import org.springframework.ai.autoconfigure.NativeHints; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.model.ToolFunctionCallback; @@ -31,6 +32,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.util.Assert; @@ -56,7 +58,7 @@ public class OpenAiAutoConfiguration { @ConditionalOnMissingBean public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder, - List toolFunctionCallbacks) { + List toolFunctionCallbacks, SpringAiFunctionManager functionManager) { String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey() : commonProperties.getApiKey(); @@ -73,6 +75,11 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper chatProperties.getOptions().getToolCallbacks().addAll(toolFunctionCallbacks); } + var annotatedFunctionsList = functionManager.getAnnotatedToolFunctionCallbacks(); + if (!annotatedFunctionsList.isEmpty()) { + chatProperties.getOptions().getToolCallbacks().addAll(annotatedFunctionsList); + } + return new OpenAiChatClient(openAiApi, chatProperties.getOptions()); } @@ -113,4 +120,12 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions()); } + @Bean + @ConditionalOnMissingBean + public SpringAiFunctionManager springAiFunctionManager(ApplicationContext context) { + SpringAiFunctionManager manager = new SpringAiFunctionManager(); + manager.setApplicationContext(context); + return manager; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithSpringAIFunctionAnnotationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithSpringAIFunctionAnnotationIT.java new file mode 100644 index 00000000000..131a5a71831 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithSpringAIFunctionAnnotationIT.java @@ -0,0 +1,81 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.openai.tool; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.annotations.SpringAIFunction; +import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Configuration; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class ToolCallWithSpringAIFunctionAnnotationIT { + + private final Logger logger = LoggerFactory.getLogger(ToolCallWithBeanFunctionRegistrationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=gpt-4-1106-preview").run(context -> { + + OpenAiChatClient chatClient = context.getBean(OpenAiChatClient.class); + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withEnabledFunction("WeatherInfo").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Configuration + static class Config { + + @SpringAIFunction(name = "WeatherInfo", description = "Get the weather in location", + classType = MockWeatherService.Request.class) + public Function weatherFunction() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} \ No newline at end of file