Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spring-ai-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
<version>6.1.0-M3</version>
<scope>compile</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.springframework.ai;

import org.springframework.ai.annotations.SpringAIFunction;
import org.springframework.ai.model.AbstractToolFunctionCallback;
import org.springframework.ai.model.ToolFunctionCallback;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.support.GenericApplicationContext;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class SpringAiFunctionManager implements ApplicationContextAware {

private GenericApplicationContext applicationContext;

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = (GenericApplicationContext) applicationContext;
}

public int size() {
return this.applicationContext.getBeansWithAnnotation(SpringAIFunction.class).size();
}

/**
* @return The list of chat functions
*/
public List<ToolFunctionCallback> getAnnotatedToolFunctionCallbacks() {
var beans = this.applicationContext.getBeansWithAnnotation(SpringAIFunction.class);
List<ToolFunctionCallback> chatFunctions = new ArrayList<>();
beans.forEach((k, v) -> {
if (v instanceof Function<?, ?> function) {
SpringAIFunction aiFunction = applicationContext.findAnnotationOnBean(k, SpringAIFunction.class);
chatFunctions.add(new AnnotationGeneratedFunctionCallback(aiFunction.name(), aiFunction.description(),
aiFunction.classType(), function));
}
});

return chatFunctions;
}

}

class AnnotationGeneratedFunctionCallback extends AbstractToolFunctionCallback<Object, Object> {

private Function function;

protected AnnotationGeneratedFunctionCallback(String name, String description, Class inputType, Function function) {
super(name, description, inputType);
this.function = function;
}

@Override
public Object apply(Object o) {
return function.apply(o);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.springframework.ai.annotations;

import org.springframework.context.annotation.Bean;
import org.springframework.core.annotation.AliasFor;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* An annotation used to define functions for use in
*/
@Bean
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SpringAIFunction {

String name();

String description();

Class classType();

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.List;

import org.springframework.ai.SpringAiFunctionManager;
import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.model.ToolFunctionCallback;
Expand All @@ -31,6 +32,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.util.Assert;
Expand All @@ -56,7 +58,7 @@ public class OpenAiAutoConfiguration {
@ConditionalOnMissingBean
public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties,
OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder,
List<ToolFunctionCallback> toolFunctionCallbacks) {
List<ToolFunctionCallback> toolFunctionCallbacks, SpringAiFunctionManager functionManager) {

String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey()
: commonProperties.getApiKey();
Expand All @@ -73,6 +75,11 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
chatProperties.getOptions().getToolCallbacks().addAll(toolFunctionCallbacks);
}

var annotatedFunctionsList = functionManager.getAnnotatedToolFunctionCallbacks();
if (!annotatedFunctionsList.isEmpty()) {
chatProperties.getOptions().getToolCallbacks().addAll(annotatedFunctionsList);
}

return new OpenAiChatClient(openAiApi, chatProperties.getOptions());
}

Expand Down Expand Up @@ -113,4 +120,12 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp
return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions());
}

@Bean
@ConditionalOnMissingBean
public SpringAiFunctionManager springAiFunctionManager(ApplicationContext context) {
SpringAiFunctionManager manager = new SpringAiFunctionManager();
manager.setApplicationContext(context);
return manager;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.openai.tool;

import java.util.List;
import java.util.function.Function;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.annotations.SpringAIFunction;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;

import static org.assertj.core.api.Assertions.assertThat;

@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
class ToolCallWithSpringAIFunctionAnnotationIT {

private final Logger logger = LoggerFactory.getLogger(ToolCallWithBeanFunctionRegistrationIT.class);

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.withUserConfiguration(Config.class);

@Test
void functionCallTest() {
contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=gpt-4-1106-preview").run(context -> {

OpenAiChatClient chatClient = context.getBean(OpenAiChatClient.class);

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");

ChatResponse response = chatClient.call(new Prompt(List.of(userMessage),
OpenAiChatOptions.builder().withEnabledFunction("WeatherInfo").build()));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");

});
}

@Configuration
static class Config {

@SpringAIFunction(name = "WeatherInfo", description = "Get the weather in location",
classType = MockWeatherService.Request.class)
public Function<MockWeatherService.Request, MockWeatherService.Response> weatherFunction() {
MockWeatherService weatherService = new MockWeatherService();
return (weatherService::apply);
}

}

}