Skip to content
This repository was archived by the owner on Feb 14, 2025. It is now read-only.

Commit 9130e74

Browse files
committed
fix(mcp): Improve transport shutdown and error handling
- Add isClosing flag to coordinate graceful shutdown - Synchronize output stream writes to prevent race conditions - Enhance error handling in stream processing - Add delay for pending message processing during shutdown - Improve connection state tracking in tests
1 parent 759fbf1 commit 9130e74

File tree

3 files changed

+124
-53
lines changed

3 files changed

+124
-53
lines changed

mcp/src/main/java/org/springframework/ai/mcp/client/transport/StdioClientTransport.java

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.IOException;
2121
import java.io.InputStreamReader;
2222
import java.nio.charset.StandardCharsets;
23+
import java.time.Duration;
2324
import java.util.ArrayList;
2425
import java.util.List;
2526
import java.util.concurrent.CompletableFuture;
@@ -77,6 +78,8 @@ public class StdioClientTransport implements McpTransport {
7778

7879
private final Sinks.Many<String> errorSink;
7980

81+
private volatile boolean isClosing = false;
82+
8083
// visible for tests
8184
private Consumer<String> errorHandler = error -> logger.error("Error received: {}", error);
8285

@@ -200,19 +203,32 @@ private void startErrorProcessing() {
200203
try (BufferedReader processErrorReader = new BufferedReader(
201204
new InputStreamReader(process.getErrorStream()))) {
202205
String line;
203-
while ((line = processErrorReader.readLine()) != null) {
206+
while (!isClosing && (line = processErrorReader.readLine()) != null) {
204207
try {
205208
logger.error("Received error line: {}", line);
206-
// TODO: handle errors, etc.
207-
this.errorSink.tryEmitNext(line);
209+
if (!this.errorSink.tryEmitNext(line).isSuccess()) {
210+
if (!isClosing) {
211+
logger.error("Failed to emit error message");
212+
}
213+
break;
214+
}
208215
}
209216
catch (Exception e) {
210-
throw new RuntimeException(e);
217+
if (!isClosing) {
218+
logger.error("Error processing error message", e);
219+
}
220+
break;
211221
}
212222
}
213223
}
214224
catch (IOException e) {
215-
throw new RuntimeException(e);
225+
if (!isClosing) {
226+
logger.error("Error reading from error stream", e);
227+
}
228+
}
229+
finally {
230+
isClosing = true;
231+
errorSink.tryEmitComplete();
216232
}
217233
});
218234
}
@@ -254,21 +270,32 @@ private void startInboundProcessing() {
254270
this.inboundScheduler.schedule(() -> {
255271
try (BufferedReader processReader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
256272
String line;
257-
while ((line = processReader.readLine()) != null) {
273+
while (!isClosing && (line = processReader.readLine()) != null) {
258274
try {
259275
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line);
260276
if (!this.inboundSink.tryEmitNext(message).isSuccess()) {
261-
// TODO: Back off, reschedule, give up?
262-
throw new RuntimeException("Failed to enqueue message");
277+
if (!isClosing) {
278+
logger.error("Failed to enqueue inbound message");
279+
}
280+
break;
263281
}
264282
}
265283
catch (Exception e) {
266-
throw new RuntimeException(e);
284+
if (!isClosing) {
285+
logger.error("Error processing inbound message", e);
286+
}
287+
break;
267288
}
268289
}
269290
}
270291
catch (IOException e) {
271-
throw new RuntimeException(e);
292+
if (!isClosing) {
293+
logger.error("Error reading from input stream", e);
294+
}
295+
}
296+
finally {
297+
isClosing = true;
298+
inboundSink.tryEmitComplete();
272299
}
273300
});
274301
}
@@ -284,7 +311,7 @@ private void startOutboundProcessing() {
284311
// want to ensure that the actual writing happens on a dedicated thread
285312
.publishOn(outboundScheduler)
286313
.handle((message, s) -> {
287-
if (message != null) {
314+
if (message != null && !isClosing) {
288315
try {
289316
String jsonMessage = objectMapper.writeValueAsString(message);
290317
// Escape any embedded newlines in the JSON message as per spec:
@@ -293,9 +320,12 @@ private void startOutboundProcessing() {
293320
// embedded newlines.
294321
jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n");
295322

296-
this.process.getOutputStream().write(jsonMessage.getBytes(StandardCharsets.UTF_8));
297-
this.process.getOutputStream().write("\n".getBytes(StandardCharsets.UTF_8));
298-
this.process.getOutputStream().flush();
323+
var os = this.process.getOutputStream();
324+
synchronized (os) {
325+
os.write(jsonMessage.getBytes(StandardCharsets.UTF_8));
326+
os.write("\n".getBytes(StandardCharsets.UTF_8));
327+
os.flush();
328+
}
299329
s.next(message);
300330
}
301331
catch (IOException e) {
@@ -306,7 +336,16 @@ private void startOutboundProcessing() {
306336
}
307337

308338
protected void handleOutbound(Function<Flux<JSONRPCMessage>, Flux<JSONRPCMessage>> outboundConsumer) {
309-
outboundConsumer.apply(outboundSink.asFlux()).subscribe();
339+
outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> {
340+
isClosing = true;
341+
outboundSink.tryEmitComplete();
342+
}).doOnError(e -> {
343+
if (!isClosing) {
344+
logger.error("Error in outbound processing", e);
345+
isClosing = true;
346+
outboundSink.tryEmitComplete();
347+
}
348+
}).subscribe();
310349
}
311350

312351
/**
@@ -317,7 +356,18 @@ protected void handleOutbound(Function<Flux<JSONRPCMessage>, Flux<JSONRPCMessage
317356
*/
318357
@Override
319358
public Mono<Void> closeGracefully() {
320-
return Mono.fromFuture(() -> {
359+
return Mono.fromRunnable(() -> {
360+
isClosing = true;
361+
logger.debug("Initiating graceful shutdown");
362+
}).then(Mono.defer(() -> {
363+
// First complete all sinks to stop accepting new messages
364+
inboundSink.tryEmitComplete();
365+
outboundSink.tryEmitComplete();
366+
errorSink.tryEmitComplete();
367+
368+
// Give a short time for any pending messages to be processed
369+
return Mono.delay(Duration.ofMillis(100));
370+
})).then(Mono.fromFuture(() -> {
321371
logger.info("Sending TERM to process");
322372
if (this.process != null) {
323373
this.process.destroy();
@@ -326,16 +376,23 @@ public Mono<Void> closeGracefully() {
326376
else {
327377
return CompletableFuture.failedFuture(new RuntimeException("Process not started"));
328378
}
329-
}).doOnNext(process -> {
379+
})).doOnNext(process -> {
330380
if (process.exitValue() != 0) {
331381
logger.warn("Process terminated with code " + process.exitValue());
332382
}
333383
}).then(Mono.fromRunnable(() -> {
334-
// The Threads are blocked on readLine so disposeGracefully would not
335-
// interrupt them, therefore we issue an async hard dispose.
336-
inboundScheduler.dispose();
337-
errorScheduler.dispose();
338-
outboundScheduler.dispose();
384+
try {
385+
// The Threads are blocked on readLine so disposeGracefully would not
386+
// interrupt them, therefore we issue an async hard dispose.
387+
inboundScheduler.dispose();
388+
errorScheduler.dispose();
389+
outboundScheduler.dispose();
390+
391+
logger.info("Graceful shutdown completed");
392+
}
393+
catch (Exception e) {
394+
logger.error("Error during graceful shutdown", e);
395+
}
339396
})).then().subscribeOn(Schedulers.boundedElastic());
340397
}
341398

mcp/src/main/java/org/springframework/ai/mcp/server/transport/StdioServerTransport.java

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ private void startInboundProcessing() {
176176
}
177177
}
178178
finally {
179-
if (!isClosing) {
180-
isClosing = true;
181-
}
179+
isClosing = true;
182180
inboundSink.tryEmitComplete();
183181
}
184182
});
@@ -218,9 +216,7 @@ else if (isClosing) {
218216
}
219217
})
220218
.doOnComplete(() -> {
221-
if (!isClosing) {
222-
isClosing = true;
223-
}
219+
isClosing = true;
224220
outboundSink.tryEmitComplete();
225221
})
226222
.doOnError(e -> {
@@ -240,29 +236,35 @@ public Mono<Void> closeGracefully() {
240236
return Mono.fromRunnable(() -> {
241237
isClosing = true;
242238
logger.debug("Initiating graceful shutdown");
243-
// }).then(Mono.delay(Duration.ofMillis(100))).then(Mono.fromRunnable(() -> {
244-
}).then(Mono.fromRunnable(() -> {
245-
try {
246-
// inboundSink.tryEmitComplete();
247-
// outboundSink.tryEmitComplete();
248-
249-
inboundScheduler.dispose();
250-
outboundScheduler.dispose();
251-
252-
// Wait for schedulers to terminate
253-
if (!inboundScheduler.isDisposed()) {
254-
inboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
239+
})
240+
// .then(Mono.defer(() -> {
241+
// inboundSink.tryEmitComplete();
242+
// outboundSink.tryEmitComplete();
243+
// return Mono.delay(Duration.ofMillis(100));
244+
// }))
245+
246+
.then(Mono.fromRunnable(() -> {
247+
try {
248+
249+
inboundScheduler.dispose();
250+
outboundScheduler.dispose();
251+
252+
// Wait for schedulers to terminate
253+
if (!inboundScheduler.isDisposed()) {
254+
inboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
255+
}
256+
if (!outboundScheduler.isDisposed()) {
257+
outboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
258+
}
259+
260+
logger.info("Graceful shutdown completed");
255261
}
256-
if (!outboundScheduler.isDisposed()) {
257-
outboundScheduler.disposeGracefully().block(Duration.ofSeconds(5));
262+
catch (Exception e) {
263+
logger.error("Error during graceful shutdown", e);
258264
}
259-
260-
logger.info("Graceful shutdown completed");
261-
}
262-
catch (Exception e) {
263-
logger.error("Error during graceful shutdown", e);
264-
}
265-
})).then().subscribeOn(Schedulers.boundedElastic());
265+
}))
266+
.then()
267+
.subscribeOn(Schedulers.boundedElastic());
266268
}
267269

268270
@Override

mcp/src/test/java/org/springframework/ai/mcp/client/McpAsyncClientResponseHandlerTests.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ private static class MockMcpTransport implements McpTransport {
4949

5050
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
5151

52-
private Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
52+
private final Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
5353

54-
private Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
54+
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
5555

56-
private Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
56+
private final Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
5757

5858
public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
5959
if (inbound.tryEmitNext(message).isFailure()) {
@@ -82,17 +82,29 @@ public McpSchema.JSONRPCMessage getLastSentMessage() {
8282
return outboundView.blockFirst();
8383
}
8484

85+
private volatile boolean connected = false;
86+
8587
@Override
8688
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
89+
if (connected) {
90+
return Mono.error(new IllegalStateException("Already connected"));
91+
}
92+
connected = true;
8793
return inbound.asFlux()
8894
.publishOn(Schedulers.boundedElastic())
8995
.flatMap(message -> Mono.just(message).transform(handler))
96+
.doFinally(signal -> connected = false)
9097
.then();
9198
}
9299

93100
@Override
94101
public Mono<Void> closeGracefully() {
95-
return Mono.empty();
102+
return Mono.defer(() -> {
103+
connected = false;
104+
outgoing.tryEmitComplete();
105+
inbound.tryEmitComplete();
106+
return Mono.empty();
107+
});
96108
}
97109

98110
@Override

0 commit comments

Comments
 (0)