Skip to content

Commit 0f2ebb3

Browse files
committed
feat: add integration tests and post-processor for handling Flux return types in MCP tools
Signed-off-by: liugddx <[email protected]>
1 parent a19d9b6 commit 0f2ebb3

File tree

2 files changed

+507
-0
lines changed

2 files changed

+507
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp.server.common.autoconfigure.annotations;
18+
19+
import java.lang.reflect.Method;
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.function.BiFunction;
24+
25+
import com.fasterxml.jackson.databind.ObjectMapper;
26+
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
27+
import io.modelcontextprotocol.spec.McpSchema;
28+
import io.modelcontextprotocol.spec.McpTransportContext;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
import org.springaicommunity.mcp.annotation.McpTool;
32+
import org.springaicommunity.mcp.annotation.McpToolParam;
33+
import reactor.core.publisher.Flux;
34+
import reactor.core.publisher.Mono;
35+
36+
import org.springframework.util.ReflectionUtils;
37+
38+
/**
39+
* Post-processor that wraps AsyncToolSpecifications to handle Flux return types properly
40+
* by collecting all elements before serialization.
41+
*
42+
* <p>
43+
* <strong>Background:</strong> This class fixes Issue #4542 where Flux-returning @McpTool
44+
* methods only return the first element. The root cause is in the external {@code
45+
* org.springaicommunity.mcp.provider.tool.AsyncStatelessMcpToolProvider} library, which
46+
* treats Flux as a single-value Publisher and only takes the first element.
47+
*
48+
* <p>
49+
* <strong>Solution:</strong> This post-processor intercepts tool specifications and wraps
50+
* their call handlers. When a method returns a Flux, it collects all elements into a list
51+
* before passing the result to the MCP serialization layer.
52+
*
53+
* <p>
54+
* <strong>Note:</strong> Users can also work around this issue by returning {@code
55+
* Mono<List<T>>} instead of {@code Flux<T>} from their {@code @McpTool} methods.
56+
*
57+
* @author liugddx
58+
* @since 1.1.0
59+
* @see <a href="https://github.com/spring-projects/spring-ai/issues/4542">Issue #4542</a>
60+
*/
61+
public final class FluxToolSpecificationPostProcessor {
62+
63+
private static final Logger logger = LoggerFactory.getLogger(FluxToolSpecificationPostProcessor.class);
64+
65+
private static final ObjectMapper objectMapper = new ObjectMapper();
66+
67+
private FluxToolSpecificationPostProcessor() {
68+
// Utility class - no instances allowed
69+
}
70+
71+
/**
72+
* Wraps tool specifications to properly handle Flux return types by collecting all
73+
* elements into a list.
74+
* @param originalSpecs the original tool specifications from the annotation provider
75+
* @param toolBeans the bean objects containing @McpTool annotated methods
76+
* @return wrapped tool specifications that properly collect Flux elements
77+
*/
78+
public static List<McpStatelessServerFeatures.AsyncToolSpecification> processToolSpecifications(
79+
List<McpStatelessServerFeatures.AsyncToolSpecification> originalSpecs, List<Object> toolBeans) {
80+
81+
List<McpStatelessServerFeatures.AsyncToolSpecification> processedSpecs = new ArrayList<>();
82+
83+
for (McpStatelessServerFeatures.AsyncToolSpecification spec : originalSpecs) {
84+
ToolMethodInfo methodInfo = findToolMethod(toolBeans, spec.tool().name());
85+
if (methodInfo != null && methodInfo.returnsFlux()) {
86+
logger.info("Detected Flux return type for MCP tool '{}', applying collection wrapper",
87+
spec.tool().name());
88+
McpStatelessServerFeatures.AsyncToolSpecification wrappedSpec = wrapToolSpecificationForFlux(spec,
89+
methodInfo);
90+
processedSpecs.add(wrappedSpec);
91+
}
92+
else {
93+
processedSpecs.add(spec);
94+
}
95+
}
96+
97+
return processedSpecs;
98+
}
99+
100+
/**
101+
* Finds the method annotated with @McpTool that matches the given tool name.
102+
* @param toolBeans the bean objects containing @McpTool annotated methods
103+
* @param toolName the name of the tool to find
104+
* @return the ToolMethodInfo object, or null if not found
105+
*/
106+
private static ToolMethodInfo findToolMethod(List<Object> toolBeans, String toolName) {
107+
for (Object bean : toolBeans) {
108+
Class<?> clazz = bean.getClass();
109+
Method[] methods = ReflectionUtils.getAllDeclaredMethods(clazz);
110+
for (Method method : methods) {
111+
McpTool annotation = method.getAnnotation(McpTool.class);
112+
if (annotation != null && annotation.name().equals(toolName)) {
113+
return new ToolMethodInfo(bean, method);
114+
}
115+
}
116+
}
117+
return null;
118+
}
119+
120+
/**
121+
* Wraps a tool specification to collect all Flux elements before serialization.
122+
* @param original the original tool specification
123+
* @param methodInfo the method information including bean and method
124+
* @return the wrapped tool specification
125+
*/
126+
private static McpStatelessServerFeatures.AsyncToolSpecification wrapToolSpecificationForFlux(
127+
McpStatelessServerFeatures.AsyncToolSpecification original, ToolMethodInfo methodInfo) {
128+
129+
BiFunction<McpTransportContext, McpSchema.CallToolRequest, Mono<McpSchema.CallToolResult>> originalHandler = original
130+
.callHandler();
131+
132+
BiFunction<McpTransportContext, McpSchema.CallToolRequest, Mono<McpSchema.CallToolResult>> wrappedHandler = (
133+
context, request) -> {
134+
try {
135+
// Invoke the method directly to get access to the Flux
136+
Object[] args = buildMethodArguments(methodInfo.method(), request.arguments());
137+
Object result = ReflectionUtils.invokeMethod(methodInfo.method(), methodInfo.bean(), args);
138+
139+
if (result instanceof Flux) {
140+
// Collect all Flux elements into a list
141+
Flux<?> flux = (Flux<?>) result;
142+
return flux.collectList().flatMap(list -> {
143+
// Serialize the list to JSON
144+
try {
145+
String jsonContent = objectMapper.writeValueAsString(list);
146+
return Mono.just(new McpSchema.CallToolResult(
147+
List.of(new McpSchema.TextContent(jsonContent)), false));
148+
}
149+
catch (Exception e) {
150+
logger.error("Failed to serialize Flux result for tool '{}'", original.tool().name(), e);
151+
return Mono.just(new McpSchema.CallToolResult(
152+
List.of(new McpSchema.TextContent("Error: " + e.getMessage())), true));
153+
}
154+
});
155+
}
156+
else {
157+
// Fall back to original handler for non-Flux results
158+
return originalHandler.apply(context, request);
159+
}
160+
}
161+
catch (Exception e) {
162+
logger.error("Failed to invoke tool method '{}'", original.tool().name(), e);
163+
return Mono.just(
164+
new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Error: " + e.getMessage())),
165+
true));
166+
}
167+
};
168+
169+
return new McpStatelessServerFeatures.AsyncToolSpecification(original.tool(), wrappedHandler);
170+
}
171+
172+
/**
173+
* Builds method arguments from the request arguments map.
174+
* @param method the method to invoke
175+
* @param requestArgs the arguments from the CallToolRequest
176+
* @return array of method arguments
177+
*/
178+
private static Object[] buildMethodArguments(Method method, Map<String, Object> requestArgs) {
179+
java.lang.reflect.Parameter[] parameters = method.getParameters();
180+
Object[] args = new Object[parameters.length];
181+
182+
for (int i = 0; i < parameters.length; i++) {
183+
java.lang.reflect.Parameter param = parameters[i];
184+
McpToolParam paramAnnotation = param.getAnnotation(McpToolParam.class);
185+
186+
if (paramAnnotation != null) {
187+
String paramName = paramAnnotation.name().isEmpty() ? param.getName() : paramAnnotation.name();
188+
Object value = requestArgs.get(paramName);
189+
190+
// Type conversion if needed
191+
if (value != null) {
192+
args[i] = objectMapper.convertValue(value, param.getType());
193+
}
194+
else if (!paramAnnotation.required()) {
195+
args[i] = null;
196+
}
197+
else {
198+
throw new IllegalArgumentException("Required parameter '" + paramName + "' is missing");
199+
}
200+
}
201+
else {
202+
// Try to match by parameter name
203+
Object value = requestArgs.get(param.getName());
204+
if (value != null) {
205+
args[i] = objectMapper.convertValue(value, param.getType());
206+
}
207+
else {
208+
args[i] = null;
209+
}
210+
}
211+
}
212+
213+
return args;
214+
}
215+
216+
/**
217+
* Holds information about a tool method.
218+
*/
219+
private static class ToolMethodInfo {
220+
221+
private final Object bean;
222+
223+
private final Method method;
224+
225+
ToolMethodInfo(Object bean, Method method) {
226+
this.bean = bean;
227+
this.method = method;
228+
ReflectionUtils.makeAccessible(method);
229+
}
230+
231+
Object bean() {
232+
return this.bean;
233+
}
234+
235+
Method method() {
236+
return this.method;
237+
}
238+
239+
boolean returnsFlux() {
240+
return Flux.class.isAssignableFrom(this.method.getReturnType());
241+
}
242+
243+
}
244+
245+
}

0 commit comments

Comments
 (0)