Skip to content

Commit 6757ab8

Browse files
committed
Address review comments
Signed-off-by: Christian Tzolov <[email protected]>
1 parent 65262ec commit 6757ab8

File tree

3 files changed

+50
-31
lines changed

3 files changed

+50
-31
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-server/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientAndWebFluxServerIT.java

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.springframework.beans.factory.ObjectProvider;
4040
import org.springframework.boot.autoconfigure.AutoConfigurations;
4141
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
42-
import org.springframework.boot.test.context.runner.ReactiveWebApplicationContextRunner;
4342
import org.springframework.context.ApplicationContext;
4443
import org.springframework.context.annotation.Bean;
4544
import org.springframework.core.ResolvableType;
@@ -93,11 +92,6 @@ public class SseWebClientAndWebFluxServerIT {
9392
.withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class,
9493
McpClientAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class));
9594

96-
static AtomicReference<LoggingMessageNotification> loggingNotificationRef = new AtomicReference<>();
97-
98-
static CountDownLatch progressLatch = new CountDownLatch(3);
99-
static List<McpSchema.ProgressNotification> progressNotifications = new CopyOnWriteArrayList<>();
100-
10195
@Test
10296
void clientServerCapabilities() {
10397

@@ -184,12 +178,14 @@ void clientServerCapabilities() {
184178
{"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}"""));
185179

186180
// PROGRESS
187-
assertThat(progressLatch.await(5, TimeUnit.SECONDS))
181+
TestContext testContext = clientContext.getBean(TestContext.class);
182+
assertThat(testContext.progressLatch.await(5, TimeUnit.SECONDS))
188183
.as("Should receive progress notifications in reasonable time")
189184
.isTrue();
190-
assertThat(progressNotifications).hasSize(3);
185+
assertThat(testContext.progressNotifications).hasSize(3);
191186

192-
Map<String, McpSchema.ProgressNotification> notificationMap = progressNotifications.stream()
187+
Map<String, McpSchema.ProgressNotification> notificationMap = testContext.progressNotifications
188+
.stream()
193189
.collect(Collectors.toMap(n -> n.message(), n -> n));
194190

195191
// First notification should be 0.0/1.0 progress
@@ -238,7 +234,7 @@ void clientServerCapabilities() {
238234
assertThat(completeResult.meta()).isNull();
239235

240236
// logging message
241-
var logMessage = loggingNotificationRef.get();
237+
var logMessage = testContext.loggingNotificationRef.get();
242238
assertThat(logMessage).isNotNull();
243239
assertThat(logMessage.level()).isEqualTo(LoggingLevel.INFO);
244240
assertThat(logMessage.logger()).isEqualTo("test-logger");
@@ -261,6 +257,16 @@ void clientServerCapabilities() {
261257
});
262258
}
263259

