|
17 | 17 | package org.springframework.ai.mcp.server.common.autoconfigure; |
18 | 18 |
|
19 | 19 | import java.util.List; |
| 20 | +import java.util.concurrent.ConcurrentHashMap; |
| 21 | +import java.util.concurrent.CopyOnWriteArrayList; |
20 | 22 | import java.util.function.BiConsumer; |
21 | 23 | import java.util.function.BiFunction; |
| 24 | +import java.util.stream.Stream; |
22 | 25 |
|
23 | 26 | import io.modelcontextprotocol.client.McpSyncClient; |
24 | 27 | import io.modelcontextprotocol.json.TypeRef; |
25 | 28 | import io.modelcontextprotocol.server.McpAsyncServer; |
26 | 29 | import io.modelcontextprotocol.server.McpAsyncServerExchange; |
27 | 30 | import io.modelcontextprotocol.server.McpServerFeatures; |
28 | 31 | import io.modelcontextprotocol.server.McpServerFeatures.AsyncCompletionSpecification; |
| 32 | +import io.modelcontextprotocol.server.McpServerFeatures.AsyncPromptSpecification; |
| 33 | +import io.modelcontextprotocol.server.McpServerFeatures.AsyncResourceSpecification; |
29 | 34 | import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; |
30 | 35 | import io.modelcontextprotocol.server.McpServerFeatures.SyncCompletionSpecification; |
31 | 36 | import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; |
|
39 | 44 | import io.modelcontextprotocol.spec.McpServerTransportProvider; |
40 | 45 | import org.junit.jupiter.api.Test; |
41 | 46 | import org.mockito.Mockito; |
| 47 | +import org.springaicommunity.mcp.annotation.McpArg; |
| 48 | +import org.springaicommunity.mcp.annotation.McpComplete; |
| 49 | +import org.springaicommunity.mcp.annotation.McpPrompt; |
| 50 | +import org.springaicommunity.mcp.annotation.McpResource; |
| 51 | +import org.springaicommunity.mcp.annotation.McpTool; |
| 52 | +import org.springaicommunity.mcp.annotation.McpToolParam; |
42 | 53 | import reactor.core.publisher.Mono; |
43 | 54 |
|
44 | 55 | import org.springframework.ai.mcp.SyncMcpToolCallback; |
| 56 | +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; |
| 57 | +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration; |
45 | 58 | import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerChangeNotificationProperties; |
46 | 59 | import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties; |
47 | 60 | import org.springframework.ai.tool.ToolCallback; |
|
50 | 63 | import org.springframework.boot.test.context.runner.ApplicationContextRunner; |
51 | 64 | import org.springframework.context.annotation.Bean; |
52 | 65 | import org.springframework.context.annotation.Configuration; |
| 66 | +import org.springframework.stereotype.Component; |
| 67 | +import org.springframework.test.util.ReflectionTestUtils; |
53 | 68 |
|
54 | 69 | import static org.assertj.core.api.Assertions.assertThat; |
55 | 70 | import static org.mockito.Mockito.when; |
@@ -345,6 +360,72 @@ void toolCallbackProviderConfiguration() { |
345 | 360 | .run(context -> assertThat(context).hasSingleBean(ToolCallbackProvider.class)); |
346 | 361 | } |
347 | 362 |
|
| 363 | + @SuppressWarnings("unchecked") |
| 364 | + @Test |
| 365 | + void syncServerSpecificationConfiguration() { |
| 366 | + this.contextRunner |
| 367 | + .withUserConfiguration(McpServerAnnotationScannerAutoConfiguration.class, |
| 368 | + McpServerSpecificationFactoryAutoConfiguration.class) |
| 369 | + .withBean(SyncTestMcpSpecsComponent.class) |
| 370 | + .run(context -> { |
| 371 | + McpSyncServer syncServer = context.getBean(McpSyncServer.class); |
| 372 | + McpAsyncServer asyncServer = (McpAsyncServer) ReflectionTestUtils.getField(syncServer, "asyncServer"); |
| 373 | + |
| 374 | + CopyOnWriteArrayList<AsyncToolSpecification> tools = (CopyOnWriteArrayList<AsyncToolSpecification>) ReflectionTestUtils |
| 375 | + .getField(asyncServer, "tools"); |
| 376 | + assertThat(tools).hasSize(1); |
| 377 | + assertThat(tools.get(0).tool().name()).isEqualTo("add"); |
| 378 | + |
| 379 | + ConcurrentHashMap<String, AsyncResourceSpecification> resources = (ConcurrentHashMap<String, AsyncResourceSpecification>) ReflectionTestUtils |
| 380 | + .getField(asyncServer, "resources"); |
| 381 | + assertThat(resources).hasSize(1); |
| 382 | + assertThat(resources.get("config://{key}")).isNotNull(); |
| 383 | + |
| 384 | + ConcurrentHashMap<String, AsyncPromptSpecification> prompts = (ConcurrentHashMap<String, AsyncPromptSpecification>) ReflectionTestUtils |
| 385 | + .getField(asyncServer, "prompts"); |
| 386 | + assertThat(prompts).hasSize(1); |
| 387 | + assertThat(prompts.get("greeting")).isNotNull(); |
| 388 | + |
| 389 | + ConcurrentHashMap<McpSchema.CompleteReference, AsyncCompletionSpecification> completions = (ConcurrentHashMap<McpSchema.CompleteReference, AsyncCompletionSpecification>) ReflectionTestUtils |
| 390 | + .getField(asyncServer, "completions"); |
| 391 | + assertThat(completions).hasSize(1); |
| 392 | + assertThat(completions.keySet().iterator().next()).isInstanceOf(McpSchema.CompleteReference.class); |
| 393 | + }); |
| 394 | + } |
| 395 | + |
| 396 | + @SuppressWarnings("unchecked") |
| 397 | + @Test |
| 398 | + void asyncServerSpecificationConfiguration() { |
| 399 | + this.contextRunner |
| 400 | + .withUserConfiguration(McpServerAnnotationScannerAutoConfiguration.class, |
| 401 | + McpServerSpecificationFactoryAutoConfiguration.class) |
| 402 | + .withBean(AsyncTestMcpSpecsComponent.class) |
| 403 | + .withPropertyValues("spring.ai.mcp.server.type=async") |
| 404 | + .run(context -> { |
| 405 | + McpAsyncServer asyncServer = context.getBean(McpAsyncServer.class); |
| 406 | + |
| 407 | + CopyOnWriteArrayList<AsyncToolSpecification> tools = (CopyOnWriteArrayList<AsyncToolSpecification>) ReflectionTestUtils |
| 408 | + .getField(asyncServer, "tools"); |
| 409 | + assertThat(tools).hasSize(1); |
| 410 | + assertThat(tools.get(0).tool().name()).isEqualTo("add"); |
| 411 | + |
| 412 | + ConcurrentHashMap<String, AsyncResourceSpecification> resources = (ConcurrentHashMap<String, AsyncResourceSpecification>) ReflectionTestUtils |
| 413 | + .getField(asyncServer, "resources"); |
| 414 | + assertThat(resources).hasSize(1); |
| 415 | + assertThat(resources.get("config://{key}")).isNotNull(); |
| 416 | + |
| 417 | + ConcurrentHashMap<String, AsyncPromptSpecification> prompts = (ConcurrentHashMap<String, AsyncPromptSpecification>) ReflectionTestUtils |
| 418 | + .getField(asyncServer, "prompts"); |
| 419 | + assertThat(prompts).hasSize(1); |
| 420 | + assertThat(prompts.get("greeting")).isNotNull(); |
| 421 | + |
| 422 | + ConcurrentHashMap<McpSchema.CompleteReference, AsyncCompletionSpecification> completions = (ConcurrentHashMap<McpSchema.CompleteReference, AsyncCompletionSpecification>) ReflectionTestUtils |
| 423 | + .getField(asyncServer, "completions"); |
| 424 | + assertThat(completions).hasSize(1); |
| 425 | + assertThat(completions.keySet().iterator().next()).isInstanceOf(McpSchema.CompleteReference.class); |
| 426 | + }); |
| 427 | + } |
| 428 | + |
348 | 429 | @Configuration |
349 | 430 | static class TestResourceConfiguration { |
350 | 431 |
|
@@ -516,4 +597,76 @@ McpServerTransport customTransport() { |
516 | 597 |
|
517 | 598 | } |
518 | 599 |
|
| 600 | + @Component |
| 601 | + static class SyncTestMcpSpecsComponent { |
| 602 | + |
| 603 | + @McpTool(name = "add", description = "Add two numbers together", title = "Add Two Numbers Together", |
| 604 | + annotations = @McpTool.McpAnnotations(title = "Rectangle Area Calculator", readOnlyHint = true, |
| 605 | + destructiveHint = false, idempotentHint = true)) |
| 606 | + public int add(@McpToolParam(description = "First number", required = true) int a, |
| 607 | + @McpToolParam(description = "Second number", required = true) int b) { |
| 608 | + return a + b; |
| 609 | + } |
| 610 | + |
| 611 | + @McpResource(uri = "config://{key}", name = "Configuration", description = "Provides configuration data") |
| 612 | + public String getConfig(String key) { |
| 613 | + return "config value"; |
| 614 | + } |
| 615 | + |
| 616 | + @McpPrompt(name = "greeting", description = "Generate a greeting message") |
| 617 | + public McpSchema.GetPromptResult greeting( |
| 618 | + @McpArg(name = "name", description = "User's name", required = true) String name) { |
| 619 | + |
| 620 | + String message = "Hello, " + name + "! How can I help you today?"; |
| 621 | + |
| 622 | + return new McpSchema.GetPromptResult("Greeting", |
| 623 | + List.of(new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent(message)))); |
| 624 | + } |
| 625 | + |
| 626 | + @McpComplete(prompt = "city-search") |
| 627 | + public List<String> completeCityName(String prefix) { |
| 628 | + return Stream.of("New York", "Los Angeles", "Chicago", "Houston", "Phoenix") |
| 629 | + .filter(city -> city.toLowerCase().startsWith(prefix.toLowerCase())) |
| 630 | + .limit(10) |
| 631 | + .toList(); |
| 632 | + } |
| 633 | + |
| 634 | + } |
| 635 | + |
| 636 | + @Component |
| 637 | + static class AsyncTestMcpSpecsComponent { |
| 638 | + |
| 639 | + @McpTool(name = "add", description = "Add two numbers together", title = "Add Two Numbers Together", |
| 640 | + annotations = @McpTool.McpAnnotations(title = "Rectangle Area Calculator", readOnlyHint = true, |
| 641 | + destructiveHint = false, idempotentHint = true)) |
| 642 | + public Mono<Integer> add(@McpToolParam(description = "First number", required = true) int a, |
| 643 | + @McpToolParam(description = "Second number", required = true) int b) { |
| 644 | + return Mono.just(a + b); |
| 645 | + } |
| 646 | + |
| 647 | + @McpResource(uri = "config://{key}", name = "Configuration", description = "Provides configuration data") |
| 648 | + public Mono<String> getConfig(String key) { |
| 649 | + return Mono.just("config value"); |
| 650 | + } |
| 651 | + |
| 652 | + @McpPrompt(name = "greeting", description = "Generate a greeting message") |
| 653 | + public Mono<McpSchema.GetPromptResult> greeting( |
| 654 | + @McpArg(name = "name", description = "User's name", required = true) String name) { |
| 655 | + |
| 656 | + String message = "Hello, " + name + "! How can I help you today?"; |
| 657 | + |
| 658 | + return Mono.just(new McpSchema.GetPromptResult("Greeting", List |
| 659 | + .of(new McpSchema.PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent(message))))); |
| 660 | + } |
| 661 | + |
| 662 | + @McpComplete(prompt = "city-search") |
| 663 | + public Mono<List<String>> completeCityName(String prefix) { |
| 664 | + return Mono.just(Stream.of("New York", "Los Angeles", "Chicago", "Houston", "Phoenix") |
| 665 | + .filter(city -> city.toLowerCase().startsWith(prefix.toLowerCase())) |
| 666 | + .limit(10) |
| 667 | + .toList()); |
| 668 | + } |
| 669 | + |
| 670 | + } |
| 671 | + |
519 | 672 | } |
0 commit comments