diff --git a/README.md b/README.md index db6aeb9..612277a 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ The Spring integration module provides seamless integration with Spring AI and S #### Client - **`@McpLogging`** - Annotates methods that handle logging message notifications from MCP servers (requires `clientId` parameter) -- **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers +- **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers (requires `clientId` parameter) - **`@McpElicitation`** - Annotates methods that handle elicitation requests to gather additional information from users (requires `clientId` parameter) - **`@McpProgress`** - Annotates methods that handle progress notifications for long-running operations (requires `clientId` parameter) - **`@McpToolListChanged`** - Annotates methods that handle tool list change notifications from MCP servers @@ -997,10 +997,11 @@ public class SamplingHandler { /** * Handle sampling requests with a synchronous implementation. + * Note: clientId is now required for all @McpSampling annotations. * @param request The create message request * @return The create message result */ - @McpSampling + @McpSampling(clientId = "default-client") public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) { // Process the request and generate a response return CreateMessageResult.builder() @@ -1029,10 +1030,11 @@ public class AsyncSamplingHandler { /** * Handle sampling requests with an asynchronous implementation. + * Note: clientId is now required for all @McpSampling annotations. * @param request The create message request * @return A Mono containing the create message result */ - @McpSampling + @McpSampling(clientId = "default-client") public Mono handleAsyncSamplingRequest(CreateMessageRequest request) { return Mono.just(CreateMessageResult.builder() .role(Role.ASSISTANT) diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java index 7a084db..0110495 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java @@ -28,7 +28,7 @@ * *

* Example usage:

{@code
- * @McpSampling
+ * @McpSampling(clientId = "test-client")
  * public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
  *     // Process the request and return a result
  *     return CreateMessageResult.builder()
@@ -36,7 +36,7 @@
  *         .build();
  * }
  *
