diff --git a/auto-configurations/common/spring-ai-autoconfigure-tool/pom.xml b/auto-configurations/common/spring-ai-autoconfigure-tool/pom.xml new file mode 100644 index 00000000000..0f112d5c2b3 --- /dev/null +++ b/auto-configurations/common/spring-ai-autoconfigure-tool/pom.xml @@ -0,0 +1,66 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../../pom.xml + + spring-ai-autoconfigure-tool + jar + Spring AI Tool Auto Configuration + Spring AI Tool Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + diff --git a/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackAutoRegistrar.java b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackAutoRegistrar.java new file mode 100644 index 00000000000..a7e5f8d719c --- /dev/null +++ b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/ToolCallbackAutoRegistrar.java @@ -0,0 +1,250 @@ +package org.springframework.ai.tool.autoconfigure; + +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.autoconfigure.annotation.EnableToolCallbackAutoRegistration; +import org.springframework.ai.tool.method.MethodToolCallbackProvider; +import org.springframework.aop.support.AopUtils; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.context.event.ApplicationReadyEvent; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.ApplicationListener; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.ImportAware; +import org.springframework.core.annotation.AnnotationAttributes; +import org.springframework.core.type.AnnotationMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * {@link ApplicationListener} for {@link ApplicationReadyEvent} that scans for Spring + * beans with {@link Tool @Tool} annotated methods within specified base packages. It then + * registers a {@link MethodToolCallbackProvider} bean containing these tools. + *

+ * This registrar is activated when {@link EnableToolCallbackAutoRegistration} is used on + * a configuration class. It leverages {@link ImportAware} to obtain configuration + * attributes (like base packages) from the enabling annotation and + * {@link ApplicationContextAware} to access the application context. + *

