diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/FluxToolSpecificationPostProcessor.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/FluxToolSpecificationPostProcessor.java new file mode 100644 index 00000000000..f4bb20742da --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/FluxToolSpecificationPostProcessor.java @@ -0,0 +1,244 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.common.autoconfigure.annotations; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.util.ReflectionUtils; + +/** + * Post-processor that wraps AsyncToolSpecifications to handle Flux return types properly + * by collecting all elements before serialization. + * + *

+ * Background: This class fixes Issue #4542 where Flux-returning @McpTool + * methods only return the first element. The root cause is in the external {@code + * org.springaicommunity.mcp.provider.tool.AsyncStatelessMcpToolProvider} library, which + * treats Flux as a single-value Publisher and only takes the first element. + * + *

+ * Solution: This post-processor intercepts tool specifications and wraps + * their call handlers. When a method returns a Flux, it collects all elements into a list + * before passing the result to the MCP serialization layer. + * + *

+ * Note: Users can also work around this issue by returning {@code + * Mono>} instead of {@code Flux} from their {@code @McpTool} methods. + * + * @author liugddx + * @since 1.1.0 + * @see Issue #4542 + */ +public final class FluxToolSpecificationPostProcessor { + + private static final Logger logger = LoggerFactory.getLogger(FluxToolSpecificationPostProcessor.class); + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private FluxToolSpecificationPostProcessor() { + // Utility class - no instances allowed + } + + /** + * Wraps tool specifications to properly handle Flux return types by collecting all + * elements into a list. + * @param originalSpecs the original tool specifications from the annotation provider + * @param toolBeans the bean objects containing @McpTool annotated methods + * @return wrapped tool specifications that properly collect Flux elements + */ + public static List processToolSpecifications( + List originalSpecs, List toolBeans) { + + List processedSpecs = new ArrayList<>(); + + for (McpStatelessServerFeatures.AsyncToolSpecification spec : originalSpecs) { + ToolMethodInfo methodInfo = findToolMethod(toolBeans, spec.tool().name()); + if (methodInfo != null && methodInfo.returnsFlux()) { + logger.info("Detected Flux return type for MCP tool '{}', applying collection wrapper", + spec.tool().name()); + McpStatelessServerFeatures.AsyncToolSpecification wrappedSpec = wrapToolSpecificationForFlux(spec, + methodInfo); + processedSpecs.add(wrappedSpec); + } + else { + processedSpecs.add(spec); + } + } + + return processedSpecs; + } + + /** + * Finds the method annotated with @McpTool that matches the given tool name. + * @param toolBeans the bean objects containing @McpTool annotated methods + * @param toolName the name of the tool to find + * @return the ToolMethodInfo object, or null if not found + */ + private static ToolMethodInfo findToolMethod(List toolBeans, String toolName) { + for (Object bean : toolBeans) { + Class clazz = bean.getClass(); + Method[] methods = ReflectionUtils.getAllDeclaredMethods(clazz); + for (Method method : methods) { + McpTool annotation = method.getAnnotation(McpTool.class); + if (annotation != null && annotation.name().equals(toolName)) { + return new ToolMethodInfo(bean, method); + } + } + } + return null; + } + + /** + * Wraps a tool specification to collect all Flux elements before serialization. + * @param original the original tool specification + * @param methodInfo the method information including bean and method + * @return the wrapped tool specification + */ + private static McpStatelessServerFeatures.AsyncToolSpecification wrapToolSpecificationForFlux( + McpStatelessServerFeatures.AsyncToolSpecification original, ToolMethodInfo methodInfo) { + + BiFunction> originalHandler = original + .callHandler(); + + BiFunction> wrappedHandler = ( + context, request) -> { + try { + // Invoke the method directly to get access to the Flux + Object[] args = buildMethodArguments(methodInfo.method(), request.arguments()); + Object result = ReflectionUtils.invokeMethod(methodInfo.method(), methodInfo.bean(), args); + + if (result instanceof Flux) { + // Collect all Flux elements into a list + Flux flux = (Flux) result; + return flux.collectList().flatMap(list -> { + // Serialize the list to JSON + try { + String jsonContent = objectMapper.writeValueAsString(list); + return Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent(jsonContent)), false)); + } + catch (Exception e) { + logger.error("Failed to serialize Flux result for tool '{}'", original.tool().name(), e); + return Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Error: " + e.getMessage())), true)); + } + }); + } + else { + // Fall back to original handler for non-Flux results + return originalHandler.apply(context, request); + } + } + catch (Exception e) { + logger.error("Failed to invoke tool method '{}'", original.tool().name(), e); + return Mono.just(new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Error: " + e.getMessage())), true)); + } + }; + + return new McpStatelessServerFeatures.AsyncToolSpecification(original.tool(), wrappedHandler); + } + + /** + * Builds method arguments from the request arguments map. + * @param method the method to invoke + * @param requestArgs the arguments from the CallToolRequest + * @return array of method arguments + */ + private static Object[] buildMethodArguments(Method method, Map requestArgs) { + java.lang.reflect.Parameter[] parameters = method.getParameters(); + Object[] args = new Object[parameters.length]; + + for (int i = 0; i < parameters.length; i++) { + java.lang.reflect.Parameter param = parameters[i]; + McpToolParam paramAnnotation = param.getAnnotation(McpToolParam.class); + + if (paramAnnotation != null) { + String paramName = paramAnnotation.name().isEmpty() ? param.getName() : paramAnnotation.name(); + Object value = requestArgs.get(paramName); + + // Type conversion if needed + if (value != null) { + args[i] = objectMapper.convertValue(value, param.getType()); + } + else if (!paramAnnotation.required()) { + args[i] = null; + } + else { + throw new IllegalArgumentException("Required parameter '" + paramName + "' is missing"); + } + } + else { + // Try to match by parameter name + Object value = requestArgs.get(param.getName()); + if (value != null) { + args[i] = objectMapper.convertValue(value, param.getType()); + } + else { + args[i] = null; + } + } + } + + return args; + } + + /** + * Holds information about a tool method. + */ + private static class ToolMethodInfo { + + private final Object bean; + + private final Method method; + + ToolMethodInfo(Object bean, Method method) { + this.bean = bean; + this.method = method; + ReflectionUtils.makeAccessible(method); + } + + Object bean() { + return this.bean; + } + + Method method() { + return this.method; + } + + boolean returnsFlux() { + return Flux.class.isAssignableFrom(this.method.getReturnType()); + } + + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/StatelessServerSpecificationFactoryAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/StatelessServerSpecificationFactoryAutoConfiguration.java index 97d01f82280..8f9430a55f8 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/StatelessServerSpecificationFactoryAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/annotations/StatelessServerSpecificationFactoryAutoConfiguration.java @@ -127,8 +127,12 @@ public List completionS @Bean public List toolSpecs( ServerMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return AsyncMcpAnnotationProviders - .statelessToolSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class)); + List toolBeans = beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class); + List originalSpecs = AsyncMcpAnnotationProviders + .statelessToolSpecifications(toolBeans); + + // Apply post-processing to handle Flux return types (Issue #4542) + return FluxToolSpecificationPostProcessor.processToolSpecifications(originalSpecs, toolBeans); } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/test/java/org/springframework/ai/mcp/server/common/autoconfigure/FluxReturnTypeIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/test/java/org/springframework/ai/mcp/server/common/autoconfigure/FluxReturnTypeIT.java new file mode 100644 index 00000000000..670164c4b66 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/test/java/org/springframework/ai/mcp/server/common/autoconfigure/FluxReturnTypeIT.java @@ -0,0 +1,257 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.common.autoconfigure; + +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.stereotype.Component; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration test to verify the fix for Issue #4542: Stateless Async MCP Server with + * streamable-http returns only the first element from tools with a Flux return type. + * + * @author liugddx + */ +public class FluxReturnTypeIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(McpServerStatelessAutoConfiguration.class, + McpServerAnnotationScannerAutoConfiguration.class, + StatelessToolCallbackConverterAutoConfiguration.class)); + + private final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * This test verifies that @McpTool methods returning Flux now properly return all + * elements after the fix. + */ + @Test + void testFluxReturnTypeReturnsAllElements() { + this.contextRunner.withUserConfiguration(FluxToolConfiguration.class) + .withPropertyValues("spring.ai.mcp.server.type=ASYNC", "spring.ai.mcp.server.protocol=STATELESS", + "spring.ai.mcp.server.annotation.enabled=true") + .run(context -> { + assertThat(context).hasBean("fluxTestTools"); + + // Get the tool specifications + List toolSpecs = context.getBean("toolSpecs", + List.class); + assertThat(toolSpecs).isNotEmpty(); + + // Find the flux-test tool + McpStatelessServerFeatures.AsyncToolSpecification fluxTestTool = toolSpecs.stream() + .filter(spec -> spec.tool().name().equals("flux-test")) + .findFirst() + .orElseThrow(() -> new AssertionError("flux-test tool not found")); + + // Call the tool with count=3 + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("flux-test") + .arguments(Map.of("count", 3)) + .build(); + + McpSchema.CallToolResult result = fluxTestTool.callHandler() + .apply(new McpTransportContext(Map.of()), request) + .block(); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + + // Verify all three elements are present in the result + String content = ((McpSchema.TextContent) result.content().get(0)).text(); + List items = this.objectMapper.readValue(content, List.class); + + assertThat(items).containsExactly("item-1", "item-2", "item-3"); + }); + } + + /** + * This test verifies that @McpTool methods returning Flux with complex + * objects properly return all elements. + */ + @Test + void testFluxReturnTypeWithComplexObjects() { + this.contextRunner.withUserConfiguration(FluxToolConfiguration.class) + .withPropertyValues("spring.ai.mcp.server.type=ASYNC", "spring.ai.mcp.server.protocol=STATELESS", + "spring.ai.mcp.server.annotation.enabled=true") + .run(context -> { + assertThat(context).hasBean("fluxTestTools"); + + List toolSpecs = context.getBean("toolSpecs", + List.class); + + McpStatelessServerFeatures.AsyncToolSpecification fluxDataTool = toolSpecs.stream() + .filter(spec -> spec.tool().name().equals("flux-data-stream")) + .findFirst() + .orElseThrow(() -> new AssertionError("flux-data-stream tool not found")); + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("flux-data-stream") + .arguments(Map.of("category", "test")) + .build(); + + McpSchema.CallToolResult result = fluxDataTool.callHandler() + .apply(new McpTransportContext(Map.of()), request) + .block(); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + + String content = ((McpSchema.TextContent) result.content().get(0)).text(); + List> items = this.objectMapper.readValue(content, List.class); + + assertThat(items).hasSize(3); + assertThat(items.get(0)).containsEntry("id", "id1"); + assertThat(items.get(1)).containsEntry("id", "id2"); + assertThat(items.get(2)).containsEntry("id", "id3"); + }); + } + + /** + * This test demonstrates that the workaround using Mono> continues to work + * properly. + */ + @Test + void testMonoListWorkaround() { + this.contextRunner.withUserConfiguration(MonoListToolConfiguration.class) + .withPropertyValues("spring.ai.mcp.server.type=ASYNC", "spring.ai.mcp.server.protocol=STATELESS", + "spring.ai.mcp.server.annotation.enabled=true") + .run(context -> { + assertThat(context).hasBean("monoListTestTools"); + + List toolSpecs = context.getBean("toolSpecs", + List.class); + + McpStatelessServerFeatures.AsyncToolSpecification monoListTool = toolSpecs.stream() + .filter(spec -> spec.tool().name().equals("mono-list-test")) + .findFirst() + .orElseThrow(() -> new AssertionError("mono-list-test tool not found")); + + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("mono-list-test") + .arguments(Map.of("count", 3)) + .build(); + + McpSchema.CallToolResult result = monoListTool.callHandler() + .apply(new McpTransportContext(Map.of()), request) + .block(); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + + // The workaround should also return all three elements + String content = ((McpSchema.TextContent) result.content().get(0)).text(); + List items = this.objectMapper.readValue(content, List.class); + + assertThat(items).containsExactly("item-1", "item-2", "item-3"); + }); + } + + @Configuration + static class FluxToolConfiguration { + + @Bean + FluxTestTools fluxTestTools() { + return new FluxTestTools(); + } + + } + + @Component + static class FluxTestTools { + + /** + * This method demonstrates the bug: it returns Flux but only the first + * element is returned to the client. + */ + @McpTool(name = "flux-test", description = "Test Flux return type - BUGGY") + public Flux getMultipleItems( + @McpToolParam(description = "Number of items to return", required = true) int count) { + return Flux.range(1, count).map(i -> "item-" + i); + } + + /** + * This method also demonstrates the bug with a more realistic streaming scenario. + */ + @McpTool(name = "flux-data-stream", description = "Stream data items - BUGGY") + public Flux streamDataItems( + @McpToolParam(description = "Category to filter", required = false) String category) { + return Flux.just(new DataItem("id1", "Item 1", category), new DataItem("id2", "Item 2", category), + new DataItem("id3", "Item 3", category)); + } + + } + + @Configuration + static class MonoListToolConfiguration { + + @Bean + MonoListTestTools monoListTestTools() { + return new MonoListTestTools(); + } + + } + + @Component + static class MonoListTestTools { + + /** + * WORKAROUND: Use Mono> instead of Flux to return all elements. + */ + @McpTool(name = "mono-list-test", description = "Test Mono workaround") + public Mono> getMultipleItems( + @McpToolParam(description = "Number of items to return", required = true) int count) { + return Flux.range(1, count).map(i -> "item-" + i).collectList(); + } + + /** + * WORKAROUND: Collect Flux elements into a list before returning. + */ + @McpTool(name = "mono-list-data-stream", description = "Get data items as list") + public Mono> getDataItems( + @McpToolParam(description = "Category to filter", required = false) String category) { + return Flux + .just(new DataItem("id1", "Item 1", category), new DataItem("id2", "Item 2", category), + new DataItem("id3", "Item 3", category)) + .collectList(); + } + + } + + record DataItem(String id, String name, String category) { + } + +}