Skip to content

Commit c924b3b

Browse files
committed
refactor(mcp): reorganize MCP tool utilities and client configuration
- Rename ToolUtils to McpToolUtils for better MCP-specific naming - Rename McpToolCallbackProvider to SyncMcpToolCallbackProvider - Add utility methods for handling tool callbacks in McpToolUtils - Extract client configuration logic into new McpClientDefinitions class - Add tool callback support to ChatClient interface and implementations - Remove redundant integration test
1 parent de51c65 commit c924b3b

File tree

11 files changed

+162
-333
lines changed

11 files changed

+162
-333
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/ToolUtils.java renamed to mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919

20+
import io.modelcontextprotocol.client.McpSyncClient;
2021
import io.modelcontextprotocol.server.McpServerFeatures;
2122
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolRegistration;
2223
import io.modelcontextprotocol.spec.McpSchema;
@@ -25,6 +26,7 @@
2526

2627
import org.springframework.ai.model.ModelOptionsUtils;
2728
import org.springframework.ai.tool.ToolCallback;
29+
import org.springframework.util.CollectionUtils;
2830

2931
/**
3032
* Utility class that provides helper methods for working with Model Context Protocol
@@ -46,9 +48,9 @@
4648
*
4749
* @author Christian Tzolov
4850
*/
49-
public final class ToolUtils {
51+
public final class McpToolUtils {
5052

51-
private ToolUtils() {
53+
private McpToolUtils() {
5254
}
5355

5456
/**
@@ -63,7 +65,7 @@ private ToolUtils() {
6365
*/
6466
public static List<McpServerFeatures.SyncToolRegistration> toSyncToolRegistration(
6567
List<ToolCallback> toolCallbacks) {
66-
return toolCallbacks.stream().map(ToolUtils::toSyncToolRegistration).toList();
68+
return toolCallbacks.stream().map(McpToolUtils::toSyncToolRegistration).toList();
6769
}
6870

6971
/**
@@ -129,7 +131,7 @@ public static McpServerFeatures.SyncToolRegistration toSyncToolRegistration(Tool
129131
*/
130132
public static List<McpServerFeatures.AsyncToolRegistration> toAsyncToolRegistration(
131133
List<ToolCallback> toolCallbacks) {
132-
return toolCallbacks.stream().map(ToolUtils::toAsyncToolRegistration).toList();
134+
return toolCallbacks.stream().map(McpToolUtils::toAsyncToolRegistration).toList();
133135
}
134136

135137
/**
@@ -181,4 +183,19 @@ public static McpServerFeatures.AsyncToolRegistration toAsyncToolRegistration(To
181183
.subscribeOn(Schedulers.boundedElastic()));
182184
}
183185

186+
public static List<ToolCallback> getToolCallbacks(McpSyncClient... mcpClients) {
187+
return getToolCallbacks(List.of(mcpClients));
188+
}
189+
190+
public static List<ToolCallback> getToolCallbacks(List<McpSyncClient> mcpClients) {
191+
192+
if (CollectionUtils.isEmpty(mcpClients)) {
193+
return List.of();
194+
}
195+
return mcpClients.stream()
196+
.map(mcpClient -> List.of((new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks())))
197+
.flatMap(List::stream)
198+
.toList();
199+
}
200+
184201
}

mcp/common/src/main/java/org/springframework/ai/mcp/McpToolCallbackProvider.java renamed to mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
import java.util.List;
1919

20+
import io.modelcontextprotocol.client.McpAsyncClient;
2021
import io.modelcontextprotocol.client.McpSyncClient;
2122

2223
import org.springframework.ai.tool.ToolCallback;
2324
import org.springframework.ai.tool.ToolCallbackProvider;
2425
import org.springframework.ai.tool.util.ToolUtils;
26+
import org.springframework.util.CollectionUtils;
2527

2628
/**
2729
* Implementation of {@link ToolCallbackProvider} that discovers and provides MCP tools.
@@ -50,15 +52,15 @@
5052
* @see McpSyncClient
5153
*/
5254

53-
public class McpToolCallbackProvider implements ToolCallbackProvider {
55+
public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {
5456

5557
private final McpSyncClient mcpClient;
5658

5759
/**
5860
* Creates a new {@code McpToolCallbackProvider} instance.
5961
* @param mcpClient the MCP client to use for discovering tools
6062
*/
61-
public McpToolCallbackProvider(McpSyncClient mcpClient) {
63+
public SyncMcpToolCallbackProvider(McpSyncClient mcpClient) {
6264
this.mcpClient = mcpClient;
6365
}
6466

@@ -105,4 +107,15 @@ private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
105107
}
106108
}
107109

110+
public static List<ToolCallback> syncToolCallbacks(List<McpSyncClient> mcpClients) {
111+
112+
if (CollectionUtils.isEmpty(mcpClients)) {
113+
return List.of();
114+
}
115+
return mcpClients.stream()
116+
.map(mcpClient -> List.of((new SyncMcpToolCallbackProvider(mcpClient).getToolCallbacks())))
117+
.flatMap(List::stream)
118+
.toList();
119+
}
120+
108121
}

mcp/common/src/test/java/org/springframework/ai/mcp/McpToolCallbackProviderTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void getToolCallbacksShouldReturnEmptyArrayWhenNoTools() {
4545
when(listToolsResult.tools()).thenReturn(List.of());
4646
when(mcpClient.listTools()).thenReturn(listToolsResult);
4747

48-
McpToolCallbackProvider provider = new McpToolCallbackProvider(mcpClient);
48+
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);
4949

5050
// Act
5151
var callbacks = provider.getToolCallbacks();
@@ -67,7 +67,7 @@ void getToolCallbacksShouldReturnCallbacksForEachTool() {
6767
when(listToolsResult.tools()).thenReturn(List.of(tool1, tool2));
6868
when(mcpClient.listTools()).thenReturn(listToolsResult);
6969

70-
McpToolCallbackProvider provider = new McpToolCallbackProvider(mcpClient);
70+
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);
7171

7272
// Act
7373
var callbacks = provider.getToolCallbacks();
@@ -89,7 +89,7 @@ void getToolCallbacksShouldThrowExceptionForDuplicateToolNames() {
8989
when(listToolsResult.tools()).thenReturn(List.of(tool1, tool2));
9090
when(mcpClient.listTools()).thenReturn(listToolsResult);
9191

92-
McpToolCallbackProvider provider = new McpToolCallbackProvider(mcpClient);
92+
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(mcpClient);
9393

9494
// Act & Assert
9595
assertThatThrownBy(() -> provider.getToolCallbacks()).isInstanceOf(IllegalStateException.class)

mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ToolUtilsTests {
4040

4141
@Test
4242
void constructorShouldBePrivate() throws Exception {
43-
Constructor<ToolUtils> constructor = ToolUtils.class.getDeclaredConstructor();
43+
Constructor<McpToolUtils> constructor = McpToolUtils.class.getDeclaredConstructor();
4444
assertThat(Modifier.isPrivate(constructor.getModifiers())).isTrue();
4545
constructor.setAccessible(true);
4646
constructor.newInstance();
@@ -52,7 +52,7 @@ void toSyncToolRegistrationShouldConvertSingleCallback() {
5252
ToolCallback callback = createMockToolCallback("test", "success");
5353

5454
// Act
55-
SyncToolRegistration registration = ToolUtils.toSyncToolRegistration(callback);
55+
SyncToolRegistration registration = McpToolUtils.toSyncToolRegistration(callback);
5656

5757
// Assert
5858
assertThat(registration).isNotNull();
@@ -70,7 +70,7 @@ void toSyncToolRegistrationShouldHandleError() {
7070
ToolCallback callback = createMockToolCallback("test", new RuntimeException("error"));
7171

7272
// Act
73-
SyncToolRegistration registration = ToolUtils.toSyncToolRegistration(callback);
73+
SyncToolRegistration registration = McpToolUtils.toSyncToolRegistration(callback);
7474

7575
// Assert
7676
assertThat(registration).isNotNull();
@@ -87,7 +87,7 @@ void toSyncToolRegistrationShouldConvertMultipleCallbacks() {
8787
ToolCallback callback2 = createMockToolCallback("test2", "success2");
8888

8989
// Act
90-
List<SyncToolRegistration> registrations = ToolUtils.toSyncToolRegistration(callback1, callback2);
90+
List<SyncToolRegistration> registrations = McpToolUtils.toSyncToolRegistration(callback1, callback2);
9191

9292
// Assert
9393
assertThat(registrations).hasSize(2);
@@ -101,7 +101,7 @@ void toAsyncToolRegistrationShouldConvertSingleCallback() {
101101
ToolCallback callback = createMockToolCallback("test", "success");
102102

103103
// Act
104-
AsyncToolRegistration registration = ToolUtils.toAsyncToolRegistration(callback);
104+
AsyncToolRegistration registration = McpToolUtils.toAsyncToolRegistration(callback);
105105

106106
// Assert
107107
assertThat(registration).isNotNull();
@@ -120,7 +120,7 @@ void toAsyncToolRegistrationShouldHandleError() {
120120
ToolCallback callback = createMockToolCallback("test", new RuntimeException("error"));
121121

122122
// Act
123-
AsyncToolRegistration registration = ToolUtils.toAsyncToolRegistration(callback);
123+
AsyncToolRegistration registration = McpToolUtils.toAsyncToolRegistration(callback);
124124

125125
// Assert
126126
assertThat(registration).isNotNull();
@@ -138,7 +138,7 @@ void toAsyncToolRegistrationShouldConvertMultipleCallbacks() {
138138
ToolCallback callback2 = createMockToolCallback("test2", "success2");
139139

140140
// Act
141-
List<AsyncToolRegistration> registrations = ToolUtils.toAsyncToolRegistration(callback1, callback2);
141+
List<AsyncToolRegistration> registrations = McpToolUtils.toAsyncToolRegistration(callback1, callback2);
142142

143143
// Assert
144144
assertThat(registrations).hasSize(2);

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.converter.StructuredOutputConverter;
3636
import org.springframework.ai.model.Media;
3737
import org.springframework.ai.model.function.FunctionCallback;
38+
import org.springframework.ai.tool.ToolCallback;
3839
import org.springframework.core.ParameterizedTypeReference;
3940
import org.springframework.core.io.Resource;
4041
import org.springframework.lang.Nullable;
@@ -218,6 +219,8 @@ interface ChatClientRequestSpec {
218219

219220
ChatClientRequestSpec tools(FunctionCallback... toolCallbacks);
220221

222+
ChatClientRequestSpec tools(List<ToolCallback> toolCallbacks);
223+
221224
ChatClientRequestSpec tools(Object... toolObjects);
222225

223226
@Deprecated
@@ -283,6 +286,8 @@ interface Builder {
283286

284287
Builder defaultTools(FunctionCallback... toolCallbacks);
285288

289+
Builder defaultTools(List<ToolCallback> toolCallbacks);
290+
286291
Builder defaultTools(Object... toolObjects);
287292

288293
/**

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import io.micrometer.observation.Observation;
3434
import io.micrometer.observation.ObservationRegistry;
3535
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
36+
37+
import org.springframework.ai.tool.ToolCallback;
3638
import org.springframework.ai.tool.ToolCallbacks;
3739
import reactor.core.publisher.Flux;
3840
import reactor.core.scheduler.Schedulers;
@@ -853,6 +855,14 @@ public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
853855
return this;
854856
}
855857

858+
@Override
859+
public ChatClientRequestSpec tools(List<ToolCallback> toolCallbacks) {
860+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
861+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
862+
this.functionCallbacks.addAll(toolCallbacks);
863+
return this;
864+
}
865+
856866
@Override
857867
public ChatClientRequestSpec tools(Object... toolObjects) {
858868
Assert.notNull(toolObjects, "toolObjects cannot be null");

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.chat.model.ToolContext;
3636
import org.springframework.ai.chat.prompt.ChatOptions;
3737
import org.springframework.ai.model.function.FunctionCallback;
38+
import org.springframework.ai.tool.ToolCallback;
3839
import org.springframework.core.io.Resource;
3940
import org.springframework.lang.Nullable;
4041
import org.springframework.util.Assert;
@@ -160,6 +161,12 @@ public Builder defaultTools(FunctionCallback... toolCallbacks) {
160161
return this;
161162
}
162163

164+
@Override
165+
public Builder defaultTools(List<ToolCallback> toolCallbacks) {
166+
this.defaultRequest.tools(toolCallbacks);
167+
return this;
168+
}
169+
163170
@Override
164171
public Builder defaultTools(Object... toolObjects) {
165172
this.defaultRequest.tools(toolObjects);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.autoconfigure.mcp.client.stdio;
17+
18+
import java.util.List;
19+
import java.util.Map;
20+
import java.util.stream.Collectors;
21+
22+
import io.modelcontextprotocol.client.McpClient;
23+
import io.modelcontextprotocol.client.McpSyncClient;
24+
import io.modelcontextprotocol.client.transport.ServerParameters;
25+
import io.modelcontextprotocol.client.transport.StdioClientTransport;
26+
import io.modelcontextprotocol.spec.McpSchema;
27+
28+
/**
29+
* @author Christian Tzolov
30+
* @since 1.0.0
31+
*/
32+
33+
public class McpClientDefinitions {
34+
35+
private final Map<String, McpClientDefinition> definitions;
36+
37+
public McpClientDefinitions(Map<String, McpClientDefinition> definitions) {
38+
this.definitions = definitions;
39+
}
40+
41+
public record McpClientDefinition(String name, ServerParameters serverParameters,
42+
McpStdioClientProperties clientProperties) {
43+
44+
public McpClient.SyncSpec syncSpec() {
45+
46+
var transport = new StdioClientTransport(serverParameters());
47+
48+
McpSchema.ClientCapabilities.Builder capabilitiesBuilder = McpSchema.ClientCapabilities.builder();
49+
50+
McpSchema.Implementation clientInfo = new McpSchema.Implementation(name(), clientProperties.getVersion());
51+
52+
McpClient.SyncSpec clientBilder = McpClient.sync(transport)
53+
.clientInfo(clientInfo)
54+
.requestTimeout(clientProperties.getRequestTimeout());
55+
56+
clientBilder.capabilities(capabilitiesBuilder.build());
57+
58+
return clientBilder;
59+
}
60+
}
61+
62+
public List<String> names() {
63+
return definitions.keySet().stream().collect(Collectors.toList());
64+
}
65+
66+
public McpClientDefinition getMcpClientDefinition(String name) {
67+
return definitions.get(name);
68+
}
69+
70+
public List<McpSyncClient> toMcpSyncClients() {
71+
72+
var clients = definitions.values()
73+
.stream()
74+
.map(McpClientDefinition::syncSpec)
75+
.map(McpClient.SyncSpec::build)
76+
.collect(Collectors.toList());
77+
78+
clients.forEach(McpSyncClient::initialize);
79+
80+
return clients;
81+
}
82+
83+
}

0 commit comments

Comments
 (0)