Skip to content
This repository was archived by the owner on Feb 14, 2025. It is now read-only.

Commit a424b36

Browse files
committed
feat: Add protocol version negotiation
Implement protocol version negotiation between MCP client and server to ensure compatibility and graceful version handling. The server will now suggest the latest supported version when an unsupported version is requested, while the client will verify version compatibility during initialization. - Support multiple protocol versions in both client and server - Negotiate compatible version during initialization - Fall back to latest version when unsupported version requested - Add comprehensive test coverage for version negotiation Resolves #71 Signed-off-by: Christian Tzolov <[email protected]>
1 parent 48ab97b commit a424b36

File tree

8 files changed

+370
-13
lines changed

8 files changed

+370
-13
lines changed

mcp/src/main/java/org/springframework/ai/mcp/client/McpAsyncClient.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ public class McpAsyncClient {
146146
*/
147147
private final McpTransport transport;
148148

149+
/**
150+
* Supported protocol versions.
151+
*/
152+
private List<String> protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
153+
149154
/**
150155
* Create a new McpAsyncClient with the given transport and session request-response
151156
* timeout.
@@ -257,6 +262,8 @@ public McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Imp
257262
Assert.notNull(requestTimeout, "Request timeout must not be null");
258263
Assert.notNull(clientInfo, "Client info must not be null");
259264

265+
this.protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
266+
260267
this.clientInfo = clientInfo;
261268

262269
this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities
@@ -360,8 +367,11 @@ public McpAsyncClient(ClientMcpTransport transport, Duration requestTimeout, Imp
360367
* @return the initialize result.
361368
*/
362369
public Mono<McpSchema.InitializeResult> initialize() {
370+
371+
String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
372+
363373
McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off
364-
McpSchema.LATEST_PROTOCOL_VERSION,
374+
latestVersion,
365375
this.clientCapabilities,
366376
this.clientInfo); // @formatter:on
367377

@@ -378,7 +388,7 @@ public Mono<McpSchema.InitializeResult> initialize() {
378388
initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(),
379389
initializeResult.instructions());
380390

381-
if (!McpSchema.LATEST_PROTOCOL_VERSION.equals(initializeResult.protocolVersion())) {
391+
if (!this.protocolVersions.contains(initializeResult.protocolVersion())) {
382392
return Mono.error(new McpError(
383393
"Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
384394
}
@@ -911,4 +921,13 @@ public Mono<Void> setLoggingLevel(LoggingLevel loggingLevel) {
911921
return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params);
912922
}
913923

924+
/**
925+
* This method is package-private and used for test only. Should not be called by user
926+
* code.
927+
* @param protocolVersions the Client supported protocol versions.
928+
*/
929+
void setProtocolVersions(List<String> protocolVersions) {
930+
this.protocolVersions = protocolVersions;
931+
}
932+
914933
}

mcp/src/main/java/org/springframework/ai/mcp/client/McpClientFeatures.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
6767

6868
/**
6969
* Create an instance and validate the arguments.
70-
* @param clientInfo the client implementation information.
7170
* @param clientCapabilities the client capabilities.
7271
* @param roots the roots.
7372
* @param toolsChangeConsumers the tools change consumers.
@@ -85,7 +84,6 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
8584
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler) {
8685

8786
Assert.notNull(clientInfo, "Client info must not be null");
88-
8987
this.clientInfo = clientInfo;
9088
this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities
9189
: new McpSchema.ClientCapabilities(null,
@@ -182,7 +180,6 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
182180
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler) {
183181

184182
Assert.notNull(clientInfo, "Client info must not be null");
185-
186183
this.clientInfo = clientInfo;
187184
this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities
188185
: new McpSchema.ClientCapabilities(null,

mcp/src/main/java/org/springframework/ai/mcp/server/McpAsyncServer.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,18 @@ public class McpAsyncServer {
122122

123123
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
124124

125+
/**
126+
* Supported protocol versions.
127+
*/
128+
private List<String> protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
129+
125130
/**
126131
* Create a new McpAsyncServer with the given transport and capabilities.
127132
* @param mcpTransport The transport layer implementation for MCP communication.
128133
* @param features The MCP server supported features.
129134
*/
130135
McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) {
136+
131137
this.serverInfo = features.serverInfo();
132138
this.serverCapabilities = features.serverCapabilities();
133139
this.tools.addAll(features.tools());
@@ -210,6 +216,7 @@ public McpAsyncServer(ServerMcpTransport mcpTransport, McpSchema.Implementation
210216
Map<String, ResourceRegistration> resources, List<McpSchema.ResourceTemplate> resourceTemplates,
211217
Map<String, PromptRegistration> prompts, List<Consumer<List<McpSchema.Root>>> rootsChangeConsumers) {
212218

219+
this.protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
213220
this.serverInfo = serverInfo;
214221
if (!Utils.isEmpty(tools)) {
215222
this.tools.addAll(McpServer.mapDeprecatedTools(tools));
@@ -299,12 +306,18 @@ private DefaultMcpSession.RequestHandler<McpSchema.InitializeResult> asyncInitia
299306
initializeRequest.protocolVersion(), initializeRequest.capabilities(),
300307
initializeRequest.clientInfo());
301308

302-
if (!McpSchema.LATEST_PROTOCOL_VERSION.equals(initializeRequest.protocolVersion())) {
303-
return Mono.error(new McpError(
304-
"Unsupported protocol version from client: " + initializeRequest.protocolVersion()));
309+
String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
310+
311+
if (this.protocolVersions.contains(initializeRequest.protocolVersion())) {
312+
serverProtocolVersion = initializeRequest.protocolVersion();
313+
}
314+
else {
315+
logger.warn(
316+
"Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead",
317+
initializeRequest.protocolVersion(), serverProtocolVersion);
305318
}
306319

307-
return Mono.just(new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, this.serverCapabilities,
320+
return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities,
308321
this.serverInfo, null));
309322
};
310323
}
@@ -903,4 +916,13 @@ public Mono<McpSchema.CreateMessageResult> createMessage(McpSchema.CreateMessage
903916
CREATE_MESSAGE_RESULT_TYPE_REF);
904917
}
905918

