Skip to content

Commit ef06515

Browse files
committed
Add auto-registration support for MethodToolCallbackProvider using @tool
- Add @EnableToolCallbackAutoRegistration to enable auto-registration - Add ToolCallbackAutoRegistrar using ImportBeanDefinitionRegistrar - Add ToolAnnotatedBeanProcessor to collect @Tool-annotated methods and register a MethodToolCallbackProvider Signed-off-by: jitokim <[email protected]>
1 parent 4d6ce31 commit ef06515

File tree

7 files changed

+241
-136
lines changed

7 files changed

+241
-136
lines changed

auto-configurations/common/spring-ai-autoconfigure-tool/pom.xml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,6 @@
5656
<artifactId>spring-boot-starter-test</artifactId>
5757
<scope>test</scope>
5858
</dependency>
59-
60-
<dependency>
61-
<groupId>org.mockito</groupId>
62-
<artifactId>mockito-core</artifactId>
63-
<scope>test</scope>
64-
</dependency>
6559
</dependencies>
6660

6761
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package org.springframework.ai.tool.autoconfigure;
2+
3+
import org.springframework.ai.tool.ToolCallbackProvider;
4+
import org.springframework.ai.tool.annotation.Tool;
5+
import org.springframework.ai.tool.method.MethodToolCallbackProvider;
6+
import org.springframework.beans.BeansException;
7+
import org.springframework.beans.factory.SmartInitializingSingleton;
8+
import org.springframework.beans.factory.config.BeanPostProcessor;
9+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
10+
import org.springframework.context.ApplicationContext;
11+
import org.springframework.context.ApplicationContextAware;
12+
import org.springframework.context.ConfigurableApplicationContext;
13+
14+
import java.util.Arrays;
15+
import java.util.HashSet;
16+
import java.util.Set;
17+
18+
/**
19+
* {@code ToolAnnotatedBeanProcessor} scans beans after initialization and collects those
20+
* that have methods annotated with {@link Tool} within the specified base packages.
21+
*
22+
* The collected beans are then registered with a {@link ToolCallbackProvider}.
23+
*/
24+
public class ToolAnnotatedBeanProcessor
25+
implements ApplicationContextAware, BeanPostProcessor, SmartInitializingSingleton {
26+
27+
private final Set<String> basePackages;
28+
29+
private ApplicationContext applicationContext;
30+
31+
private Set<Object> methodToolBeans = new HashSet<>();
32+
33+
public ToolAnnotatedBeanProcessor(Set<String> basePackages) {
34+
this.basePackages = basePackages;
35+
}
36+
37+
@Override
38+
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
39+
this.applicationContext = applicationContext;
40+
}
41+
42+
@Override
43+
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
44+
Class<?> beanClass = bean.getClass();
45+
46+
if (isInBasePackage(beanClass.getPackageName())) {
47+
48+
if (hasToolAnnotatedMethod(beanClass)) {
49+
methodToolBeans.add(bean);
50+
}
51+
}
52+
53+
return bean;
54+
}
55+
56+
/**
57+
* Checks whether the given package name starts with any of the configured base
58+
* packages.
59+
* @param packageName the package name to check
60+
* @return {@code true} if it is within a configured base package; {@code false}
61+
* otherwise
62+
*/
63+
private boolean isInBasePackage(String packageName) {
64+
return basePackages.stream().anyMatch(packageName::startsWith);
65+
}
66+
67+
/**
68+
* Checks if the specified class has any method annotated with {@link Tool}.
69+
* @param clazz the class to inspect
70+
* @return {@code true} if at least one method is annotated with {@link Tool};
71+
* {@code false} otherwise
72+
*/
73+
private boolean hasToolAnnotatedMethod(Class<?> clazz) {
74+
return Arrays.stream(clazz.getMethods()).anyMatch(method -> method.isAnnotationPresent(Tool.class));
75+
}
76+
77+
/**
78+
* Registers a {@link ToolCallbackProvider} bean dynamically after all singleton beans
79+
* have been instantiated.
80+
*
81+
* <p>
82+
* This method is invoked by the Spring container at the end of the singleton bean
83+
* lifecycle. It collects beans containing methods annotated with {@link Tool}, and
84+
* builds a {@link MethodToolCallbackProvider} using those beans. The resulting
85+
* provider is registered into the {@link ApplicationContext} as a singleton bean of
86+
* type {@link ToolCallbackProvider}.
87+
* </p>
88+
*
89+
* <p>
90+
* If no such tool beans are found, or a {@code methodToolCallbackProvider} bean is
91+
* already defined in the context, this method does nothing.
92+
* </p>
93+
*/
94+
95+
@Override
96+
public void afterSingletonsInstantiated() {
97+
98+
if (!methodToolBeans.isEmpty()) {
99+
MethodToolCallbackProvider.Builder builder = MethodToolCallbackProvider.builder();
100+
builder.toolObjects(methodToolBeans.toArray());
101+
MethodToolCallbackProvider provider = builder.build();
102+
ConfigurableListableBeanFactory factory = ((ConfigurableApplicationContext) applicationContext)
103+
.getBeanFactory();
104+
105+
if (!factory.containsBean("methodToolCallbackProvider")) {
106+
factory.registerSingleton("methodToolCallbackProvider", provider);
107+
}
108+
}
109+
110+
}
111+
112+
}
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package org.springframework.ai.tool.autoconfigure;
22

