Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

package org.springframework.ai.mcp;

import java.util.Map;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import java.util.Map;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -112,19 +113,16 @@ public String call(String functionInput) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
try {
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).map(response -> {
if (response.isError() != null && response.isError()) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
}).block();
}
catch (Exception ex) {
throw new ToolExecutionException(this.getToolDefinition(), ex.getCause());
}

return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> {
// If the tool throws an error during execution
throw new ToolExecutionException(this.getToolDefinition(), exception);
}).map(response -> {
if (response.isError() != null && response.isError()) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
}).block();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@

package org.springframework.ai.mcp;

import java.lang.reflect.InvocationTargetException;
import java.util.Map;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -32,7 +30,6 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.core.log.LogAccessor;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -118,22 +115,24 @@ public ToolDefinition getToolDefinition() {
@Override
public String call(String functionInput) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition

CallToolResult response;
try {
CallToolResult response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
if (response.isError() != null && response.isError()) {
logger.error("Error calling tool: {}", response.content());
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
}
catch (Exception ex) {
logger.error("Exception while tool calling: ", ex);
throw new ToolExecutionException(this.getToolDefinition(), ex.getCause());
throw new ToolExecutionException(this.getToolDefinition(), ex);
}

if (response.isError() != null && response.isError()) {
logger.error("Error calling tool: {}", response.content());
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.mcp;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

import org.springframework.ai.tool.execution.ToolExecutionException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class AsyncMcpToolCallbackTest {

@Mock
private McpAsyncClient mcpClient;

@Mock
private McpSchema.Tool tool;

@Test
void callShouldThrowOnError() {
when(this.tool.name()).thenReturn("testTool");
var clientInfo = new McpSchema.Implementation("testClient", "1.0.0");
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
var callToolResult = McpSchema.CallToolResult.builder().addTextContent("Some error data").isError(true).build();
when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callToolResult));

var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
.cause()
.isInstanceOf(IllegalStateException.class)
.hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]");
}

@Test
void callShouldWrapReactiveErrors() {
when(this.tool.name()).thenReturn("testTool");
var clientInfo = new McpSchema.Implementation("testClient", "1.0.0");
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class)))
.thenReturn(Mono.error(new Exception("Testing tool error")));

var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
.rootCause()
.hasMessage("Testing tool error");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.mcp;

import io.modelcontextprotocol.spec.McpSchema;
import java.util.List;
import java.util.Map;

import io.modelcontextprotocol.client.McpSyncClient;
Expand All @@ -29,8 +31,11 @@
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.content.Content;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -94,4 +99,36 @@ void callShouldIgnoreToolContext() {
assertThat(response).isNotNull();
}

@Test
void callShouldThrowOnError() {
when(this.tool.name()).thenReturn("testTool");
var clientInfo = new Implementation("testClient", "1.0.0");
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
CallToolResult callResult = mock(CallToolResult.class);
when(callResult.isError()).thenReturn(true);
when(callResult.content()).thenReturn(List.of(new McpSchema.TextContent("Some error data")));
when(this.mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult);

SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool);

assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
.cause()
.isInstanceOf(IllegalStateException.class)
.hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]");
}

@Test
void callShouldWrapExceptions() {
when(this.tool.name()).thenReturn("testTool");
var clientInfo = new Implementation("testClient", "1.0.0");
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
when(this.mcpClient.callTool(any(CallToolRequest.class))).thenThrow(new RuntimeException("Testing tool error"));

SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool);

assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
.rootCause()
.hasMessage("Testing tool error");
}

}