919+
/**
920+
* This method is package-private and used for test only. Should not be called by user
921+
* code.
922+
* @param protocolVersions the Client supported protocol versions.
923+
*/
924+
void setProtocolVersions(List<String> protocolVersions) {
925+
this.protocolVersions = protocolVersions;
926+
}
927+
906928
}

mcp/src/main/java/org/springframework/ai/mcp/util/Assert.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
package org.springframework.ai.mcp.util;
1818

19+
import java.util.Collection;
20+
1921
import reactor.util.annotation.Nullable;
2022

2123
/**
2224
* Assertion utility class that assists in validating arguments.
23-
*
25+
*
2426
* @author Christian Tzolov
2527
*/
2628

@@ -29,6 +31,18 @@
2931
*/
3032
public final class Assert {
3133

34+
/**
35+
* Assert that the collection is not {@code null} and not empty.
36+
* @param collection the collection to check
37+
* @param message the exception message to use if the assertion fails
38+
* @throws IllegalArgumentException if the collection is {@code null} or empty
39+
*/
40+
public static void notEmpty(@Nullable Collection<?> collection, String message) {
41+
if (collection == null || collection.isEmpty()) {
42+
throw new IllegalArgumentException(message);
43+
}
44+
}
45+
3246
/**
3347
* Assert that an object is not {@code null}.
3448
*

mcp/src/test/java/org/springframework/ai/mcp/MockMcpTransport.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
import org.springframework.ai.mcp.spec.McpSchema.JSONRPCRequest;
3333
import org.springframework.ai.mcp.spec.ServerMcpTransport;
3434

35-
@SuppressWarnings("unused")
35+
/**
36+
* A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport}
37+
* interfaces.
38+
*/
3639
public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport {
3740

3841
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
@@ -91,6 +94,7 @@ public Mono<Void> closeGracefully() {
9194
connected = false;
9295
outgoing.tryEmitComplete();
9396
inbound.tryEmitComplete();
97+
// Wait for all subscribers to complete
9498
return Mono.empty();
9599
});
96100
}
@@ -100,4 +104,4 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
100104
return new ObjectMapper().convertValue(data, typeRef);
101105
}
102106

103-
}
107+
}

mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.function.Consumer;
2423
import java.util.function.Function;
2524

