diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java index e6e60920c8b..0bed3f77765 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java @@ -23,6 +23,15 @@ import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; @@ -38,7 +47,10 @@ import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; +import org.springframework.ai.mcp.annotation.spring.AsyncMcpAnnotationProviders; +import org.springframework.ai.mcp.annotation.spring.SyncMcpAnnotationProviders; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpAsyncAnnotationCustomizer; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpSyncAnnotationCustomizer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpAsyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; @@ -46,6 +58,7 @@ import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -137,60 +150,45 @@ private String connectedClientName(String clientName, String serverConnectionNam } /** - * Creates a list of {@link McpSyncClient} instances based on the available - * transports. + * Creates a {@link McpSyncClientInitializer} that defers client creation until all + * singleton beans have been initialized. * *

- * Each client is configured with: - *

- * - *

- * If initialization is enabled in properties, the clients are automatically - * initialized. + * This ensures that all beans with MCP annotations have been scanned and registered + * before the clients are created, preventing a timing issue where late-initialized + * beans might miss registration. * @param mcpSyncClientConfigurer the configurer for customizing client creation * @param commonProperties common MCP client properties * @param transportsProvider provider of named MCP transports - * @return list of configured MCP sync clients + * @param annotatedBeans registry of beans with MCP annotations + * @return the client initializer that creates clients after singleton instantiation */ @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer, + public McpSyncClientInitializer mcpSyncClientInitializer(McpSyncClientConfigurer mcpSyncClientConfigurer, McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { - - List mcpSyncClients = new ArrayList<>(); - - List namedTransports = transportsProvider.stream().flatMap(List::stream).toList(); - - if (!CollectionUtils.isEmpty(namedTransports)) { - for (NamedClientMcpTransport namedTransport : namedTransports) { - - McpSchema.Implementation clientInfo = new McpSchema.Implementation( - this.connectedClientName(commonProperties.getName(), namedTransport.name()), - namedTransport.name(), commonProperties.getVersion()); - - McpClient.SyncSpec spec = McpClient.sync(namedTransport.transport()) - .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); - - spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec); - - var client = spec.build(); - - if (commonProperties.isInitialized()) { - client.initialize(); - } - - mcpSyncClients.add(client); - } - } + ObjectProvider> transportsProvider, + ObjectProvider annotatedBeansProvider) { + return new McpSyncClientInitializer(this, mcpSyncClientConfigurer, commonProperties, transportsProvider, + annotatedBeansProvider); + } - return mcpSyncClients; + /** + * Provides the list of {@link McpSyncClient} instances created by the initializer. + * + *

+ * This bean is available after all singleton beans have been initialized. + * @param initializer the client initializer + * @return list of configured MCP sync clients + */ + @Bean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", + matchIfMissing = true) + public List mcpSyncClients(McpSyncClientInitializer initializer) { + // Return the client list directly - it will be populated by + // SmartInitializingSingleton + return initializer.getClients(); } /** @@ -222,81 +220,341 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider + * This ensures that all beans with MCP annotations have been scanned and registered + * before the clients are created, preventing a timing issue where late-initialized + * beans might miss registration. + * @param mcpAsyncClientConfigurer the configurer for customizing client creation + * @param commonProperties common MCP client properties + * @param transportsProvider provider of named MCP transports + * @param annotatedBeans registry of beans with MCP annotations + * @return the client initializer that creates clients after singleton instantiation + */ @Bean - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", - matchIfMissing = true) - public McpSyncClientCustomizer mcpAnnotationMcpSyncClientCustomizer(List loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List syncToolListChangedSpecifications, - List syncResourceListChangedSpecifications, - List syncPromptListChangedSpecifications) { - return new McpSyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - syncToolListChangedSpecifications, syncResourceListChangedSpecifications, - syncPromptListChangedSpecifications); + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + public McpAsyncClientInitializer mcpAsyncClientInitializer(McpAsyncClientConfigurer mcpAsyncClientConfigurer, + McpClientCommonProperties commonProperties, + ObjectProvider> transportsProvider, + ObjectProvider annotatedBeansProvider) { + return new McpAsyncClientInitializer(this, mcpAsyncClientConfigurer, commonProperties, transportsProvider, + annotatedBeansProvider); } - // Async client configuration + /** + * Provides the list of {@link McpAsyncClient} instances created by the initializer. + * + *