3-
3+
import org.springframework.ai.tool.ToolCallbackProvider;
4+
import org.springframework.ai.tool.annotation.Tool;
45
import org.springframework.ai.tool.autoconfigure.annotation.EnableToolCallbackAutoRegistration;
56
import org.springframework.beans.factory.config.BeanDefinition;
67
import org.springframework.beans.factory.config.ConstructorArgumentValues;
78
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
89
import org.springframework.beans.factory.support.BeanNameGenerator;
910
import org.springframework.beans.factory.support.GenericBeanDefinition;
11+
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
1012
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
1113
import org.springframework.core.type.AnnotationMetadata;
1214
import org.springframework.lang.NonNull;
@@ -16,20 +18,26 @@
1618
import java.util.Set;
1719

1820
/**
19-
* {@link ImportBeanDefinitionRegistrar} implementation that registers a {@link ToolCallbackBeanRegistrar}
20-
* bean based on the metadata from {@link EnableToolCallbackAutoRegistration}.
21+
* {@link ImportBeanDefinitionRegistrar} implementation that registers a
22+
* {@link ToolAnnotatedBeanProcessor} bean based on the metadata from
23+
* {@link EnableToolCallbackAutoRegistration}.
2124
*
22-
* <p>This registrar extracts package scanning information from the annotation attributes
23-
* and registers a {@link ToolCallbackBeanRegistrar} to process beans containing {@code @Tool}-annotated methods.
25+
* <p>
26+
* This registrar extracts package scanning information from the annotation attributes and
27+
* registers a {@link ToolAnnotatedBeanProcessor} to process beans containing
28+
* {@code @Tool}-annotated methods.
2429
*
2530
* @see EnableToolCallbackAutoRegistration
26-
* @see ToolCallbackBeanRegistrar
31+
* @see ToolAnnotatedBeanProcessor
2732
*/
28-
public class AutoToolCallbacksRegistrar implements ImportBeanDefinitionRegistrar {
33+
@ConditionalOnClass({ Tool.class, ToolCallbackProvider.class })
34+
public class ToolCallbackAutoRegistrar implements ImportBeanDefinitionRegistrar {
2935

3036
@Override
31-
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry, BeanNameGenerator importBeanNameGenerator) {
32-
Map<String, Object> attributes = importingClassMetadata.getAnnotationAttributes(EnableToolCallbackAutoRegistration.class.getName());
37+
public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry,
38+
BeanNameGenerator importBeanNameGenerator) {
39+
Map<String, Object> attributes = importingClassMetadata
40+
.getAnnotationAttributes(EnableToolCallbackAutoRegistration.class.getName());
3341

3442
if (attributes == null) {
3543
return;
@@ -38,34 +46,36 @@ public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, B
3846
Set<String> basePackages = getBasePackages(attributes);
3947

4048
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
41-
beanDefinition.setBeanClass(ToolCallbackBeanRegistrar.class);
49+
beanDefinition.setBeanClass(ToolAnnotatedBeanProcessor.class);
4250
beanDefinition.setScope(BeanDefinition.SCOPE_SINGLETON);
4351

4452
ConstructorArgumentValues args = new ConstructorArgumentValues();
4553
args.addGenericArgumentValue(basePackages);
4654
beanDefinition.setConstructorArgumentValues(args);
4755

48-
registry.registerBeanDefinition("toolScannerConfigurer", beanDefinition);
56+
registry.registerBeanDefinition("ToolAnnotatedBeanProcessor", beanDefinition);
4957
}
5058

5159
/**
52-
* Extracts the base packages to scan from the {@code @EnableToolCallbackAutoRegistration} annotation attributes.
60+
* Extracts the base packages to scan from the
61+
* {@code @EnableToolCallbackAutoRegistration} annotation attributes.
5362
*
54-
* <p>Supports the following attributes:
63+
* <p>
64+
* Supports the following attributes:
5565
* <ul>
56-
* <li>{@code value} - Shorthand for base packages</li>
57-
* <li>{@code basePackages} - Explicit list of packages</li>
58-
* <li>{@code basePackageClasses} - Infers packages from class types</li>
66+
* <li>{@code value} - Shorthand for base packages</li>
67+
* <li>{@code basePackages} - Explicit list of packages</li>
68+
* <li>{@code basePackageClasses} - Infers packages from class types</li>
5969
* </ul>
60-
*
6170
* @param attributes the annotation attributes
6271
* @return a set of base package names to scan
6372
*/
6473
private Set<String> getBasePackages(@NonNull Map<String, Object> attributes) {
6574
Set<String> basePackages = new HashSet<>();
6675

6776
Object[] valuePackages = (Object[]) attributes.get("value");
68-
if (valuePackages == null) valuePackages = new Object[0];
77+
if (valuePackages == null)
78+
valuePackages = new Object[0];
6979

7080
for (Object obj : valuePackages) {
7181
if (obj instanceof String str && !str.isEmpty()) {
@@ -74,7 +84,8 @@ private Set<String> getBasePackages(@NonNull Map<String, Object> attributes) {
7484
}
7585

7686
Object[] basePackagesAttr = (Object[]) attributes.get("basePackages");
77-
if (basePackagesAttr == null) basePackagesAttr = new Object[0];
87+
if (basePackagesAttr == null)
88+
basePackagesAttr = new Object[0];
7889

7990
for (Object obj : basePackagesAttr) {
8091
if (obj instanceof String str && !str.isEmpty()) {
@@ -83,7 +94,8 @@ private Set<String> getBasePackages(@NonNull Map<String, Object> attributes) {
8394
}
8495

8596
Object[] basePackageClasses = (Object[]) attributes.get("basePackageClasses");
86-
if (basePackageClasses == null) basePackageClasses = new Object[0];
97+
if (basePackageClasses == null)
98+
basePackageClasses = new Object[0];
8799

88100
for (Object obj : basePackageClasses) {
89101
if (obj instanceof Class<?>) {
@@ -97,4 +109,5 @@ private Set<String> getBasePackages(@NonNull Map<String, Object> attributes) {
97109

98110
return basePackages;
99111
}
112+
100113
}

auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackBeanRegistrar.java

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)