Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ The Spring integration module provides seamless integration with Spring AI and S

#### Client
- **`@McpLogging`** - Annotates methods that handle logging message notifications from MCP servers (requires `clientId` parameter)
- **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers
- **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers (requires `clientId` parameter)
- **`@McpElicitation`** - Annotates methods that handle elicitation requests to gather additional information from users (requires `clientId` parameter)
- **`@McpProgress`** - Annotates methods that handle progress notifications for long-running operations (requires `clientId` parameter)
- **`@McpToolListChanged`** - Annotates methods that handle tool list change notifications from MCP servers
Expand Down Expand Up @@ -997,10 +997,11 @@ public class SamplingHandler {

/**
* Handle sampling requests with a synchronous implementation.
* Note: clientId is now required for all @McpSampling annotations.
* @param request The create message request
* @return The create message result
*/
@McpSampling
@McpSampling(clientId = "default-client")
public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
// Process the request and generate a response
return CreateMessageResult.builder()
Expand Down Expand Up @@ -1029,10 +1030,11 @@ public class AsyncSamplingHandler {

/**
* Handle sampling requests with an asynchronous implementation.
* Note: clientId is now required for all @McpSampling annotations.
* @param request The create message request
* @return A Mono containing the create message result
*/
@McpSampling
@McpSampling(clientId = "default-client")
public Mono<CreateMessageResult> handleAsyncSamplingRequest(CreateMessageRequest request) {
return Mono.just(CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
*
* <p>
* Example usage: <pre>{@code
* &#64;McpSampling
* &#64;McpSampling(clientId = "test-client")
* public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
* // Process the request and return a result
* return CreateMessageResult.builder()
* .message("Generated response")
* .build();
* }
*
* &#64;McpSampling
* &#64;McpSampling(clientId = "test-client")
* public Mono<CreateMessageResult> handleAsyncSamplingRequest(CreateMessageRequest request) {
* // Process the request asynchronously and return a result
* return Mono.just(CreateMessageResult.builder()
Expand All @@ -56,8 +56,8 @@

/**
* Used as connection or client identifier to select the MCP client, the sampling
* method is associated with. If not specified, is applied to all clients.
* method is associated with.
*/
String clientId() default "";
String clientId();

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.springaicommunity.mcp.method.sampling;

import java.util.Objects;
import java.util.function.Function;

import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
Expand All @@ -9,4 +10,12 @@
public record AsyncSamplingSpecification(String clientId,
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler) {

public AsyncSamplingSpecification {
Objects.requireNonNull(clientId, "clientId must not be null");
if (clientId.trim().isEmpty()) {
throw new IllegalArgumentException("clientId must not be empty");
}
Objects.requireNonNull(samplingHandler, "samplingHandler must not be null");
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.springaicommunity.mcp.method.sampling;

import java.util.Objects;
import java.util.function.Function;

import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
Expand All @@ -8,4 +9,12 @@
public record SyncSamplingSpecification(String clientId,
Function<CreateMessageRequest, CreateMessageResult> samplingHandler) {

public SyncSamplingSpecification {
Objects.requireNonNull(clientId, "clientId must not be null");
if (clientId.trim().isEmpty()) {
throw new IllegalArgumentException("clientId must not be empty");
}
Objects.requireNonNull(samplingHandler, "samplingHandler must not be null");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class AsyncMcpSamplingMethodCallbackExample {
* @param request The sampling request
* @return The sampling result as a Mono
*/
@McpSampling
@McpSampling(clientId = "test-client")
public Mono<CreateMessageResult> handleAsyncSamplingRequest(CreateMessageRequest request) {
// Process the request asynchronously and return a result
return Mono.just(CreateMessageResult.builder()
Expand All @@ -40,7 +40,7 @@ public Mono<CreateMessageResult> handleAsyncSamplingRequest(CreateMessageRequest
* @param request The sampling request
* @return The sampling result directly
*/
@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest request) {
// Process the request and return a direct result
return CreateMessageResult.builder()
Expand All @@ -55,7 +55,7 @@ public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest requ
* @param request The sampling request
* @return A Mono with an invalid type
*/
@McpSampling
@McpSampling(clientId = "test-client")
public Mono<String> invalidMonoReturnType(CreateMessageRequest request) {
return Mono.just("This method has an invalid return type");
}
Expand All @@ -65,7 +65,7 @@ public Mono<String> invalidMonoReturnType(CreateMessageRequest request) {
* @param invalidParam An invalid parameter type
* @return The sampling result as a Mono
*/
@McpSampling
@McpSampling(clientId = "test-client")
public Mono<CreateMessageResult> invalidParameterType(String invalidParam) {
return Mono.just(CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand All @@ -78,7 +78,7 @@ public Mono<CreateMessageResult> invalidParameterType(String invalidParam) {
* Example method with no parameters.
* @return The sampling result as a Mono
*/
@McpSampling
@McpSampling(clientId = "test-client")
public Mono<CreateMessageResult> noParameters() {
return Mono.just(CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand All @@ -93,7 +93,7 @@ public Mono<CreateMessageResult> noParameters() {
* @param extraParam An extra parameter
* @return The sampling result as a Mono
*/
@McpSampling
@McpSampling(clientId = "test-client")
public Mono<CreateMessageResult> tooManyParameters(CreateMessageRequest request, String extraParam) {
return Mono.just(CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class SyncMcpSamplingMethodCallbackExample {
* @param request The sampling request
* @return The sampling result
*/
@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
// Process the request and return a result
return CreateMessageResult.builder()
Expand All @@ -39,7 +39,7 @@ public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
* @param request The sampling request
* @return A string (invalid return type)
*/
@McpSampling
@McpSampling(clientId = "test-client")
public String invalidReturnType(CreateMessageRequest request) {
return "This method has an invalid return type";
}
Expand All @@ -49,7 +49,7 @@ public String invalidReturnType(CreateMessageRequest request) {
* @param invalidParam An invalid parameter type
* @return The sampling result
*/
@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult invalidParameterType(String invalidParam) {
return CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand All @@ -62,7 +62,7 @@ public CreateMessageResult invalidParameterType(String invalidParam) {
* Example method with no parameters.
* @return The sampling result
*/
@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult noParameters() {
return CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand All @@ -77,7 +77,7 @@ public CreateMessageResult noParameters() {
* @param extraParam An extra parameter
* @return The sampling result
*/
@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult tooManyParameters(CreateMessageRequest request, String extraParam) {
return CreateMessageResult.builder()
.role(Role.ASSISTANT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void testGetSamplingHandler() {
// Create a class with only one valid sampling method
class SingleValidMethod {

@McpSampling
@McpSampling(clientId = "test-client")
public Mono<CreateMessageResult> handleAsyncSamplingRequest(CreateMessageRequest request) {
return Mono.just(CreateMessageResult.builder()
.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)
Expand Down Expand Up @@ -77,7 +77,7 @@ void testDirectResultMethod() {
// Create a class with only the direct result method
class DirectResultOnly {

@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest request) {
return CreateMessageResult.builder()
.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void testGetSamplingHandler() {
// Create a class with only one valid sampling method
class SingleValidMethod {

@McpSampling
@McpSampling(clientId = "test-client")
public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
return CreateMessageResult.builder()
.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)
Expand Down