+ * This bean is available after all singleton beans have been initialized. + * @param initializer the client initializer + * @return list of configured MCP async clients + */ + @Bean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + public List mcpAsyncClients(McpAsyncClientInitializer initializer) { + // Return the client list directly - it will be populated by + // SmartInitializingSingleton + return initializer.getClients(); + } @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public List mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer, - McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { + public CloseableMcpAsyncClients makeAsyncClientsClosable(List clients) { + return new CloseableMcpAsyncClients(clients); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + McpAsyncClientConfigurer mcpAsyncClientConfigurer(ObjectProvider customizerProvider) { + return new McpAsyncClientConfigurer(customizerProvider.orderedStream().toList()); + } + + /** + * Initializer for MCP synchronous clients that implements + * {@link SmartInitializingSingleton}. + * + *

+ * This class defers the creation of MCP sync clients until after all singleton beans + * have been initialized. This ensures that all beans with MCP-annotated methods have + * been scanned and registered in the {@link ClientMcpAnnotatedBeans} registry before + * the specifications are created and clients are configured. + * + *

+ * The initialization process: + *

    + *
  1. Wait for all singleton beans to be instantiated + *
  2. Re-evaluate specifications from the complete registry + *
  3. Create and configure MCP clients with all registered specifications + *
  4. Initialize clients if configured to do so + *
+ */ + public static class McpSyncClientInitializer implements SmartInitializingSingleton { + + private static final Logger logger = LoggerFactory.getLogger(McpSyncClientInitializer.class); + + private final McpClientAutoConfiguration configuration; + + private final McpSyncClientConfigurer configurer; + + private final McpClientCommonProperties properties; + + private final ObjectProvider> transportsProvider; + + private final ObjectProvider annotatedBeansProvider; + + final List clients = new ArrayList<>(); + + private long initializationTimestamp = -1; + + public McpSyncClientInitializer(McpClientAutoConfiguration configuration, McpSyncClientConfigurer configurer, + McpClientCommonProperties properties, ObjectProvider> transportsProvider, + ObjectProvider annotatedBeansProvider) { + this.configuration = configuration; + this.configurer = configurer; + this.properties = properties; + this.transportsProvider = transportsProvider; + this.annotatedBeansProvider = annotatedBeansProvider; + } + + @Override + public void afterSingletonsInstantiated() { + // Record when initialization starts + this.initializationTimestamp = System.nanoTime(); + + logger.debug("Creating MCP sync clients after all singleton beans have been instantiated"); + + McpSyncClientCustomizer annotationCustomizer = null; - List mcpAsyncClients = new ArrayList<>(); + // Only create annotation customizer if annotated beans registry is available + ClientMcpAnnotatedBeans annotatedBeans = this.annotatedBeansProvider.getIfAvailable(); + if (annotatedBeans != null) { + // Re-create specifications from the now-complete registry + List loggingSpecs = SyncMcpAnnotationProviders + .loggingSpecifications(annotatedBeans.getBeansByAnnotation(McpLogging.class)); - List namedTransports = transportsProvider.stream().flatMap(List::stream).toList(); + List samplingSpecs = SyncMcpAnnotationProviders + .samplingSpecifications(annotatedBeans.getBeansByAnnotation(McpSampling.class)); - if (!CollectionUtils.isEmpty(namedTransports)) { - for (NamedClientMcpTransport namedTransport : namedTransports) { + List elicitationSpecs = SyncMcpAnnotationProviders + .elicitationSpecifications(annotatedBeans.getBeansByAnnotation(McpElicitation.class)); - McpSchema.Implementation clientInfo = new McpSchema.Implementation( - this.connectedClientName(commonProperties.getName(), namedTransport.name()), - commonProperties.getVersion()); + List progressSpecs = SyncMcpAnnotationProviders + .progressSpecifications(annotatedBeans.getBeansByAnnotation(McpProgress.class)); - McpClient.AsyncSpec spec = McpClient.async(namedTransport.transport()) - .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); + List toolListChangedSpecs = SyncMcpAnnotationProviders + .toolListChangedSpecifications(annotatedBeans.getBeansByAnnotation(McpToolListChanged.class)); - spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec); + List resourceListChangedSpecs = SyncMcpAnnotationProviders + .resourceListChangedSpecifications( + annotatedBeans.getBeansByAnnotation(McpResourceListChanged.class)); - var client = spec.build(); + List promptListChangedSpecs = SyncMcpAnnotationProviders + .promptListChangedSpecifications(annotatedBeans.getBeansByAnnotation(McpPromptListChanged.class)); + + // Create the annotation customizer with fresh specifications + annotationCustomizer = new McpSyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, + progressSpecs, toolListChangedSpecs, resourceListChangedSpecs, promptListChangedSpecs); + } + + // Create the clients using the base configurer and annotation customizer (if + // available) + List createdClients = createClients(this.configurer, annotationCustomizer); + this.clients.addAll(createdClients); + + logger.info("Created {} MCP sync client(s)", this.clients.size()); + } - if (commonProperties.isInitialized()) { - client.initialize().block(); + private List createClients(McpSyncClientConfigurer configurer, + McpSyncClientCustomizer annotationCustomizer) { + List mcpSyncClients = new ArrayList<>(); + + List namedTransports = this.transportsProvider.stream() + .flatMap(List::stream) + .toList(); + + if (!CollectionUtils.isEmpty(namedTransports)) { + for (NamedClientMcpTransport namedTransport : namedTransports) { + + McpSchema.Implementation clientInfo = new McpSchema.Implementation( + this.configuration.connectedClientName(this.properties.getName(), namedTransport.name()), + namedTransport.name(), this.properties.getVersion()); + + McpClient.SyncSpec spec = McpClient.sync(namedTransport.transport()) + .clientInfo(clientInfo) + .requestTimeout(this.properties.getRequestTimeout()); + + spec = configurer.configure(namedTransport.name(), spec); + + // Apply annotation customizer after other customizers (if available) + if (annotationCustomizer != null) { + annotationCustomizer.customize(namedTransport.name(), spec); + } + + var client = spec.build(); + + if (this.properties.isInitialized()) { + client.initialize(); + } + + mcpSyncClients.add(client); } + } - mcpAsyncClients.add(client); + return mcpSyncClients; + } + + public List getClients() { + if (this.clients == null) { + throw new IllegalStateException( + "MCP sync clients not yet initialized. They are created after all singleton beans are instantiated."); } + return this.clients; } - return mcpAsyncClients; - } + /** + * Returns the timestamp (in nanoseconds) when afterSingletonsInstantiated() was + * called. This can be used in tests to verify SmartInitializingSingleton timing. + * @return the initialization timestamp, or -1 if not yet initialized + */ + public long getInitializationTimestamp() { + return this.initializationTimestamp; + } - @Bean - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public CloseableMcpAsyncClients makeAsyncClientsClosable(List clients) { - return new CloseableMcpAsyncClients(clients); } - @Bean - @ConditionalOnMissingBean - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - McpAsyncClientConfigurer mcpAsyncClientConfigurer(ObjectProvider customizerProvider) { - return new McpAsyncClientConfigurer(customizerProvider.orderedStream().toList()); - } + /** + * Initializer for MCP asynchronous clients that implements + * {@link SmartInitializingSingleton}. + * + *

+ * This class defers the creation of MCP async clients until after all singleton beans + * have been initialized. This ensures that all beans with MCP-annotated methods have + * been scanned and registered in the {@link ClientMcpAnnotatedBeans} registry before + * the specifications are created and clients are configured. + */ + public static class McpAsyncClientInitializer implements SmartInitializingSingleton { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncClientInitializer.class); + + private final McpClientAutoConfiguration configuration; + + private final McpAsyncClientConfigurer configurer; + + private final McpClientCommonProperties properties; + + private final ObjectProvider> transportsProvider; + + private final ObjectProvider annotatedBeansProvider; + + final List clients = new ArrayList<>(); + + public McpAsyncClientInitializer(McpClientAutoConfiguration configuration, McpAsyncClientConfigurer configurer, + McpClientCommonProperties properties, ObjectProvider> transportsProvider, + ObjectProvider annotatedBeansProvider) { + this.configuration = configuration; + this.configurer = configurer; + this.properties = properties; + this.transportsProvider = transportsProvider; + this.annotatedBeansProvider = annotatedBeansProvider; + } + + @Override + public void afterSingletonsInstantiated() { + logger.debug("Creating MCP async clients after all singleton beans have been instantiated"); + + McpAsyncClientCustomizer annotationCustomizer = null; + + // Only create annotation customizer if annotated beans registry is available + ClientMcpAnnotatedBeans annotatedBeans = this.annotatedBeansProvider.getIfAvailable(); + if (annotatedBeans != null) { + // Re-create specifications from the now-complete registry + List loggingSpecs = AsyncMcpAnnotationProviders + .loggingSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List samplingSpecs = AsyncMcpAnnotationProviders + .samplingSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List elicitationSpecs = AsyncMcpAnnotationProviders + .elicitationSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List progressSpecs = AsyncMcpAnnotationProviders + .progressSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List toolListChangedSpecs = AsyncMcpAnnotationProviders + .toolListChangedSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List resourceListChangedSpecs = AsyncMcpAnnotationProviders + .resourceListChangedSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + List promptListChangedSpecs = AsyncMcpAnnotationProviders + .promptListChangedSpecifications(annotatedBeans.getAllAnnotatedBeans()); + + // Create the annotation customizer with fresh specifications + annotationCustomizer = new McpAsyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, + progressSpecs, toolListChangedSpecs, resourceListChangedSpecs, promptListChangedSpecs); + } + + // Create the clients using the base configurer and annotation customizer (if + // available) + List createdClients = createClients(this.configurer, annotationCustomizer); + this.clients.addAll(createdClients); + + logger.info("Created {} MCP async client(s)", this.clients.size()); + } + + private List createClients(McpAsyncClientConfigurer configurer, + McpAsyncClientCustomizer annotationCustomizer) { + List mcpAsyncClients = new ArrayList<>(); + + List namedTransports = this.transportsProvider.stream() + .flatMap(List::stream) + .toList(); + + if (!CollectionUtils.isEmpty(namedTransports)) { + for (NamedClientMcpTransport namedTransport : namedTransports) { + + McpSchema.Implementation clientInfo = new McpSchema.Implementation( + this.configuration.connectedClientName(this.properties.getName(), namedTransport.name()), + this.properties.getVersion()); + + McpClient.AsyncSpec spec = McpClient.async(namedTransport.transport()) + .clientInfo(clientInfo) + .requestTimeout(this.properties.getRequestTimeout()); + + spec = configurer.configure(namedTransport.name(), spec); + + // Apply annotation customizer after other customizers (if available) + if (annotationCustomizer != null) { + annotationCustomizer.customize(namedTransport.name(), spec); + } + + var client = spec.build(); + + if (this.properties.isInitialized()) { + client.initialize().block(); + } + + mcpAsyncClients.add(client); + } + } + + return mcpAsyncClients; + } + + public List getClients() { + if (this.clients == null) { + throw new IllegalStateException( + "MCP async clients not yet initialized. They are created after all singleton beans are instantiated."); + } + return this.clients; + } - @Bean - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - public McpAsyncClientCustomizer mcpAnnotationMcpAsyncClientCustomizer(List loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List toolListChangedSpecs, - List resourceListChangedSpecs, - List promptListChangedSpecs) { - return new McpAsyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - toolListChangedSpecs, resourceListChangedSpecs, promptListChangedSpecs); } /** @@ -313,13 +571,24 @@ public record CloseableMcpSyncClients(List clients) implements Au public void close() { this.clients.forEach(McpSyncClient::close); } + } + /** + * Record class that implements {@link AutoCloseable} to ensure proper cleanup of MCP + * async clients. + * + *

+ * This class is responsible for closing all MCP async clients when the application + * context is closed, preventing resource leaks. + */ public record CloseableMcpAsyncClients(List clients) implements AutoCloseable { + @Override public void close() { this.clients.forEach(McpAsyncClient::close); } + } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index 89143c324c0..8272ace23ca 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -60,9 +60,17 @@ public McpToolNamePrefixGenerator defaultMcpToolNamePrefixGenerator() { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. + * + *

+ * IMPORTANT: This method receives the same list reference that is populated by + * {@link McpClientAutoConfiguration.McpSyncClientInitializer} in its + * {@code afterSingletonsInstantiated()} method. This ensures that when + * {@code getToolCallbacks()} is called, even if it's called before full + * initialization completes, it will eventually see the populated list. * @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to * filter the discovered tools - * @param syncMcpClients provider of MCP sync clients + * @param syncMcpClients the MCP sync clients list (same reference as returned by + * mcpSyncClients() bean method) * @param mcpToolNamePrefixGenerator the tool name prefix generator * @return list of tool callbacks for MCP integration */ @@ -70,15 +78,14 @@ public McpToolNamePrefixGenerator defaultMcpToolNamePrefixGenerator() { @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider syncClientsToolFilter, - ObjectProvider> syncMcpClients, - ObjectProvider mcpToolNamePrefixGenerator, + List syncMcpClients, ObjectProvider mcpToolNamePrefixGenerator, ObjectProvider toolContextToMcpMetaConverter) { - List mcpClients = syncMcpClients.stream().flatMap(List::stream).toList(); - + // Use mcpClientsReference to share the list reference - it will be populated by + // SmartInitializingSingleton return SyncMcpToolCallbackProvider.builder() - .mcpClients(mcpClients) - .toolFilter(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true))) + .mcpClientsReference(syncMcpClients) + .toolFilter(syncClientsToolFilter.getIfUnique((() -> (mcpClient, tool) -> true))) .toolNamePrefixGenerator( mcpToolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.noPrefix())) .toolContextToMcpMetaConverter( @@ -86,19 +93,34 @@ public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider + * IMPORTANT: This method receives the same list reference that is populated by + * {@link McpClientAutoConfiguration.McpAsyncClientInitializer} in its + * {@code afterSingletonsInstantiated()} method. + * @param asyncClientsToolFilter tool filter for async clients + * @param mcpClients the MCP async clients list (same reference as returned by + * mcpAsyncClients() bean method) + * @param toolNamePrefixGenerator the tool name prefix generator + * @param toolContextToMcpMetaConverter converter for tool context to MCP metadata + * @return async tool callback provider + */ @Bean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider asyncClientsToolFilter, - ObjectProvider> mcpClientsProvider, - ObjectProvider toolNamePrefixGenerator, - ObjectProvider toolContextToMcpMetaConverter) { // TODO - List mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList(); + List mcpClients, ObjectProvider toolNamePrefixGenerator, + ObjectProvider toolContextToMcpMetaConverter) { + + // Use mcpClientsReference to share the list reference - it will be populated by + // SmartInitializingSingleton return AsyncMcpToolCallbackProvider.builder() - .toolFilter(asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true)) + .toolFilter(asyncClientsToolFilter.getIfUnique(() -> (mcpClient, tool) -> true)) .toolNamePrefixGenerator(toolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.noPrefix())) .toolContextToMcpMetaConverter( toolContextToMcpMetaConverter.getIfUnique(() -> ToolContextToMcpMetaConverter.defaultConverter())) - .mcpClients(mcpClients) + .mcpClientsReference(mcpClients) .build(); } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java index 620028f0e63..b4e430db0ce 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java @@ -51,13 +51,33 @@ import org.springframework.context.annotation.Configuration; /** + * Auto-configuration for MCP client specification factory. + * + *

+ * Note: This configuration is now obsolete and disabled by default. + * Specification creation has been moved to + * {@link org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration.McpSyncClientInitializer} + * and + * {@link org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration.McpAsyncClientInitializer} + * which use {@link org.springframework.beans.factory.SmartInitializingSingleton} to defer + * client creation until after all singleton beans have been initialized. This ensures + * that all beans with MCP-annotated methods are scanned before specifications are + * created. + * + *

+ * This class is kept for backwards compatibility but can be safely removed in future + * versions. + * * @author Christian Tzolov * @author Fu Jian + * @deprecated Since 1.1.0, specifications are now created dynamically after all singleton + * beans are initialized. This class will be removed in a future release. */ +@Deprecated(since = "1.1.0", forRemoval = true) @AutoConfiguration(after = McpClientAnnotationScannerAutoConfiguration.class) @ConditionalOnClass(McpLogging.class) @ConditionalOnProperty(prefix = McpClientAnnotationScannerProperties.CONFIG_PREFIX, name = "enabled", - havingValue = "true", matchIfMissing = true) + havingValue = "false") // Disabled by default - changed from "true" to "false" public class McpClientSpecificationFactoryAutoConfiguration { @Configuration(proxyBeanMethods = false) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java index 1d1fbb92ae4..d0689d718d2 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java @@ -88,28 +88,39 @@ public class McpClientAutoConfigurationIT { AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class)); /** - * Tests the default MCP client auto-configuration. - * - * Note: We use 'spring.ai.mcp.client.initialized=false' to prevent the - * auto-configuration from calling client.initialize() explicitly, which would cause a - * 20-second timeout waiting for real MCP protocol communication. This allows us to - * test bean creation and auto-configuration behavior without requiring a full MCP - * server connection. + * Tests that MCP clients are created after all singleton beans have been initialized, + * verifying the SmartInitializingSingleton timing behavior. + *

+ * This test uses a LateInitBean that records its initialization timestamp, and then + * verifies that the MCP client initializer was called AFTER the late bean was + * constructed. This proves that + * SmartInitializingSingleton.afterSingletonsInstantiated() is called after all + * singleton beans (including late-initializing ones) have been fully created. */ @Test - void defaultConfiguration() { - this.contextRunner.withUserConfiguration(TestTransportConfiguration.class) + void clientsCreatedAfterAllSingletons() { + this.contextRunner.withUserConfiguration(TestTransportConfiguration.class, LateInitBeanWithTimestamp.class) .withPropertyValues("spring.ai.mcp.client.initialized=false") .run(context -> { + // Get the late-init bean and its construction timestamp + LateInitBeanWithTimestamp lateBean = context.getBean(LateInitBeanWithTimestamp.class); + long lateBeanTimestamp = lateBean.getInitTimestamp(); + + // Get the initializer and its execution timestamp + var initializer = context.getBean(McpClientAutoConfiguration.McpSyncClientInitializer.class); + long initializerTimestamp = initializer.getInitializationTimestamp(); + + // Verify clients were created List clients = context.getBean("mcpSyncClients", List.class); - assertThat(clients).hasSize(1); + assertThat(clients).isNotNull(); - McpClientCommonProperties properties = context.getBean(McpClientCommonProperties.class); - assertThat(properties.getName()).isEqualTo("spring-ai-mcp-client"); - assertThat(properties.getVersion()).isEqualTo("1.0.0"); - assertThat(properties.getType()).isEqualTo(McpClientCommonProperties.ClientType.SYNC); - assertThat(properties.getRequestTimeout()).isEqualTo(Duration.ofSeconds(20)); - assertThat(properties.isInitialized()).isFalse(); + // THE KEY ASSERTION: Initializer must have been called AFTER late bean + // was constructed + // This proves SmartInitializingSingleton.afterSingletonsInstantiated() + // timing + assertThat(initializerTimestamp) + .as("MCP client initializer should be called AFTER all singleton beans are initialized") + .isGreaterThan(lateBeanTimestamp); }); } @@ -224,6 +235,54 @@ void closeableWrappersCreation() { .hasSingleBean(McpClientAutoConfiguration.CloseableMcpSyncClients.class)); } + /** + * Tests that SmartInitializingSingleton initializers are created and function + * correctly for sync clients. + */ + @Test + void smartInitializingSingletonBehavior() { + this.contextRunner.withUserConfiguration(TestTransportConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.initialized=false") + .run(context -> { + // Verify that McpSyncClientInitializer bean exists + assertThat(context).hasBean("mcpSyncClientInitializer"); + assertThat(context.getBean("mcpSyncClientInitializer")) + .isInstanceOf(McpClientAutoConfiguration.McpSyncClientInitializer.class); + + // Verify that clients list exists and was created by initializer + List clients = context.getBean("mcpSyncClients", List.class); + assertThat(clients).isNotNull(); + + // Verify the initializer has completed + var initializer = context.getBean(McpClientAutoConfiguration.McpSyncClientInitializer.class); + assertThat(initializer.getClients()).isSameAs(clients); + }); + } + + /** + * Tests that SmartInitializingSingleton initializers are created and function + * correctly for async clients. + */ + @Test + void smartInitializingSingletonForAsyncClients() { + this.contextRunner.withUserConfiguration(TestTransportConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=ASYNC", "spring.ai.mcp.client.initialized=false") + .run(context -> { + // Verify that McpAsyncClientInitializer bean exists + assertThat(context).hasBean("mcpAsyncClientInitializer"); + assertThat(context.getBean("mcpAsyncClientInitializer")) + .isInstanceOf(McpClientAutoConfiguration.McpAsyncClientInitializer.class); + + // Verify that clients list exists and was created by initializer + List clients = context.getBean("mcpAsyncClients", List.class); + assertThat(clients).isNotNull(); + + // Verify the initializer has completed + var initializer = context.getBean(McpClientAutoConfiguration.McpAsyncClientInitializer.class); + assertThat(initializer.getClients()).isSameAs(clients); + }); + } + @Configuration static class TestTransportConfiguration { @@ -265,6 +324,55 @@ McpSyncClientCustomizer testCustomizer() { } + @Configuration + static class LateInitBean { + + private final boolean initialized; + + LateInitBean() { + // Simulate late initialization + this.initialized = true; + } + + @Bean + String lateInitBean() { + // This bean method ensures the configuration is instantiated + return "late-init-marker"; + } + + boolean isInitialized() { + return this.initialized; + } + + } + + /** + * A configuration bean that records when it was initialized. Used to verify + * SmartInitializingSingleton timing - that the MCP client initializer is called AFTER + * all singleton beans (including this one) have been constructed. + */ + @Configuration + static class LateInitBeanWithTimestamp { + + private final long initTimestamp; + + LateInitBeanWithTimestamp() { + // Record when this bean was constructed + this.initTimestamp = System.nanoTime(); + } + + @Bean + String lateInitMarker() { + // This bean method ensures the configuration is instantiated + return "late-init-marker"; + } + + long getInitTimestamp() { + return this.initTimestamp; + } + + } + static class CustomClientTransport implements McpClientTransport { @Override diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java index d00e3cc6b35..2b85a654c8d 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java @@ -45,45 +45,42 @@ public class McpClientListChangedAnnotationsScanningIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class)); @ParameterizedTest @ValueSource(strings = { "SYNC", "ASYNC" }) void shouldScanAllThreeListChangedAnnotations(String clientType) { - String prefix = clientType.toLowerCase(); - this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) .withPropertyValues("spring.ai.mcp.client.type=" + clientType) .run(context -> { - // Verify all three annotations were scanned + // Verify all three annotations were scanned and registered McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans annotatedBeans = context .getBean(McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans.class); assertThat(annotatedBeans.getBeansByAnnotation(McpToolListChanged.class)).hasSize(1); assertThat(annotatedBeans.getBeansByAnnotation(McpResourceListChanged.class)).hasSize(1); assertThat(annotatedBeans.getBeansByAnnotation(McpPromptListChanged.class)).hasSize(1); - // Verify all three specification beans were created - assertThat(context).hasBean(prefix + "ToolListChangedSpecs"); - assertThat(context).hasBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).hasBean(prefix + "PromptListChangedSpecs"); + // Verify the annotation scanner configuration is present + assertThat(context).hasSingleBean(McpClientAnnotationScannerAutoConfiguration.class); + + // Note: Specification beans are no longer created as separate beans. + // They are now created dynamically in McpClientAutoConfiguration + // initializers + // after all singleton beans have been instantiated. }); } @ParameterizedTest @ValueSource(strings = { "SYNC", "ASYNC" }) void shouldNotScanAnnotationsWhenScannerDisabled(String clientType) { - String prefix = clientType.toLowerCase(); - this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) .withPropertyValues("spring.ai.mcp.client.type=" + clientType, "spring.ai.mcp.client.annotation-scanner.enabled=false") .run(context -> { - // Verify scanner beans were not created + // Verify scanner configuration was not created when disabled assertThat(context).doesNotHaveBean(McpClientAnnotationScannerAutoConfiguration.class); - assertThat(context).doesNotHaveBean(prefix + "ToolListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "PromptListChangedSpecs"); + assertThat(context) + .doesNotHaveBean(McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans.class); }); } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java new file mode 100644 index 00000000000..c0628a6c098 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -0,0 +1,336 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure; + +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpSyncRequestContext; +import org.springaicommunity.mcp.context.StructuredElicitResult; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerStreamableHttpProperties; +import org.springframework.ai.model.anthropic.autoconfigure.AnthropicChatAutoConfiguration; +import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration; +import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.test.util.TestSocketUtils; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") +public class StreamableMcpAnnotationsWithLLMIT { + + private final ApplicationContextRunner serverContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mcp.server.protocol=STREAMABLE") + .withConfiguration(AutoConfigurations.of(McpServerAutoConfiguration.class, + ToolCallbackConverterAutoConfiguration.class, McpServerStreamableHttpWebFluxAutoConfiguration.class, + McpServerAnnotationScannerAutoConfiguration.class, + McpServerSpecificationFactoryAutoConfiguration.class)); + + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() + .withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY")) + .withConfiguration(anthropicAutoConfig(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, + AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class)); + + private static AutoConfigurations anthropicAutoConfig(Class... additional) { + Class[] dependencies = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class, + RestClientAutoConfiguration.class, WebClientAutoConfiguration.class }; + Class[] all = Stream.concat(Arrays.stream(dependencies), Arrays.stream(additional)).toArray(Class[]::new); + return AutoConfigurations.of(all); + } + + private static AtomicInteger toolCouter = new AtomicInteger(0); + + @Test + void clientServerCapabilities() { + + int serverPort = TestSocketUtils.findAvailableTcpPort(); + + this.serverContextRunner.withUserConfiguration(TestMcpServerConfiguration.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.server.name=test-mcp-server", + "spring.ai.mcp.server.version=1.0.0", + "spring.ai.mcp.server.streamable-http.keep-alive-interval=1s", + "spring.ai.mcp.server.streamable-http.mcp-endpoint=/mcp") // @formatter:on + .run(serverContext -> { + // Verify all required beans are present + assertThat(serverContext).hasSingleBean(WebFluxStreamableServerTransportProvider.class); + assertThat(serverContext).hasSingleBean(RouterFunction.class); + assertThat(serverContext).hasSingleBean(McpSyncServer.class); + + // Verify server properties are configured correctly + McpServerProperties properties = serverContext.getBean(McpServerProperties.class); + assertThat(properties.getName()).isEqualTo("test-mcp-server"); + assertThat(properties.getVersion()).isEqualTo("1.0.0"); + + McpServerStreamableHttpProperties streamableHttpProperties = serverContext + .getBean(McpServerStreamableHttpProperties.class); + assertThat(streamableHttpProperties.getMcpEndpoint()).isEqualTo("/mcp"); + assertThat(streamableHttpProperties.getKeepAliveInterval()).isEqualTo(Duration.ofSeconds(1)); + + var httpServer = startHttpServer(serverContext, serverPort); + + this.clientApplicationContext.withUserConfiguration(TestMcpClientConfiguration.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:" + serverPort, + "spring.ai.mcp.client.initialized=false") // @formatter:on + .run(clientContext -> { + + ChatClient.Builder builder = clientContext.getBean(ChatClient.Builder.class); + + ToolCallbackProvider tcp = clientContext.getBean(ToolCallbackProvider.class); + + assertThat(builder).isNotNull(); + + ChatClient chatClient = builder.defaultToolCallbacks(tcp) + .defaultToolContext(Map.of("progressToken", "test-progress-token")) + .build(); + + String cResponse = chatClient.prompt() + .user("What is the weather in Amsterdam right now") + .call() + .content(); + + assertThat(cResponse).isNotEmpty(); + assertThat(cResponse).contains("22"); + + assertThat(toolCouter.get()).isEqualTo(1); + + // PROGRESS + TestMcpClientConfiguration.TestContext testContext = clientContext + .getBean(TestMcpClientConfiguration.TestContext.class); + assertThat(testContext.progressLatch.await(5, TimeUnit.SECONDS)) + .as("Should receive progress notifications in reasonable time") + .isTrue(); + assertThat(testContext.progressNotifications).hasSize(3); + + Map notificationMap = testContext.progressNotifications + .stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("tool call start").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("tool call start").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("tool call start").total()).isEqualTo(1.0); + assertThat(notificationMap.get("tool call start").message()).isEqualTo("tool call start"); + + // Second notification should be 1.0/1.0 progress + assertThat(notificationMap.get("elicitation completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("elicitation completed").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("elicitation completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("elicitation completed").message()) + .isEqualTo("elicitation completed"); + + // Third notification should be 0.5/1.0 progress + assertThat(notificationMap.get("sampling completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("sampling completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").message()).isEqualTo("sampling completed"); + + }); + + stopHttpServer(httpServer); + }); + } + + // Helper methods to start and stop the HTTP server + private static DisposableServer startHttpServer(ApplicationContext serverContext, int port) { + WebFluxStreamableServerTransportProvider mcpStreamableServerTransport = serverContext + .getBean(WebFluxStreamableServerTransportProvider.class); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + return HttpServer.create().port(port).handle(adapter).bindNow(); + } + + private static void stopHttpServer(DisposableServer server) { + if (server != null) { + server.disposeNow(); + } + } + + record ElicitInput(String message) { + } + + public static class TestMcpServerConfiguration { + + @Bean + public McpServerHandlers serverSideSpecProviders() { + return new McpServerHandlers(); + } + + public static class McpServerHandlers { + + @McpTool(description = "Provides weather information by city name") + public String weather(McpSyncRequestContext ctx, @McpToolParam String cityName) { + + toolCouter.incrementAndGet(); + + ctx.info("Weather called!"); + + ctx.progress(p -> p.progress(0.0).total(1.0).message("tool call start")); + + ctx.ping(); // call client ping + + // call elicitation + var elicitationResult = ctx.elicit(e -> e.message("Test message"), ElicitInput.class); + + ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); + + // call sampling + CreateMessageResult samplingResponse = ctx.sample(s -> s.message("Test Sampling Message") + .modelPreferences(pref -> pref.modelHints("OpenAi", "Ollama") + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0))); + + ctx.progress(p -> p.progress(1.0).total(1.0).message("sampling completed")); + + ctx.info("Tool1 Done!"); + + return "Weahter is 22C with rain " + samplingResponse.toString() + ", " + elicitationResult.toString(); + } + + } + + } + + public static class TestMcpClientConfiguration { + + @Bean + public TestContext testContext() { + return new TestContext(); + } + + @Bean + public TestMcpClientHandlers mcpClientHandlers(TestContext testContext) { + return new TestMcpClientHandlers(testContext); + } + + public static class TestContext { + + final AtomicReference loggingNotificationRef = new AtomicReference<>(); + + final CountDownLatch progressLatch = new CountDownLatch(3); + + final List progressNotifications = new CopyOnWriteArrayList<>(); + + } + + public static class TestMcpClientHandlers { + + private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); + + private TestContext testContext; + + public TestMcpClientHandlers(TestContext testContext) { + this.testContext = testContext; + } + + @McpProgress(clients = "server1") + public void progressHandler(ProgressNotification progressNotification) { + logger.info("MCP PROGRESS: [{}] progress: {} total: {} message: {}", + progressNotification.progressToken(), progressNotification.progress(), + progressNotification.total(), progressNotification.message()); + this.testContext.progressNotifications.add(progressNotification); + this.testContext.progressLatch.countDown(); + } + + @McpLogging(clients = "server1") + public void loggingHandler(LoggingMessageNotification loggingMessage) { + this.testContext.loggingNotificationRef.set(loggingMessage); + logger.info("MCP LOGGING: [{}] {}", loggingMessage.level(), loggingMessage.data()); + } + + @McpSampling(clients = "server1") + public CreateMessageResult samplingHandler(CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + + return CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + ElicitInput elicitData = new ElicitInput(request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + } + + } + +} diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index a7dfbc74f40..16bc1318400 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -66,7 +66,11 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC List toolCallbacks, List tcbProviders) { List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); - tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll); + tcbProviders.stream() + .filter(pr -> !pr.getClass().getSimpleName().equals("SyncMcpToolCallbackProvider")) + .filter(pr -> !pr.getClass().getSimpleName().equals("AsyncMcpToolCallbackProvider")) + .map(pr -> List.of(pr.getToolCallbacks())) + .forEach(allFunctionAndToolCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java index af42d744158..7a2ff8da64f 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java @@ -185,6 +185,43 @@ void throwExceptionOnErrorEnabled() { }); } + @Test + void mcpToolCallbackProvidersAreFilteredOut() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(ConfigWithMcpProviders.class) + .run(context -> { + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); + assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class); + + // Regular ToolCallbackProvider should be resolved + assertThat(toolCallbackResolver.resolve("regularTool")).isNotNull(); + assertThat(toolCallbackResolver.resolve("regularTool").getToolDefinition().name()) + .isEqualTo("regularTool"); + + // MCP tools should NOT be resolved (filtered out from static resolver) + // They will be resolved lazily through ChatClient + assertThat(toolCallbackResolver.resolve("syncMcpTool")).isNull(); + assertThat(toolCallbackResolver.resolve("asyncMcpTool")).isNull(); + }); + } + + @Test + void nonMcpToolCallbackProvidersAreIncluded() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(ConfigWithRegularProvider.class) + .run(context -> { + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); + assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class); + + // All tools from regular ToolCallbackProvider should be resolved + assertThat(toolCallbackResolver.resolve("tool1")).isNotNull(); + assertThat(toolCallbackResolver.resolve("tool1").getToolDefinition().name()).isEqualTo("tool1"); + + assertThat(toolCallbackResolver.resolve("tool2")).isNotNull(); + assertThat(toolCallbackResolver.resolve("tool2").getToolDefinition().name()).isEqualTo("tool2"); + }); + } + static class WeatherService { @Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.") @@ -275,4 +312,77 @@ public record Response(String temperature) { } + @Configuration + static class ConfigWithMcpProviders { + + @Bean + public ToolCallbackProvider regularProvider() { + return new StaticToolCallbackProvider(FunctionToolCallback.builder("regularTool", request -> "OK") + .description("Regular tool") + .inputType(Request.class) + .build()); + } + + @Bean + public ToolCallbackProvider syncMcpProvider() { + return new SyncMcpToolCallbackProvider(); + } + + @Bean + public ToolCallbackProvider asyncMcpProvider() { + return new AsyncMcpToolCallbackProvider(); + } + + public record Request(String input) { + } + + } + + @Configuration + static class ConfigWithRegularProvider { + + @Bean + public ToolCallbackProvider multiToolProvider() { + return new StaticToolCallbackProvider( + FunctionToolCallback.builder("tool1", request -> "Result 1") + .description("Tool 1") + .inputType(Request.class) + .build(), + FunctionToolCallback.builder("tool2", request -> "Result 2") + .description("Tool 2") + .inputType(Request.class) + .build()); + } + + public record Request(String input) { + } + + } + + // Mock classes that simulate MCP providers - must match exact names that filter + // checks + static class SyncMcpToolCallbackProvider implements ToolCallbackProvider { + + @Override + public ToolCallback[] getToolCallbacks() { + return new ToolCallback[] { FunctionToolCallback.builder("syncMcpTool", request -> "Sync") + .description("Sync MCP tool") + .inputType(Config.Request.class) + .build() }; + } + + } + + static class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { + + @Override + public ToolCallback[] getToolCallbacks() { + return new ToolCallback[] { FunctionToolCallback.builder("asyncMcpTool", request -> "Async") + .description("Async MCP tool") + .inputType(Config.Request.class) + .build() }; + } + + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 6e67275aa78..a94fd680afd 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -223,7 +223,25 @@ public Builder toolFilter(McpToolFilter toolFilter) { } /** - * Sets MCP clients. + * Sets MCP clients by reference - the list reference will be shared. + *

+ * Use this method when the list will be populated later (e.g., by + * {@code SmartInitializingSingleton}). The provider will see any clients added to + * the list after construction. + * @param mcpClients list of MCP clients (reference will be stored) + * @return this builder + */ + public Builder mcpClientsReference(List mcpClients) { + Assert.notNull(mcpClients, "MCP clients list must not be null"); + this.mcpClients = mcpClients; + return this; + } + + /** + * Sets MCP clients for tool discovery (stores reference directly). + *

+ * Note: Unlike the sync version, this method does not create a defensive copy. + * Use {@link #mcpClientsReference(List)} for clarity when sharing references. * @param mcpClients list of MCP clients * @return this builder */ @@ -237,7 +255,9 @@ public Builder mcpClients(List mcpClients) { * Sets MCP clients. * @param mcpClients MCP clients as varargs * @return this builder + * @deprecated Plese use the mcpClientsReference instead! */ + @Deprecated public Builder mcpClients(McpAsyncClient... mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); this.mcpClients = List.of(mcpClients); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java index 298f07596be..e423eb8101b 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java @@ -191,10 +191,30 @@ public static final class Builder { .defaultConverter(); /** - * Sets MCP clients for tool discovery (replaces existing). + * Sets MCP clients by reference - the list reference will be shared. + *

+ * Use this method when the list will be populated later (e.g., by + * {@code SmartInitializingSingleton}). The provider will see any clients added to + * the list after construction. + * @param mcpClients list of MCP clients (reference will be stored) + * @return this builder + */ + public Builder mcpClientsReference(List mcpClients) { + Assert.notNull(mcpClients, "MCP clients list must not be null"); + this.mcpClients = mcpClients; + return this; + } + + /** + * Sets MCP clients for tool discovery (creates defensive copy). + *

+ * Use this method when passing a fully populated, immutable list. A defensive + * copy will be created to prevent external modifications. * @param mcpClients list of MCP clients * @return this builder + * @deprecated Plese use the mcpClientsReference instead! */ + @Deprecated public Builder mcpClients(List mcpClients) { Assert.notNull(mcpClients, "MCP clients list must not be null"); this.mcpClients = new ArrayList<>(mcpClients); @@ -205,7 +225,9 @@ public Builder mcpClients(List mcpClients) { * Sets MCP clients for tool discovery (replaces existing). * @param mcpClients MCP clients array * @return this builder + * @deprecated Plese use the mcpClientsReference instead! */ + @Deprecated public Builder mcpClients(McpSyncClient... mcpClients) { Assert.notNull(mcpClients, "MCP clients array must not be null"); this.mcpClients = new java.util.ArrayList<>(List.of(mcpClients)); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProviderListReferenceTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProviderListReferenceTest.java new file mode 100644 index 00000000000..7e5f2996786 --- /dev/null +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProviderListReferenceTest.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.ArrayList; +import java.util.List; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.ai.tool.ToolCallback; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests that {@link AsyncMcpToolCallbackProvider} correctly maintains list references + * when using {@code mcpClientsReference()}. + * + * @author Christian Tzolov + */ +class AsyncMcpToolCallbackProviderListReferenceTest { + + @Test + void testMcpClientsReferenceSharesList() { + // Create an empty list that will be populated later + List clientsList = new ArrayList<>(); + + // Create provider with reference to empty list + AsyncMcpToolCallbackProvider provider = AsyncMcpToolCallbackProvider.builder() + .mcpClientsReference(clientsList) + .build(); + + // Initially, no tool callbacks should be available + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).isEmpty(); + + // Now simulate SmartInitializingSingleton populating the list + McpAsyncClient mockClient = mock(McpAsyncClient.class); + McpSchema.Tool mockTool = mock(McpSchema.Tool.class); + when(mockTool.name()).thenReturn("test_tool"); + when(mockTool.description()).thenReturn("Test tool"); + + McpSchema.ListToolsResult toolsResult = mock(McpSchema.ListToolsResult.class); + when(toolsResult.tools()).thenReturn(List.of(mockTool)); + when(mockClient.listTools()).thenReturn(Mono.just(toolsResult)); + + // Mock connection info + when(mockClient.getClientCapabilities()).thenReturn(mock(McpSchema.ClientCapabilities.class)); + when(mockClient.getClientInfo()).thenReturn(mock(McpSchema.Implementation.class)); + when(mockClient.getCurrentInitializationResult()).thenReturn(mock(McpSchema.InitializeResult.class)); + + clientsList.add(mockClient); + + // Now the provider should see the client and return tool callbacks + callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(1); + assertThat(callbacks[0].getToolDefinition().name()).isEqualTo("test_tool"); + } + + @Test + void testMcpClientsStoresReference() { + // Create a list with a client + McpAsyncClient mockClient = mock(McpAsyncClient.class); + McpSchema.Tool mockTool = mock(McpSchema.Tool.class); + when(mockTool.name()).thenReturn("test_tool"); + when(mockTool.description()).thenReturn("Test tool"); + + McpSchema.ListToolsResult toolsResult = mock(McpSchema.ListToolsResult.class); + when(toolsResult.tools()).thenReturn(List.of(mockTool)); + when(mockClient.listTools()).thenReturn(Mono.just(toolsResult)); + + // Mock connection info + when(mockClient.getClientCapabilities()).thenReturn(mock(McpSchema.ClientCapabilities.class)); + when(mockClient.getClientInfo()).thenReturn(mock(McpSchema.Implementation.class)); + when(mockClient.getCurrentInitializationResult()).thenReturn(mock(McpSchema.InitializeResult.class)); + + List clientsList = new ArrayList<>(); + clientsList.add(mockClient); + + // Create provider - async version stores reference (no defensive copy) + AsyncMcpToolCallbackProvider provider = AsyncMcpToolCallbackProvider.builder().mcpClients(clientsList).build(); + + // Provider should see the initial client + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(1); + + // Add another client to the original list + McpAsyncClient mockClient2 = mock(McpAsyncClient.class); + McpSchema.Tool mockTool2 = mock(McpSchema.Tool.class); + when(mockTool2.name()).thenReturn("test_tool_2"); + when(mockTool2.description()).thenReturn("Test tool 2"); + + McpSchema.ListToolsResult toolsResult2 = mock(McpSchema.ListToolsResult.class); + when(toolsResult2.tools()).thenReturn(List.of(mockTool2)); + when(mockClient2.listTools()).thenReturn(Mono.just(toolsResult2)); + + when(mockClient2.getClientCapabilities()).thenReturn(mock(McpSchema.ClientCapabilities.class)); + when(mockClient2.getClientInfo()).thenReturn(mock(McpSchema.Implementation.class)); + when(mockClient2.getCurrentInitializationResult()).thenReturn(mock(McpSchema.InitializeResult.class)); + + clientsList.add(mockClient2); + + // Provider shares the reference, so should see both clients + callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(2); + } + +} diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderListReferenceTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderListReferenceTest.java new file mode 100644 index 00000000000..b5c3aae69a9 --- /dev/null +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderListReferenceTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp; + +import java.util.ArrayList; +import java.util.List; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.tool.ToolCallback; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests that {@link SyncMcpToolCallbackProvider} correctly maintains list references when + * using {@code mcpClientsReference()}. + * + * @author Christian Tzolov + */ +class SyncMcpToolCallbackProviderListReferenceTest { + + @Test + void testMcpClientsReferenceSharesList() { + // Create an empty list that will be populated later + List clientsList = new ArrayList<>(); + + // Create provider with reference to empty list + SyncMcpToolCallbackProvider provider = SyncMcpToolCallbackProvider.builder() + .mcpClientsReference(clientsList) + .build(); + + // Initially, no tool callbacks should be available + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).isEmpty(); + + // Now simulate SmartInitializingSingleton populating the list + McpSyncClient mockClient = mock(McpSyncClient.class); + McpSchema.Tool mockTool = mock(McpSchema.Tool.class); + when(mockTool.name()).thenReturn("test_tool"); + when(mockTool.description()).thenReturn("Test tool"); + + McpSchema.ListToolsResult toolsResult = mock(McpSchema.ListToolsResult.class); + when(toolsResult.tools()).thenReturn(List.of(mockTool)); + when(mockClient.listTools()).thenReturn(toolsResult); + + // Mock connection info + when(mockClient.getClientCapabilities()).thenReturn(mock(McpSchema.ClientCapabilities.class)); + when(mockClient.getClientInfo()).thenReturn(mock(McpSchema.Implementation.class)); + when(mockClient.getCurrentInitializationResult()).thenReturn(mock(McpSchema.InitializeResult.class)); + + clientsList.add(mockClient); + + // Now the provider should see the client and return tool callbacks + callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(1); + assertThat(callbacks[0].getToolDefinition().name()).isEqualTo("test_tool"); + } + + @Test + void testMcpClientsCreatesCopy() { + // Create a list with a client + McpSyncClient mockClient = mock(McpSyncClient.class); + McpSchema.Tool mockTool = mock(McpSchema.Tool.class); + when(mockTool.name()).thenReturn("test_tool"); + when(mockTool.description()).thenReturn("Test tool"); + + McpSchema.ListToolsResult toolsResult = mock(McpSchema.ListToolsResult.class); + when(toolsResult.tools()).thenReturn(List.of(mockTool)); + when(mockClient.listTools()).thenReturn(toolsResult); + + // Mock connection info + when(mockClient.getClientCapabilities()).thenReturn(mock(McpSchema.ClientCapabilities.class)); + when(mockClient.getClientInfo()).thenReturn(mock(McpSchema.Implementation.class)); + when(mockClient.getCurrentInitializationResult()).thenReturn(mock(McpSchema.InitializeResult.class)); + + List clientsList = new ArrayList<>(); + clientsList.add(mockClient); + + // Create provider with defensive copy + SyncMcpToolCallbackProvider provider = SyncMcpToolCallbackProvider.builder().mcpClients(clientsList).build(); + + // Provider should see the initial client + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(1); + + // Clear the original list + clientsList.clear(); + + // Provider still has its copy, so should still return the tool + callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(1); + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 20b207d5c5c..6b740c25821 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -618,6 +618,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List toolCallbacks = new ArrayList<>(); + private final List toolCallbackProviders = new ArrayList<>(); + private final List messages = new ArrayList<>(); private final Map userParams = new HashMap<>(); @@ -648,16 +650,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams, - ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, - ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, - ccr.toolContext, ccr.templateRenderer); + ccr.systemMetadata, ccr.toolCallbacks, ccr.toolCallbackProviders, ccr.messages, ccr.toolNames, + ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry, + ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, Map userMetadata, @Nullable String systemText, Map systemParams, Map systemMetadata, List toolCallbacks, - List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, - List advisors, Map advisorParams, ObservationRegistry observationRegistry, + List toolCallbackProviders, List messages, List toolNames, + List media, @Nullable ChatOptions chatOptions, List advisors, + Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { @@ -667,6 +670,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe Assert.notNull(systemParams, "systemParams cannot be null"); Assert.notNull(systemMetadata, "systemMetadata cannot be null"); Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.notNull(messages, "messages cannot be null"); Assert.notNull(toolNames, "toolNames cannot be null"); Assert.notNull(media, "media cannot be null"); @@ -689,6 +693,7 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); + this.toolCallbackProviders.addAll(toolCallbackProviders); this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); @@ -885,9 +890,10 @@ public ChatClientRequestSpec tools(Object... toolObjects) { public ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) { Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements"); - for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { - this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); - } + // Store providers for lazy resolution - don't call getToolCallbacks() yet! + // This allows providers that depend on SmartInitializingSingleton to complete + // their initialization before tool callbacks are resolved. + this.toolCallbackProviders.addAll(Arrays.asList(toolCallbackProviders)); return this; } @@ -988,6 +994,8 @@ public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) @Override public CallResponseSpec call() { + // Resolve tool callbacks lazily before building the request + resolveToolCallbacksBeforeExecution(); BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); @@ -995,11 +1003,30 @@ public CallResponseSpec call() { @Override public StreamResponseSpec stream() { + // Resolve tool callbacks lazily before building the request + resolveToolCallbacksBeforeExecution(); BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); } + /** + * Resolves tool callback providers and adds the results to the toolCallbacks + * list. This method should be called right before execution (call/stream) to + * ensure that all providers have had a chance to complete their initialization, + * including those that depend on SmartInitializingSingleton. + */ + private void resolveToolCallbacksBeforeExecution() { + if (!this.toolCallbackProviders.isEmpty()) { + // Resolve all providers and add their callbacks + for (ToolCallbackProvider provider : this.toolCallbackProviders) { + this.toolCallbacks.addAll(List.of(provider.getToolCallbacks())); + } + // Clear providers list to avoid re-processing on subsequent calls + this.toolCallbackProviders.clear(); + } + } + private BaseAdvisorChain buildAdvisorChain() { // At the stack bottom add the model call advisors. // They play the role of the last advisors in the advisor chain. diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index a937356e543..6778dc222e5 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -65,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(), - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, - customObservationConvention, Map.of(), null); + Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + observationRegistry, customObservationConvention, Map.of(), null); } public ChatClient build() { diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 07adcf72b48..45a1342503b 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -30,6 +30,7 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -50,6 +51,7 @@ import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.support.DefaultConversionService; @@ -61,12 +63,16 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** * Unit tests for {@link DefaultChatClient}. * * @author Thomas Vitale + * @author Christian Tzolov */ class DefaultChatClientTests { @@ -1474,15 +1480,15 @@ void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( chatModel, null, Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), - List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); + List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of(), null); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), Map.of(), - null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), - ObservationRegistry.NOOP, null, Map.of(), null)) + null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), + Map.of(), ObservationRegistry.NOOP, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); } @@ -1490,8 +1496,8 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), null, - List.of(), Map.of(), null, null, Map.of(), null)) + Map.of(), Map.of(), null, Map.of(), Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), + null, List.of(), Map.of(), null, null, Map.of(), null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); } @@ -2197,6 +2203,135 @@ void whenUserConsumerWithNullParamValueThenThrow() { .hasMessage("value cannot be null"); } + @Test + void whenToolCallbackProvidersAddedThenStored() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ToolCallbackProvider provider = mock(ToolCallbackProvider.class); + + ChatClient.ChatClientRequestSpec spec = chatClient.prompt().toolCallbacks(provider); + + // Verify provider was stored (not resolved yet) + assertThat(spec).isInstanceOf(DefaultChatClient.DefaultChatClientRequestSpec.class); + } + + @Test + void whenToolCallbackProvidersResolvedLazily() { + ChatModel chatModel = mock(ChatModel.class); + given(chatModel.call(ArgumentMatchers.any(Prompt.class))) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ToolCallback mockToolCallback = mock(ToolCallback.class); + ToolCallbackProvider mockProvider = mock(ToolCallbackProvider.class); + when(mockProvider.getToolCallbacks()).thenReturn(new ToolCallback[] { mockToolCallback }); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + + // Add provider during configuration + ChatClient.ChatClientRequestSpec spec = chatClient.prompt("question").toolCallbacks(mockProvider); + + // Provider should NOT be resolved yet (getToolCallbacks not called during + // configuration) + verify(mockProvider, never()).getToolCallbacks(); + + // Execute call - this should trigger lazy resolution + spec.call().content(); + + // NOW provider should be resolved (getToolCallbacks called during execution) + verify(mockProvider, times(1)).getToolCallbacks(); + } + + @Test + void whenMultipleToolCallbackProvidersResolvedLazily() { + ChatModel chatModel = mock(ChatModel.class); + given(chatModel.call(org.mockito.ArgumentMatchers.any(Prompt.class))) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + ToolCallback mockToolCallback1 = mock(ToolCallback.class); + ToolCallback mockToolCallback2 = mock(ToolCallback.class); + + ToolCallbackProvider mockProvider1 = mock(ToolCallbackProvider.class); + when(mockProvider1.getToolCallbacks()).thenReturn(new ToolCallback[] { mockToolCallback1 }); + + ToolCallbackProvider mockProvider2 = mock(ToolCallbackProvider.class); + when(mockProvider2.getToolCallbacks()).thenReturn(new ToolCallback[] { mockToolCallback2 }); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + + // Add multiple providers + ChatClient.ChatClientRequestSpec spec = chatClient.prompt("question") + .toolCallbacks(mockProvider1, mockProvider2); + + // Execute call + spec.call().content(); + + // Both providers should be resolved + verify(mockProvider1, times(1)).getToolCallbacks(); + verify(mockProvider2, times(1)).getToolCallbacks(); + } + + @Test + void whenToolCallbackProvidersResolvedLazilyInStream() { + ChatModel chatModel = mock(ChatModel.class); + given(chatModel.stream(org.mockito.ArgumentMatchers.any(Prompt.class))) + .willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))))); + + ToolCallback mockToolCallback = mock(ToolCallback.class); + ToolCallbackProvider mockProvider = mock(ToolCallbackProvider.class); + when(mockProvider.getToolCallbacks()).thenReturn(new ToolCallback[] { mockToolCallback }); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + + // Add provider during configuration + ChatClient.ChatClientRequestSpec spec = chatClient.prompt("question").toolCallbacks(mockProvider); + + // Provider should NOT be resolved yet + verify(mockProvider, never()).getToolCallbacks(); + + // Execute stream - this should trigger lazy resolution + spec.stream().content().blockLast(); + + // NOW provider should be resolved + verify(mockProvider, times(1)).getToolCallbacks(); + } + + @Test + void whenToolCallbackProviderIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + + assertThatThrownBy(() -> spec.toolCallbacks((ToolCallbackProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolCallbackProviders cannot contain null elements"); + } + + @Test + void whenToolCallbackProvidersWithMixedCallbacksAndProviders() { + ChatModel chatModel = mock(ChatModel.class); + given(chatModel.call(org.mockito.ArgumentMatchers.any(Prompt.class))) + .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); + + // Direct callback + ToolCallback directCallback = mock(ToolCallback.class); + + // Provider callback + ToolCallback providerCallback = mock(ToolCallback.class); + ToolCallbackProvider mockProvider = mock(ToolCallbackProvider.class); + when(mockProvider.getToolCallbacks()).thenReturn(new ToolCallback[] { providerCallback }); + + ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build(); + + // Add both direct callbacks and providers + ChatClient.ChatClientRequestSpec spec = chatClient.prompt("question") + .toolCallbacks(directCallback) + .toolCallbacks(mockProvider); + + // Execute call + spec.call().content(); + + // Provider should be resolved + verify(mockProvider, times(1)).getToolCallbacks(); + } + record Person(String name) { }