+ * The actual scanning and registration lógica happens once the application is fully + * ready, ensuring all beans are initialized. + * + * @see EnableToolCallbackAutoRegistration + * @see Tool + * @see MethodToolCallbackProvider + */ + +@ConditionalOnClass({ Tool.class, ToolCallbackProvider.class }) +public class ToolCallbackAutoRegistrar + implements ApplicationListener, ImportAware, ApplicationContextAware { + + private static final Logger logger = LoggerFactory.getLogger(ToolCallbackAutoRegistrar.class); + + private Set basePackages; + + private ApplicationContext applicationContext; + + /** + * Sets the {@link AnnotationMetadata} of the + * importing @{@link org.springframework.context.annotation.Configuration} class. This + * method is called by Spring as part of the {@link ImportAware} contract. It extracts + * the {@code basePackages} and other attributes from the + * {@link EnableToolCallbackAutoRegistration} annotation. + * @param importMetadata metadata of the importing configuration class. + */ + @Override + public void setImportMetadata(AnnotationMetadata importMetadata) { + Map attributesMap = importMetadata + .getAnnotationAttributes(EnableToolCallbackAutoRegistration.class.getName()); + AnnotationAttributes attributes = AnnotationAttributes.fromMap(attributesMap); + this.basePackages = getBasePackages(attributes, importMetadata); + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + } + + /** + * Handles the {@link ApplicationReadyEvent}, which is published when the application + * is ready to service requests. This method performs the scan for {@link Tool @Tool} + * annotated methods in beans within the configured base packages and registers a + * {@link MethodToolCallbackProvider}. + * @param event the {@link ApplicationReadyEvent} signalling that the application is + * ready. + */ + @Override + public void onApplicationEvent(ApplicationReadyEvent event) { + // Ensure this listener reacts only to its own application context's ready event, + // especially in hierarchical contexts. + if (!event.getApplicationContext().equals(this.applicationContext)) { + return; + } + + logger.debug("Application ready, scanning for @Tool annotated methods in base packages: {}", this.basePackages); + + ConfigurableApplicationContext configurableContext = (ConfigurableApplicationContext) this.applicationContext; + ConfigurableListableBeanFactory beanFactory = configurableContext.getBeanFactory(); + + List toolBeans = new ArrayList<>(); + String[] beanNames = beanFactory.getBeanDefinitionNames(); + + for (String beanName : beanNames) { + // Check if the bean is a singleton, not abstract, and actually obtainable. + // This avoids issues with beans that are not yet fully initialized or are + // infrastructure beans. + if (beanFactory.isSingleton(beanName) && !beanFactory.getBeanDefinition(beanName).isAbstract() + && beanFactory.containsBean(beanName)) { + Object beanInstance = null; + try { + beanInstance = beanFactory.getBean(beanName); + } + catch (BeansException e) { + // Log and continue, as some beans might not be fully ready or are + // special (e.g., factory beans). + logger.trace("Could not retrieve bean instance for name '{}' during @Tool scan. Skipping.", + beanName, e); + continue; + } + + // Resolve the target class for AOP proxies to find annotations on the + // actual class. + Class targetClass = AopUtils.getTargetClass(beanInstance); + + if (isInBasePackage(targetClass.getPackageName())) { + if (hasToolAnnotatedMethod(targetClass)) { + toolBeans.add(beanInstance); + logger.debug("Found @Tool annotated methods in bean: {} of type {}", beanName, + targetClass.getName()); + } + } + } + } + + if (!toolBeans.isEmpty()) { + // If a MethodToolCallbackProvider bean doesn't already exist, register one + // with the found tools. + if (!beanFactory.containsBean("methodToolCallbackProvider")) { + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(toolBeans.toArray()) + .build(); + beanFactory.registerSingleton("methodToolCallbackProvider", provider); + logger.info("Registered MethodToolCallbackProvider with {} tool bean(s).", toolBeans.size()); + } + else { + // If a bean with this name already exists, log a warning. + // This might happen if the user manually defines a bean with the same + // name. + logger.warn( + "Bean 'methodToolCallbackProvider' already exists. Skipping registration by ToolCallbackAutoRegistrar. " + + "If this is unexpected, check your configuration."); + } + } + else { + logger.debug("No beans with @Tool annotated methods found in the specified base packages."); + // If no tool beans are found and no provider bean exists, register an empty + // provider. + // This ensures that beans depending on MethodToolCallbackProvider can still + // be autowired. + if (!beanFactory.containsBean("methodToolCallbackProvider")) { + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects().build(); // Empty + beanFactory.registerSingleton("methodToolCallbackProvider", provider); + logger.info("Registered an empty MethodToolCallbackProvider as no tool beans were found."); + } + } + } + + /** + * Extracts the base packages to scan from the + * {@link EnableToolCallbackAutoRegistration} annotation attributes. It considers + * {@code value}, {@code basePackages}, and {@code basePackageClasses} attributes. If + * no packages are explicitly defined, it falls back to the package of the class + * annotated with {@link EnableToolCallbackAutoRegistration}. + * @param attributes The attributes of the {@link EnableToolCallbackAutoRegistration} + * annotation. + * @param importingClassMetadata Metadata of the class that imported this registrar + * (the @Configuration class). + * @return A set of package names to scan. + */ + private Set getBasePackages(AnnotationAttributes attributes, AnnotationMetadata importingClassMetadata) { + Set packages = new HashSet<>(); + + // Extract packages from 'value' attribute + for (String pkg : attributes.getStringArray("value")) { + if (pkg != null && !pkg.isEmpty()) { + packages.add(pkg); + } + } + // Extract packages from 'basePackages' attribute + for (String pkg : attributes.getStringArray("basePackages")) { + if (pkg != null && !pkg.isEmpty()) { + packages.add(pkg); + } + } + // Extract packages from 'basePackageClasses' attribute + for (Class clazz : attributes.getClassArray("basePackageClasses")) { + packages.add(clazz.getPackage().getName()); + } + + // Fallback: If no packages are specified, use the package of the importing + // @Configuration class. + if (packages.isEmpty() && importingClassMetadata != null) { + String className = importingClassMetadata.getClassName(); + try { + Class importingClass = Class.forName(className); + Package pkg = importingClass.getPackage(); + if (pkg != null) { + packages.add(pkg.getName()); + logger.debug( + "No explicit base packages configured. Using package of @EnableToolCallbackAutoRegistration class: {}", + pkg.getName()); + } + } + catch (ClassNotFoundException e) { + logger.warn("Could not resolve base package from importing class: {}", className, e); + } + } + + if (packages.isEmpty()) { + logger.warn("No base packages configured for @Tool scanning. Scanning will be effectively disabled."); + } + return packages; + } + + /** + * Checks if the given package name is within any of the configured base packages. + * @param packageName The package name to check. + * @return {@code true} if the package name starts with any of the configured base + * packages, {@code false} otherwise. Returns {@code false} if no base packages are + * defined. + */ + private boolean isInBasePackage(String packageName) { + if (this.basePackages == null || this.basePackages.isEmpty()) { + return false; // No scanning if no base packages are defined. + } + return this.basePackages.stream() + .anyMatch(basePackage -> packageName != null && packageName.startsWith(basePackage)); + } + + /** + * Checks if the given class (or any of its superclasses/interfaces) has at least one + * method annotated with {@link Tool @Tool}. + * @param clazz The class to inspect. + * @return {@code true} if at least one {@link Tool @Tool} annotated method is found, + * {@code false} otherwise. + */ + private boolean hasToolAnnotatedMethod(Class clazz) { + return Arrays.stream(clazz.getMethods()).anyMatch(method -> method.isAnnotationPresent(Tool.class)); + } + +} diff --git a/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistration.java b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistration.java new file mode 100644 index 00000000000..9f9dd67b326 --- /dev/null +++ b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistration.java @@ -0,0 +1,61 @@ +package org.springframework.ai.tool.autoconfigure.annotation; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.autoconfigure.ToolCallbackAutoRegistrar; +import org.springframework.context.annotation.Import; + +import java.lang.annotation.*; + +/** + * Enables automatic registration of {@link Tool}-annotated methods as + * {@link ToolCallback}s. + * + *

+ * When this annotation is used on a configuration class, it imports the + * {@link ToolCallbackAutoRegistrar}, which scans the specified packages for Spring beans + * containing {@code @Tool}-annotated methods. These beans are then registered as + * {@link ToolCallbackProvider}s. + * + *

+ * Usage example: + *

+ *
+ * {@code
+ * Configuration
+ *
+
+EnableToolCallbackAutoRegistration(basePackages = "com.example.tools")
+ * public class MyToolConfig {
+ * }
+ * }
+ * 
+ * + *

+ * You can specify packages to scan in one of three ways: + *

    + *
  • {@code basePackages} - Explicit list of package names
  • + *
  • {@code value} - Alias for {@code basePackages}
  • + *
  • {@code basePackageClasses} - Package names inferred from provided classes
  • + *
+ * + * @see Tool + * @see ToolCallback + * @see ToolCallbackProvider + * @see ToolCallbackAutoRegistrar + */ + +@Target({ ElementType.TYPE, ElementType.METHOD }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Import(ToolCallbackAutoRegistrar.class) +public @interface EnableToolCallbackAutoRegistration { + + String[] basePackages() default {}; + + String[] value() default {}; + + Class[] basePackageClasses() default {}; + +} diff --git a/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/resources/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/resources/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..e7d88717ff2 --- /dev/null +++ b/auto-configurations/common/spring-ai-autoconfigure-tool/src/main/resources/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,16 @@ +# +# Copyright 2025-2025 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +org.springframework.ai.tool.autoconfigure.annotation.EnableToolCallbackAutoRegistration diff --git a/auto-configurations/common/spring-ai-autoconfigure-tool/src/test/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistrationIT.java b/auto-configurations/common/spring-ai-autoconfigure-tool/src/test/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistrationIT.java new file mode 100644 index 00000000000..48f068c8c1a --- /dev/null +++ b/auto-configurations/common/spring-ai-autoconfigure-tool/src/test/java/org/springframework/ai/tool/autoconfigure/annotation/EnableToolCallbackAutoRegistrationIT.java @@ -0,0 +1,71 @@ +package org.springframework.ai.tool.autoconfigure.annotation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.method.MethodToolCallbackProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationListener; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.event.ContextRefreshedEvent; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +@SpringBootTest(classes = EnableToolCallbackAutoRegistrationIT.Config.class) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class EnableToolCallbackAutoRegistrationIT { + + private static final CountDownLatch latch = new CountDownLatch(1); + + @Autowired + private ApplicationContext context; + + @Test + void shouldRegisterToolCallbackProviderBean() throws InterruptedException { + if (!latch.await(3, TimeUnit.SECONDS)) { + fail("Application context was not fully refreshed in time"); + } + + ToolCallbackProvider provider = context.getBean(MethodToolCallbackProvider.class); + + assertThat(provider.getToolCallbacks()).extracting(FunctionCallback::getName).contains("echo"); + + assertThat(provider.getToolCallbacks()).extracting(FunctionCallback::getDescription) + .contains("This is a description"); + } + + @Configuration + @EnableToolCallbackAutoRegistration + static class Config { + + @Bean + public EchoTool echoTool() { + return new EchoTool(); + } + + @Bean + public ApplicationListener latchReleaser() { + return event -> latch.countDown(); + } + + } + + static class EchoTool { + + @Tool(description = "This is a description") + public String echo(String input) { + return input; + } + + } + +} diff --git a/pom.xml b/pom.xml index 6c45eb4d619..602829155e0 100644 --- a/pom.xml +++ b/pom.xml @@ -81,6 +81,7 @@ auto-configurations/common/spring-ai-autoconfigure-retry + auto-configurations/common/spring-ai-autoconfigure-tool auto-configurations/models/tool/spring-ai-autoconfigure-model-tool diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index ca1a57cf65b..010a222bfd7 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -460,6 +460,13 @@ ${project.version} + + + org.springframework.ai + spring-ai-autoconfigure-tool + ${project.version} + + diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java index 666aa6f97f3..d195f3f3a6d 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java @@ -51,7 +51,7 @@ public final class MethodToolCallbackProvider implements ToolCallbackProvider { private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class); - private final List toolObjects; + private List toolObjects; private MethodToolCallbackProvider(List toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null"); @@ -78,6 +78,10 @@ private void assertToolAnnotatedMethodsPresent(List toolObjects) { } } + public void setToolObjects(List toolObjects) { + this.toolObjects = toolObjects; + } + @Override public ToolCallback[] getToolCallbacks() { var toolCallbacks = this.toolObjects.stream()