2625
import com.fasterxml.jackson.core.JsonProcessingException;
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp.client;
18+
19+
import java.time.Duration;
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.Test;
23+
import reactor.core.publisher.Mono;
24+
import reactor.test.StepVerifier;
25+
26+
import org.springframework.ai.mcp.MockMcpTransport;
27+
import org.springframework.ai.mcp.spec.McpError;
28+
import org.springframework.ai.mcp.spec.McpSchema;
29+
import org.springframework.ai.mcp.spec.McpSchema.InitializeResult;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
33+
/**
34+
* Tests for MCP protocol version negotiation and compatibility.
35+
*/
36+
class McpClientProtocolVersionTests {
37+
38+
private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(30);
39+
40+
private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0");
41+
42+
@Test
43+
void shouldUseLatestVersionByDefault() {
44+
MockMcpTransport transport = new MockMcpTransport();
45+
McpAsyncClient client = McpClient.async(transport)
46+
.clientInfo(CLIENT_INFO)
47+
.requestTimeout(REQUEST_TIMEOUT)
48+
.build();
49+
50+
try {
51+
Mono<InitializeResult> initializeResultMono = client.initialize();
52+
53+
StepVerifier.create(initializeResultMono).then(() -> {
54+
McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest();
55+
assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class);
56+
McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params();
57+
assertThat(initRequest.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION);
58+
59+
transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
60+
new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, null,
61+
new McpSchema.Implementation("test-server", "1.0.0"), null),
62+
null));
63+
}).assertNext(result -> {
64+
assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION);
65+
}).verifyComplete();
66+
67+
}
68+
finally {
69+
// Ensure cleanup happens even if test fails
70+
StepVerifier.create(client.closeGracefully()).verifyComplete();
71+
}
72+
}
73+
74+
@Test
75+
void shouldNegotiateSpecificVersion() {
76+
String oldVersion = "0.1.0";
77+
MockMcpTransport transport = new MockMcpTransport();
78+
McpAsyncClient client = McpClient.async(transport)
79+
.clientInfo(CLIENT_INFO)
80+
.requestTimeout(REQUEST_TIMEOUT)
81+
.build();
82+
83+
client.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION));
84+
85+
try {
86+
Mono<InitializeResult> initializeResultMono = client.initialize();
87+
88+
StepVerifier.create(initializeResultMono).then(() -> {
89+
McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest();
90+
assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class);
91+
McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params();
92+
assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION));
93+
94+
transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
95+
new McpSchema.InitializeResult(oldVersion, null,
96+
new McpSchema.Implementation("test-server", "1.0.0"), null),
97+
null));
98+
}).assertNext(result -> {
99+
assertThat(result.protocolVersion()).isEqualTo(oldVersion);
100+
}).verifyComplete();
101+
}
102+
finally {
103+
StepVerifier.create(client.closeGracefully()).verifyComplete();
104+
}
105+
}
106+
107+
@Test
108+
void shouldFailForUnsupportedVersion() {
109+
String unsupportedVersion = "999.999.999";
110+
MockMcpTransport transport = new MockMcpTransport();
111+
McpAsyncClient client = McpClient.async(transport)
112+
.clientInfo(CLIENT_INFO)
113+
.requestTimeout(REQUEST_TIMEOUT)
114+
.build();
115+
116+
try {
117+
Mono<InitializeResult> initializeResultMono = client.initialize();
118+
119+
StepVerifier.create(initializeResultMono).then(() -> {
120+
McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest();
121+
assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class);
122+
123+
transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
124+
new McpSchema.InitializeResult(unsupportedVersion, null,
125+
new McpSchema.Implementation("test-server", "1.0.0"), null),
126+
null));
127+
}).expectError(McpError.class).verify();
128+
}
129+
finally {
130+
StepVerifier.create(client.closeGracefully()).verifyComplete();
131+
}
132+
}
133+
134+
@Test
135+
void shouldUseHighestVersionWhenMultipleSupported() {
136+
String oldVersion = "0.1.0";
137+
String middleVersion = "0.2.0";
138+
String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION;
139+
140+
MockMcpTransport transport = new MockMcpTransport();
141+
McpAsyncClient client = McpClient.async(transport)
142+
.clientInfo(CLIENT_INFO)
143+
.requestTimeout(REQUEST_TIMEOUT)
144+
.build();
145+
146+
client.setProtocolVersions(List.of(oldVersion, middleVersion, latestVersion));
147+
148+
try {
149+
Mono<InitializeResult> initializeResultMono = client.initialize();
150+
151+
StepVerifier.create(initializeResultMono).then(() -> {
152+
McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest();
153+
McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params();
154+
assertThat(initRequest.protocolVersion()).isEqualTo(latestVersion);
155+
156+
transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
157+
new McpSchema.InitializeResult(latestVersion, null,
158+
new McpSchema.Implementation("test-server", "1.0.0"), null),
159+
null));
160+
}).assertNext(result -> {
161+
assertThat(result.protocolVersion()).isEqualTo(latestVersion);
162+
}).verifyComplete();
163+
}
164+
finally {
165+
StepVerifier.create(client.closeGracefully()).verifyComplete();
166+
}
167+
168+
}
169+
170+
}

0 commit comments

Comments
 (0)