Skip to content

Commit e79c6c6

Browse files
committed
ensure toolObjects are populated after all beans are initialized
Signed-off-by: jitokim <[email protected]>
1 parent df35755 commit e79c6c6

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolAnnotatedBeanProcessor.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@
77
import org.springframework.beans.factory.SmartInitializingSingleton;
88
import org.springframework.beans.factory.config.BeanPostProcessor;
99
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
10+
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
1011
import org.springframework.context.ApplicationContext;
1112
import org.springframework.context.ApplicationContextAware;
1213
import org.springframework.context.ConfigurableApplicationContext;
1314

14-
import java.util.Arrays;
15-
import java.util.HashSet;
16-
import java.util.Set;
15+
import java.util.*;
1716

1817
/**
1918
* {@code ToolAnnotatedBeanProcessor} scans beans after initialization and collects those
2019
* that have methods annotated with {@link Tool} within the specified base packages.
21-
*
20+
* <p>
2221
* The collected beans are then registered with a {@link ToolCallbackProvider}.
2322
*/
2423
public class ToolAnnotatedBeanProcessor
@@ -28,7 +27,7 @@ public class ToolAnnotatedBeanProcessor
2827

2928
private ApplicationContext applicationContext;
3029

31-
private Set<Object> methodToolBeans = new HashSet<>();
30+
private List<Object> methodToolBeans = new ArrayList<>();
3231

3332
public ToolAnnotatedBeanProcessor(Set<String> basePackages) {
3433
this.basePackages = basePackages;
@@ -96,13 +95,19 @@ private boolean hasToolAnnotatedMethod(Class<?> clazz) {
9695
public void afterSingletonsInstantiated() {
9796

9897
if (!methodToolBeans.isEmpty()) {
99-
MethodToolCallbackProvider.Builder builder = MethodToolCallbackProvider.builder();
100-
builder.toolObjects(methodToolBeans.toArray());
101-
MethodToolCallbackProvider provider = builder.build();
98+
10299
ConfigurableListableBeanFactory factory = ((ConfigurableApplicationContext) applicationContext)
103100
.getBeanFactory();
104101

105-
if (!factory.containsBean("methodToolCallbackProvider")) {
102+
if (factory.containsBean("methodToolCallbackProvider")) {
103+
BeanDefinitionRegistry registry = (BeanDefinitionRegistry) applicationContext;
104+
MethodToolCallbackProvider bean = factory.getBean(MethodToolCallbackProvider.class);
105+
bean.setToolObjects(methodToolBeans);
106+
}
107+
else {
108+
MethodToolCallbackProvider.Builder builder = MethodToolCallbackProvider.builder();
109+
builder.toolObjects(methodToolBeans.toArray());
110+
MethodToolCallbackProvider provider = builder.build();
106111
factory.registerSingleton("methodToolCallbackProvider", provider);
107112
}
108113
}

auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackAutoRegistrar.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import org.springframework.ai.tool.ToolCallbackProvider;
44
import org.springframework.ai.tool.annotation.Tool;
55
import org.springframework.ai.tool.autoconfigure.annotation.EnableToolCallbackAutoRegistration;
6+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
67
import org.springframework.beans.factory.config.BeanDefinition;
78
import org.springframework.beans.factory.config.ConstructorArgumentValues;
89
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
910
import org.springframework.beans.factory.support.BeanNameGenerator;
1011
import org.springframework.beans.factory.support.GenericBeanDefinition;
12+
import org.springframework.beans.factory.support.RootBeanDefinition;
1113
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
1214
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
1315
import org.springframework.core.type.AnnotationMetadata;
@@ -36,6 +38,21 @@ public class ToolCallbackAutoRegistrar implements ImportBeanDefinitionRegistrar
3638
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry,
3739
BeanNameGenerator importBeanNameGenerator) {
3840

41+
registerToolAnnotatedBeanProcessor(importingClassMetadata, registry);
42+
43+
registerMethodToolCallbackProvider(registry);
44+
45+
}
46+
47+
private void registerMethodToolCallbackProvider(BeanDefinitionRegistry registry) {
48+
RootBeanDefinition callbackProviderDef = new RootBeanDefinition(MethodToolCallbackProvider.class);
49+
callbackProviderDef
50+
.setInstanceSupplier(() -> MethodToolCallbackProvider.builder().toolObjects(new Object[0]).build());
51+
registry.registerBeanDefinition("methodToolCallbackProvider", callbackProviderDef);
52+
}
53+
54+
private void registerToolAnnotatedBeanProcessor(AnnotationMetadata importingClassMetadata,
55+
BeanDefinitionRegistry registry) {
3956
Set<String> basePackages = getBasePackages(importingClassMetadata);
4057

4158
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();

spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,18 @@ public class MethodToolCallbackProvider implements ToolCallbackProvider {
4949

5050
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class);
5151

52-
private final List<Object> toolObjects;
52+
private List<Object> toolObjects;
5353

5454
private MethodToolCallbackProvider(List<Object> toolObjects) {
5555
Assert.notNull(toolObjects, "toolObjects cannot be null");
5656
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
5757
this.toolObjects = toolObjects;
5858
}
5959

60+
public void setToolObjects(List<Object> toolObjects) {
61+
this.toolObjects = toolObjects;
62+
}
63+
6064
@Override
6165
public ToolCallback[] getToolCallbacks() {
6266
var toolCallbacks = toolObjects.stream()

0 commit comments

Comments
 (0)