260+
private static class TestContext {
261+
262+
final AtomicReference<LoggingMessageNotification> loggingNotificationRef = new AtomicReference<>();
263+
264+
final CountDownLatch progressLatch = new CountDownLatch(3);
265+
266+
final List<McpSchema.ProgressNotification> progressNotifications = new CopyOnWriteArrayList<>();
267+
268+
}
269+
264270
public static class TestMcpServerConfiguration {
265271

266272
@Bean
@@ -432,13 +438,18 @@ private double evaluateExpression(String expression) {
432438
public static class TestMcpClientConfiguration {
433439

434440
@Bean
435-
McpSyncClientCustomizer clientCustomizer() {
441+
public TestContext testContext() {
442+
return new TestContext();
443+
}
444+
445+
@Bean
446+
McpSyncClientCustomizer clientCustomizer(TestContext testContext) {
436447

437448
return (name, mcpClientSpec) -> {
438449

439450
// Add logging handler
440451
mcpClientSpec = mcpClientSpec.loggingConsumer(logingMessage -> {
441-
loggingNotificationRef.set(logingMessage);
452+
testContext.loggingNotificationRef.set(logingMessage);
442453
logger.info("MCP LOGGING: [{}] {}", logingMessage.level(), logingMessage.data());
443454
});
444455

@@ -464,8 +475,8 @@ McpSyncClientCustomizer clientCustomizer() {
464475

465476
// Progress notification
466477
mcpClientSpec.progressConsumer(progressNotification -> {
467-
progressNotifications.add(progressNotification);
468-
progressLatch.countDown();
478+
testContext.progressNotifications.add(progressNotification);
479+
testContext.progressLatch.countDown();
469480

470481
assertThat(progressNotification.progressToken()).isEqualTo("test-progress-token");
471482
// assertThat(progressNotification.progress()).isEqualTo(0.0);

auto-configurations/mcp/spring-ai-autoconfigure-mcp-stateless-server-webflux/src/test/java/org/springframework/ai/mcp/server/stateless/webflux/autoconfigure/StatelessWebClientAndWebFluxServerIT.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import java.util.Map;
2525

2626
import org.junit.jupiter.api.Test;
27-
import org.slf4j.Logger;
28-
import org.slf4j.LoggerFactory;
2927
import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration;
3028
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration;
3129
import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration;
@@ -71,8 +69,6 @@
7169

7270
public class StatelessWebClientAndWebFluxServerIT {
7371

74-
private static final Logger logger = LoggerFactory.getLogger(StatelessWebClientAndWebFluxServerIT.class);
75-
7672
private final ApplicationContextRunner serverContextRunner = new ApplicationContextRunner()
7773
.withConfiguration(AutoConfigurations.of(McpStatelessServerAutoConfiguration.class,
7874
ToolCallbackConverterAutoConfiguration.class, McpStatelessServerWebFluxAutoConfiguration.class));

auto-configurations/mcp/spring-ai-autoconfigure-mcp-streamable-server-webflux/src/test/java/org/springframework/ai/mcp/server/streamable/webflux/autoconfigure/StreamableWebClientAndWebFluxServerIT.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ public class StreamableWebClientAndWebFluxServerIT {
9595
.withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class,
9696
McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class));
9797

98-
static AtomicReference<LoggingMessageNotification> loggingNotificationRef = new AtomicReference<>();
99-
100-
static CountDownLatch progressLatch = new CountDownLatch(3);
101-
static List<McpSchema.ProgressNotification> progressNotifications = new CopyOnWriteArrayList<>();
102-
10398
@Test
10499
void clientServerCapabilities() {
105100

@@ -185,12 +180,14 @@ void clientServerCapabilities() {
185180
{"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}"""));
186181

187182
// PROGRESS
188-
assertThat(progressLatch.await(5, TimeUnit.SECONDS))
183+
TestContext testContext = clientContext.getBean(TestContext.class);
184+
assertThat(testContext.progressLatch.await(5, TimeUnit.SECONDS))
189185
.as("Should receive progress notifications in reasonable time")
190186
.isTrue();
191-
assertThat(progressNotifications).hasSize(3);
187+
assertThat(testContext.progressNotifications).hasSize(3);
192188

193-
Map<String, McpSchema.ProgressNotification> notificationMap = progressNotifications.stream()
189+
Map<String, McpSchema.ProgressNotification> notificationMap = testContext.progressNotifications
190+
.stream()
194191
.collect(Collectors.toMap(n -> n.message(), n -> n));
195192

196193
// First notification should be 0.0/1.0 progress
@@ -239,7 +236,7 @@ void clientServerCapabilities() {
239236
assertThat(completeResult.meta()).isNull();
240237

241238
// logging message
242-
var logMessage = loggingNotificationRef.get();
239+
var logMessage = testContext.loggingNotificationRef.get();
243240
assertThat(logMessage).isNotNull();
244241
assertThat(logMessage.level()).isEqualTo(LoggingLevel.INFO);
245242
assertThat(logMessage.logger()).isEqualTo("test-logger");
@@ -430,16 +427,31 @@ private double evaluateExpression(String expression) {
430427

431428
}
432429

430+
private static class TestContext {
431+
432+
final AtomicReference<LoggingMessageNotification> loggingNotificationRef = new AtomicReference<>();
433+
434+
final CountDownLatch progressLatch = new CountDownLatch(3);
435+
436+
final List<McpSchema.ProgressNotification> progressNotifications = new CopyOnWriteArrayList<>();
437+
438+
}
439+
433440
public static class TestMcpClientConfiguration {
434441

435442
@Bean
436-
McpSyncClientCustomizer clientCustomizer() {
443+
public TestContext testContext() {
444+
return new TestContext();
445+
}
446+
447+
@Bean
448+
McpSyncClientCustomizer clientCustomizer(TestContext testContext) {
437449

438450
return (name, mcpClientSpec) -> {
439451

440452
// Add logging handler
441453
mcpClientSpec = mcpClientSpec.loggingConsumer(logingMessage -> {
442-
loggingNotificationRef.set(logingMessage);
454+
testContext.loggingNotificationRef.set(logingMessage);
443455
logger.info("MCP LOGGING: [{}] {}", logingMessage.level(), logingMessage.data());
444456
});
445457

@@ -465,8 +477,8 @@ McpSyncClientCustomizer clientCustomizer() {
465477

466478
// Progress notification
467479
mcpClientSpec.progressConsumer(progressNotification -> {
468-
progressNotifications.add(progressNotification);
469-
progressLatch.countDown();
480+
testContext.progressNotifications.add(progressNotification);
481+
testContext.progressLatch.countDown();
470482
});
471483
};
472484
}

0 commit comments

Comments
 (0)