- * @McpSampling
+ * @McpSampling(clientId = "test-client")
  * public Mono handleAsyncSamplingRequest(CreateMessageRequest request) {
  *     // Process the request asynchronously and return a result
  *     return Mono.just(CreateMessageResult.builder()
@@ -56,8 +56,8 @@
 
 	/**
 	 * Used as connection or client identifier to select the MCP client, the sampling
-	 * method is associated with. If not specified, is applied to all clients.
+	 * method is associated with.
 	 */
-	String clientId() default "";
+	String clientId();
 
 }
diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/AsyncSamplingSpecification.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/AsyncSamplingSpecification.java
index 8276706..b6aec54 100644
--- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/AsyncSamplingSpecification.java
+++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/AsyncSamplingSpecification.java
@@ -1,5 +1,6 @@
 package org.springaicommunity.mcp.method.sampling;
 
+import java.util.Objects;
 import java.util.function.Function;
 
 import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
@@ -9,4 +10,12 @@
 public record AsyncSamplingSpecification(String clientId,
 		Function> samplingHandler) {
 
+	public AsyncSamplingSpecification {
+		Objects.requireNonNull(clientId, "clientId must not be null");
+		if (clientId.trim().isEmpty()) {
+			throw new IllegalArgumentException("clientId must not be empty");
+		}
+		Objects.requireNonNull(samplingHandler, "samplingHandler must not be null");
+	}
+
 }
\ No newline at end of file
diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/SyncSamplingSpecification.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/SyncSamplingSpecification.java
index 9eb006b..11d3e5c 100644
--- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/SyncSamplingSpecification.java
+++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/sampling/SyncSamplingSpecification.java
@@ -1,5 +1,6 @@
 package org.springaicommunity.mcp.method.sampling;
 
+import java.util.Objects;
 import java.util.function.Function;
 
 import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
@@ -8,4 +9,12 @@
 public record SyncSamplingSpecification(String clientId,
 		Function samplingHandler) {
 
+	public SyncSamplingSpecification {
+		Objects.requireNonNull(clientId, "clientId must not be null");
+		if (clientId.trim().isEmpty()) {
+			throw new IllegalArgumentException("clientId must not be empty");
+		}
+		Objects.requireNonNull(samplingHandler, "samplingHandler must not be null");
+	}
+
 }
\ No newline at end of file
diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/AsyncMcpSamplingMethodCallbackExample.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/AsyncMcpSamplingMethodCallbackExample.java
index 3cb833d..52903b3 100644
--- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/AsyncMcpSamplingMethodCallbackExample.java
+++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/AsyncMcpSamplingMethodCallbackExample.java
@@ -25,7 +25,7 @@ public class AsyncMcpSamplingMethodCallbackExample {
 	 * @param request The sampling request
 	 * @return The sampling result as a Mono
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public Mono handleAsyncSamplingRequest(CreateMessageRequest request) {
 		// Process the request asynchronously and return a result
 		return Mono.just(CreateMessageResult.builder()
@@ -40,7 +40,7 @@ public Mono handleAsyncSamplingRequest(CreateMessageRequest
 	 * @param request The sampling request
 	 * @return The sampling result directly
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest request) {
 		// Process the request and return a direct result
 		return CreateMessageResult.builder()
@@ -55,7 +55,7 @@ public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest requ
 	 * @param request The sampling request
 	 * @return A Mono with an invalid type
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public Mono invalidMonoReturnType(CreateMessageRequest request) {
 		return Mono.just("This method has an invalid return type");
 	}
@@ -65,7 +65,7 @@ public Mono invalidMonoReturnType(CreateMessageRequest request) {
 	 * @param invalidParam An invalid parameter type
 	 * @return The sampling result as a Mono
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public Mono invalidParameterType(String invalidParam) {
 		return Mono.just(CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
@@ -78,7 +78,7 @@ public Mono invalidParameterType(String invalidParam) {
 	 * Example method with no parameters.
 	 * @return The sampling result as a Mono
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public Mono noParameters() {
 		return Mono.just(CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
@@ -93,7 +93,7 @@ public Mono noParameters() {
 	 * @param extraParam An extra parameter
 	 * @return The sampling result as a Mono
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public Mono tooManyParameters(CreateMessageRequest request, String extraParam) {
 		return Mono.just(CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/SyncMcpSamplingMethodCallbackExample.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/SyncMcpSamplingMethodCallbackExample.java
index 15d1828..847dfd1 100644
--- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/SyncMcpSamplingMethodCallbackExample.java
+++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/sampling/SyncMcpSamplingMethodCallbackExample.java
@@ -24,7 +24,7 @@ public class SyncMcpSamplingMethodCallbackExample {
 	 * @param request The sampling request
 	 * @return The sampling result
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
 		// Process the request and return a result
 		return CreateMessageResult.builder()
@@ -39,7 +39,7 @@ public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
 	 * @param request The sampling request
 	 * @return A string (invalid return type)
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public String invalidReturnType(CreateMessageRequest request) {
 		return "This method has an invalid return type";
 	}
@@ -49,7 +49,7 @@ public String invalidReturnType(CreateMessageRequest request) {
 	 * @param invalidParam An invalid parameter type
 	 * @return The sampling result
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public CreateMessageResult invalidParameterType(String invalidParam) {
 		return CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
@@ -62,7 +62,7 @@ public CreateMessageResult invalidParameterType(String invalidParam) {
 	 * Example method with no parameters.
 	 * @return The sampling result
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public CreateMessageResult noParameters() {
 		return CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
@@ -77,7 +77,7 @@ public CreateMessageResult noParameters() {
 	 * @param extraParam An extra parameter
 	 * @return The sampling result
 	 */
-	@McpSampling
+	@McpSampling(clientId = "test-client")
 	public CreateMessageResult tooManyParameters(CreateMessageRequest request, String extraParam) {
 		return CreateMessageResult.builder()
 			.role(Role.ASSISTANT)
diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/AsyncMcpSamplingProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/AsyncMcpSamplingProviderTests.java
index d353f5d..be65c7a 100644
--- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/AsyncMcpSamplingProviderTests.java
+++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/AsyncMcpSamplingProviderTests.java
@@ -35,7 +35,7 @@ void testGetSamplingHandler() {
 		// Create a class with only one valid sampling method
 		class SingleValidMethod {
 
-			@McpSampling
+			@McpSampling(clientId = "test-client")
 			public Mono handleAsyncSamplingRequest(CreateMessageRequest request) {
 				return Mono.just(CreateMessageResult.builder()
 					.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)
@@ -77,7 +77,7 @@ void testDirectResultMethod() {
 		// Create a class with only the direct result method
 		class DirectResultOnly {
 
-			@McpSampling
+			@McpSampling(clientId = "test-client")
 			public CreateMessageResult handleDirectSamplingRequest(CreateMessageRequest request) {
 				return CreateMessageResult.builder()
 					.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)
diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/SyncMcpSamplingProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/SyncMcpSamplingProviderTests.java
index 18830a4..bd15dc5 100644
--- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/SyncMcpSamplingProviderTests.java
+++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/sampling/SyncMcpSamplingProviderTests.java
@@ -32,7 +32,7 @@ void testGetSamplingHandler() {
 		// Create a class with only one valid sampling method
 		class SingleValidMethod {
 
-			@McpSampling
+			@McpSampling(clientId = "test-client")
 			public CreateMessageResult handleSamplingRequest(CreateMessageRequest request) {
 				return CreateMessageResult.builder()
 					.role(io.modelcontextprotocol.spec.McpSchema.Role.ASSISTANT)