Skip to content

Commit eddc5d8

Browse files
committed
ensure toolObjects are populated after all beans are initialized
Signed-off-by: jitokim <[email protected]>
1 parent 5e5ef3a commit eddc5d8

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolAnnotatedBeanProcessor.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@
77
import org.springframework.beans.factory.SmartInitializingSingleton;
88
import org.springframework.beans.factory.config.BeanPostProcessor;
99
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
10+
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
1011
import org.springframework.context.ApplicationContext;
1112
import org.springframework.context.ApplicationContextAware;
1213
import org.springframework.context.ConfigurableApplicationContext;
1314

14-
import java.util.Arrays;
15-
import java.util.HashSet;
16-
import java.util.Set;
15+
import java.util.*;
1716

1817
/**
1918
* {@code ToolAnnotatedBeanProcessor} scans beans after initialization and collects those
2019
* that have methods annotated with {@link Tool} within the specified base packages.
21-
*
20+
* <p>
2221
* The collected beans are then registered with a {@link ToolCallbackProvider}.
2322
*/
2423
public class ToolAnnotatedBeanProcessor
@@ -28,7 +27,7 @@ public class ToolAnnotatedBeanProcessor
2827

2928
private ApplicationContext applicationContext;
3029

31-
private Set<Object> methodToolBeans = new HashSet<>();
30+
private List<Object> methodToolBeans = new ArrayList<>();
3231

3332
public ToolAnnotatedBeanProcessor(Set<String> basePackages) {
3433
this.basePackages = basePackages;
@@ -96,13 +95,19 @@ private boolean hasToolAnnotatedMethod(Class<?> clazz) {
9695
public void afterSingletonsInstantiated() {
9796

9897
if (!methodToolBeans.isEmpty()) {
99-
MethodToolCallbackProvider.Builder builder = MethodToolCallbackProvider.builder();
100-
builder.toolObjects(methodToolBeans.toArray());
101-
MethodToolCallbackProvider provider = builder.build();
98+
10299
ConfigurableListableBeanFactory factory = ((ConfigurableApplicationContext) applicationContext)
103100
.getBeanFactory();
104101

105-
if (!factory.containsBean("methodToolCallbackProvider")) {
102+
if (factory.containsBean("methodToolCallbackProvider")) {
103+
BeanDefinitionRegistry registry = (BeanDefinitionRegistry) applicationContext;
104+
MethodToolCallbackProvider bean = factory.getBean(MethodToolCallbackProvider.class);
105+
bean.setToolObjects(methodToolBeans);
106+
}
107+
else {
108+
MethodToolCallbackProvider.Builder builder = MethodToolCallbackProvider.builder();
109+
builder.toolObjects(methodToolBeans.toArray());
110+
MethodToolCallbackProvider provider = builder.build();
106111
factory.registerSingleton("methodToolCallbackProvider", provider);
107112
}
108113
}

auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackAutoRegistrar.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import org.springframework.ai.tool.ToolCallbackProvider;
44
import org.springframework.ai.tool.annotation.Tool;
55
import org.springframework.ai.tool.autoconfigure.annotation.EnableToolCallbackAutoRegistration;
6+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
67
import org.springframework.beans.factory.config.BeanDefinition;
78
import org.springframework.beans.factory.config.ConstructorArgumentValues;
89
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
910
import org.springframework.beans.factory.support.BeanNameGenerator;
1011
import org.springframework.beans.factory.support.GenericBeanDefinition;
12+
import org.springframework.beans.factory.support.RootBeanDefinition;
1113
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
1214
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
1315
import org.springframework.core.type.AnnotationMetadata;
@@ -36,6 +38,21 @@ public class ToolCallbackAutoRegistrar implements ImportBeanDefinitionRegistrar
3638
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry,
3739
BeanNameGenerator importBeanNameGenerator) {
3840

41+
registerToolAnnotatedBeanProcessor(importingClassMetadata, registry);
42+
43+
registerMethodToolCallbackProvider(registry);
44+
45+
}
46+
47+
private void registerMethodToolCallbackProvider(BeanDefinitionRegistry registry) {
48+
RootBeanDefinition callbackProviderDef = new RootBeanDefinition(MethodToolCallbackProvider.class);
49+
callbackProviderDef
50+
.setInstanceSupplier(() -> MethodToolCallbackProvider.builder().toolObjects(new Object[0]).build());
51+
registry.registerBeanDefinition("methodToolCallbackProvider", callbackProviderDef);
52+
}
53+
54+
private void registerToolAnnotatedBeanProcessor(AnnotationMetadata importingClassMetadata,
55+
BeanDefinitionRegistry registry) {
3956
Set<String> basePackages = getBasePackages(importingClassMetadata);
4057

4158
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();

spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public final class MethodToolCallbackProvider implements ToolCallbackProvider {
5151

5252
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class);
5353

54-
private final List<Object> toolObjects;
54+
private List<Object> toolObjects;
5555

5656
private MethodToolCallbackProvider(List<Object> toolObjects) {
5757
Assert.notNull(toolObjects, "toolObjects cannot be null");
@@ -78,6 +78,10 @@ private void assertToolAnnotatedMethodsPresent(List<Object> toolObjects) {
7878
}
7979
}
8080

81+
public void setToolObjects(List<Object> toolObjects) {
82+
this.toolObjects = toolObjects;
83+
}
84+
8185
@Override
8286
public ToolCallback[] getToolCallbacks() {
8387
var toolCallbacks = this.toolObjects.stream()

0 commit comments

Comments
 (0)