Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,13 @@

package org.springaicommunity.mcp.method.tool;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.stream.Stream;

import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import org.reactivestreams.Publisher;
import org.springaicommunity.mcp.annotation.McpMeta;
import org.springaicommunity.mcp.annotation.McpProgressToken;
import org.springaicommunity.mcp.annotation.McpTool;
import org.springaicommunity.mcp.method.tool.utils.JsonParser;

import com.fasterxml.jackson.core.type.TypeReference;

import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand All @@ -46,109 +37,16 @@
* McpTransportContext)
* @author Christian Tzolov
*/
public abstract class AbstractAsyncMcpToolMethodCallback<T> {
public abstract class AbstractAsyncMcpToolMethodCallback<T> extends AbstractMcpToolMethodCallback<T> {

protected final Class<? extends Throwable> toolCallExceptionClass;

private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE = new TypeReference<Map<String, Object>>() {
// No implementation needed
};

protected final Method toolMethod;

protected final Object toolObject;

protected final ReturnMode returnMode;

protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject,
Class<? extends Throwable> toolCallExceptionClass) {
this.toolMethod = toolMethod;
this.toolObject = toolObject;
this.returnMode = returnMode;
super(returnMode, toolMethod, toolObject);
this.toolCallExceptionClass = toolCallExceptionClass;
}

/**
* Invokes the tool method with the provided arguments.
* @param methodArguments The arguments to pass to the method
* @return The result of the method invocation
* @throws IllegalStateException if the method cannot be accessed
* @throws RuntimeException if there's an error invoking the method
*/
protected Object callMethod(Object[] methodArguments) {
this.toolMethod.setAccessible(true);

Object result;
try {
result = this.toolMethod.invoke(this.toolObject, methodArguments);
}
catch (IllegalAccessException ex) {
throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
}
catch (InvocationTargetException ex) {
throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex);
}
return result;
}

/**
* Builds the method arguments from the context, tool input arguments, and optionally
* the full request.
* @param exchangeOrContext The exchange or context object (e.g.,
* McpAsyncServerExchange or McpTransportContext)
* @param toolInputArguments The input arguments from the tool request
* @param request The full CallToolRequest (optional, can be null)
* @return An array of method arguments
*/
protected Object[] buildMethodArguments(T exchangeOrContext, Map<String, Object> toolInputArguments,
CallToolRequest request) {
return Stream.of(this.toolMethod.getParameters()).map(parameter -> {
// Check if parameter is annotated with @McpProgressToken
if (parameter.isAnnotationPresent(McpProgressToken.class)) {
// Return the progress token from the request
return request != null ? request.progressToken() : null;
}

// Check if parameter is McpMeta type
if (McpMeta.class.isAssignableFrom(parameter.getType())) {
// Return the meta from the request wrapped in McpMeta
return request != null ? new McpMeta(request.meta()) : new McpMeta(null);
}

// Check if parameter is CallToolRequest type
if (CallToolRequest.class.isAssignableFrom(parameter.getType())) {
return request;
}

if (isExchangeOrContextType(parameter.getType())) {
return exchangeOrContext;
}

Object rawArgument = toolInputArguments.get(parameter.getName());
return buildTypedArgument(rawArgument, parameter.getParameterizedType());
}).toArray();
}

/**
* Builds a typed argument from a raw value and type information.
* @param value The raw value
* @param type The target type
* @return The typed argument
*/
protected Object buildTypedArgument(Object value, Type type) {
if (value == null) {
return null;
}

if (type instanceof Class<?>) {
return JsonParser.toTypedObject(value, (Class<?>) type);
}

// For generic types, use the fromJson method that accepts Type
String json = JsonParser.toJson(value);
return JsonParser.fromJson(json, type);
}

/**
* Convert reactive types to Mono<CallToolResult>
* @param result The result from the method invocation
Expand Down Expand Up @@ -233,53 +131,23 @@ protected Mono<CallToolResult> convertToCallToolResult(Object result) {
}

/**
* Map individual values to CallToolResult
* Map individual values to CallToolResult This method delegates to the parent class's
* convertValueToCallToolResult method to avoid code duplication.
* @param value The value to map
* @return A CallToolResult representing the mapped value
*/
protected CallToolResult mapValueToCallToolResult(Object value) {
// Return the result if it's already a CallToolResult
if (value instanceof CallToolResult) {
return (CallToolResult) value;
}

Type returnType = this.toolMethod.getGenericReturnType();

if (returnMode == ReturnMode.VOID || returnType == Void.TYPE || returnType == void.class) {
return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build();
}

if (this.returnMode == ReturnMode.STRUCTURED) {
String jsonOutput = JsonParser.toJson(value);
Object structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE);
return CallToolResult.builder().structuredContent(structuredOutput).build();
}

// Default to text output
if (value == null) {
return CallToolResult.builder().addTextContent("null").build();
}

// For string results in TEXT mode, return the string directly without JSON
// serialization
if (value instanceof String) {
return CallToolResult.builder().addTextContent((String) value).build();
}

// For other types, serialize to JSON
return CallToolResult.builder().addTextContent(JsonParser.toJson(value)).build();
return convertValueToCallToolResult(value);
}

/**
* Creates an error result for exceptions that occur during method invocation.
* @param e The exception that occurred
* @return A Mono<CallToolResult> representing the error
*/
protected Mono<CallToolResult> createErrorResult(Exception e) {
return Mono.just(CallToolResult.builder()
.isError(true)
.addTextContent("Error invoking method: %s".formatted(e.getMessage()))
.build());
protected Mono<CallToolResult> createAsyncErrorResult(Exception e) {
Throwable rootCause = findCauseUsingPlainJava(e);
return Mono.just(CallToolResult.builder().isError(true).addTextContent(rootCause.getMessage()).build());
}

/**
Expand Down
Loading