Skip to content

Commit f1fea3f

Browse files
committed
Support Prompts for MCP Proxies
1 parent 3bb49fb commit f1fea3f

File tree

6 files changed

+446
-309
lines changed

6 files changed

+446
-309
lines changed

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package software.amazon.smithy.java.mcp.server;
77

8+
import static software.amazon.smithy.java.mcp.model.ListPromptsResult.*;
9+
810
import java.util.List;
911
import java.util.Objects;
1012
import java.util.concurrent.CompletableFuture;
@@ -17,6 +19,7 @@
1719
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
1820
import software.amazon.smithy.java.mcp.model.JsonRpcResponse;
1921
import software.amazon.smithy.java.mcp.model.ListToolsResult;
22+
import software.amazon.smithy.java.mcp.model.PromptInfo;
2023
import software.amazon.smithy.java.mcp.model.ToolInfo;
2124

2225
public abstract class McpServerProxy {
@@ -46,6 +49,24 @@ public List<ToolInfo> listTools() {
4649
}).join();
4750
}
4851

52+
public List<PromptInfo> listPrompts() {
53+
JsonRpcRequest request = JsonRpcRequest.builder()
54+
.method("prompts/list")
55+
.id(generateRequestId())
56+
.jsonrpc("2.0")
57+
.build();
58+
return rpc(request).thenApply(response -> {
59+
if (response.getError() != null) {
60+
throw new RuntimeException("Error listing prompts: " + response.getError().getMessage());
61+
}
62+
return response.getResult()
63+
.asShape(builder())
64+
.getPrompts()
65+
.stream()
66+
.toList();
67+
}).join();
68+
}
69+
4970
public void initialize(
5071
Consumer<JsonRpcResponse> notificationConsumer,
5172
JsonRpcRequest initializeRequest,

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import software.amazon.smithy.java.mcp.model.JsonRpcResponse;
5353
import software.amazon.smithy.java.mcp.model.ListPromptsResult;
5454
import software.amazon.smithy.java.mcp.model.ListToolsResult;
55+
import software.amazon.smithy.java.mcp.model.PromptInfo;
5556
import software.amazon.smithy.java.mcp.model.Prompts;
5657
import software.amazon.smithy.java.mcp.model.ServerInfo;
5758
import software.amazon.smithy.java.mcp.model.TextContent;
@@ -84,7 +85,6 @@ public final class McpService {
8485

8586
private final Map<String, Tool> tools;
8687
private final Map<String, Prompt> prompts;
87-
private final PromptProcessor promptProcessor;
8888
private final String serviceName;
8989
private final String version;
9090
private final Map<String, McpServerProxy> proxies;
@@ -107,8 +107,7 @@ public final class McpService {
107107
this.schemaIndex =
108108
SchemaIndex.compose(services.values().stream().map(Service::schemaIndex).toArray(SchemaIndex[]::new));
109109
this.tools = createTools(services);
110-
this.prompts = PromptLoader.loadPrompts(services.values());
111-
this.promptProcessor = new PromptProcessor();
110+
this.prompts = new ConcurrentHashMap<>(PromptLoader.loadPrompts(services.values()));
112111
this.serviceName = name;
113112
this.version = version;
114113
this.proxies = proxyList.stream().collect(Collectors.toMap(McpServerProxy::name, p -> p));
@@ -236,7 +235,7 @@ private JsonRpcResponse handlePromptsGet(JsonRpcRequest req) {
236235
throw new RuntimeException("Prompt not found: " + promptName);
237236
}
238237

239-
var result = promptProcessor.buildPromptResult(prompt, promptArguments);
238+
var result = prompt.getPromptResult(promptArguments, req.getId());
240239
return createSuccessResponse(req.getId(), result);
241240
}
242241

@@ -342,6 +341,19 @@ public void initializeProxies(Consumer<JsonRpcResponse> responseWriter) {
342341
for (var toolInfo : proxyTools) {
343342
tools.put(toolInfo.getName(), new Tool(toolInfo, proxy.name(), proxy));
344343
}
344+
345+
// Fetch and register prompts from proxy
346+
try {
347+
List<PromptInfo> proxyPrompts = proxy.listPrompts();
348+
for (var promptInfo : proxyPrompts) {
349+
var normalizedName = PromptLoader.normalize(promptInfo.getName());
350+
if (!prompts.containsKey(normalizedName)) {
351+
prompts.put(normalizedName, new Prompt(promptInfo, proxy));
352+
}
353+
}
354+
} catch (Exception e) {
355+
LOG.error("Failed to fetch prompts from proxy: " + proxy.name(), e);
356+
}
345357
}
346358
}
347359
}
@@ -376,6 +388,19 @@ public void addNewProxy(McpServerProxy mcpServerProxy, Consumer<JsonRpcResponse>
376388
} catch (Exception e) {
377389
LOG.error("Failed to fetch tools from proxy", e);
378390
}
391+
392+
// Also fetch prompts from the new proxy
393+
try {
394+
List<PromptInfo> proxyPrompts = mcpServerProxy.listPrompts();
395+
for (var promptInfo : proxyPrompts) {
396+
var normalizedName = PromptLoader.normalize(promptInfo.getName());
397+
if (!prompts.containsKey(normalizedName)) {
398+
prompts.put(normalizedName, new Prompt(promptInfo, mcpServerProxy));
399+
}
400+
}
401+
} catch (Exception e) {
402+
LOG.error("Failed to fetch prompts from proxy: " + mcpServerProxy.name(), e);
403+
}
379404
}
380405

381406
/**

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/Prompt.java

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,226 @@
55

66
package software.amazon.smithy.java.mcp.server;
77

8+
import java.util.HashMap;
9+
import java.util.List;
10+
import java.util.Map;
11+
import java.util.Set;
12+
import java.util.regex.Matcher;
13+
import java.util.regex.Pattern;
14+
import java.util.stream.Collectors;
15+
import software.amazon.smithy.java.core.serde.document.Document;
16+
import software.amazon.smithy.java.mcp.model.GetPromptResult;
17+
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
18+
import software.amazon.smithy.java.mcp.model.PromptArgument;
819
import software.amazon.smithy.java.mcp.model.PromptInfo;
20+
import software.amazon.smithy.java.mcp.model.PromptMessage;
21+
import software.amazon.smithy.java.mcp.model.PromptMessageContent;
22+
import software.amazon.smithy.java.mcp.model.PromptMessageContentType;
23+
import software.amazon.smithy.java.mcp.model.PromptRole;
24+
import software.amazon.smithy.utils.SmithyUnstableApi;
925

10-
public record Prompt(PromptInfo promptInfo, String promptTemplate) {}
26+
/**
27+
* Represents a prompt that can be either local (with a template) or proxied to a remote MCP server.
28+
*/
29+
@SmithyUnstableApi
30+
public final class Prompt {
31+
32+
private static final Pattern PROMPT_ARGUMENT_PLACEHOLDER = Pattern.compile("\\{\\{(\\w+)\\}\\}");
33+
34+
private final PromptInfo promptInfo;
35+
private final String promptTemplate;
36+
private final McpServerProxy proxy;
37+
38+
/**
39+
* Creates a local prompt with a template.
40+
*
41+
* @param promptInfo The prompt metadata
42+
* @param promptTemplate The template string containing {{placeholder}} patterns
43+
*/
44+
public Prompt(PromptInfo promptInfo, String promptTemplate) {
45+
this.promptInfo = promptInfo;
46+
this.promptTemplate = promptTemplate;
47+
this.proxy = null;
48+
}
49+
50+
/**
51+
* Creates a proxy prompt that delegates to a remote MCP server.
52+
*
53+
* @param promptInfo The prompt metadata
54+
* @param proxy The MCP server proxy to delegate to
55+
*/
56+
public Prompt(PromptInfo promptInfo, McpServerProxy proxy) {
57+
this.promptInfo = promptInfo;
58+
this.promptTemplate = null;
59+
this.proxy = proxy;
60+
}
61+
62+
/**
63+
* @return The prompt metadata
64+
*/
65+
public PromptInfo promptInfo() {
66+
return promptInfo;
67+
}
68+
69+
/**
70+
* Gets the prompt result, either by processing the local template or by
71+
* forwarding the request to the proxy server.
72+
*
73+
* @param arguments Document containing argument values for template substitution
74+
* @param requestId The request ID to use for proxy calls (may be null for local prompts)
75+
* @return GetPromptResult with processed template or proxy response
76+
*/
77+
public GetPromptResult getPromptResult(Document arguments, Document requestId) {
78+
if (proxy != null) {
79+
return delegateToProxy(arguments, requestId);
80+
}
81+
return buildLocalPromptResult(arguments);
82+
}
83+
84+
/**
85+
* Delegates the prompt request to the proxy server via RPC.
86+
*/
87+
private GetPromptResult delegateToProxy(Document arguments, Document requestId) {
88+
Map<String, Document> params = new HashMap<>();
89+
params.put("name", Document.of(promptInfo.getName()));
90+
if (arguments != null) {
91+
params.put("arguments", arguments);
92+
}
93+
94+
JsonRpcRequest request = JsonRpcRequest.builder()
95+
.method("prompts/get")
96+
.id(requestId)
97+
.params(Document.of(params))
98+
.jsonrpc("2.0")
99+
.build();
100+
101+
return proxy.rpc(request).thenApply(response -> {
102+
if (response.getError() != null) {
103+
throw new RuntimeException("Error getting prompt: " + response.getError().getMessage());
104+
}
105+
return response.getResult().asShape(GetPromptResult.builder());
106+
}).join();
107+
}
108+
109+
/**
110+
* Builds a GetPromptResult from the local template and provided arguments.
111+
*/
112+
private GetPromptResult buildLocalPromptResult(Document arguments) {
113+
if (promptTemplate == null) {
114+
return GetPromptResult.builder()
115+
.description(promptInfo.getDescription())
116+
.messages(List.of(
117+
PromptMessage.builder()
118+
.role(PromptRole.ASSISTANT.getValue())
119+
.content(PromptMessageContent.builder()
120+
.type(PromptMessageContentType.TEXT)
121+
.text("Template is required for the prompt:" + promptInfo.getName())
122+
.build())
123+
.build()))
124+
.build();
125+
}
126+
127+
var requiredArguments = getRequiredArguments();
128+
129+
if (!requiredArguments.isEmpty() && arguments == null) {
130+
return GetPromptResult.builder()
131+
.description(promptInfo.getDescription())
132+
.messages(List.of(PromptMessage.builder()
133+
.role(PromptRole.USER.getValue())
134+
.content(PromptMessageContent.builder()
135+
.type(PromptMessageContentType.TEXT)
136+
.text("Tell user that there are missing arguments for the prompt : "
137+
+ requiredArguments)
138+
.build())
139+
.build()))
140+
.build();
141+
}
142+
143+
String processedText = applyTemplateArguments(promptTemplate, arguments);
144+
145+
return GetPromptResult.builder()
146+
.description(promptInfo.getDescription())
147+
.messages(List.of(
148+
PromptMessage.builder()
149+
.role(PromptRole.USER.getValue())
150+
.content(PromptMessageContent.builder()
151+
.type(PromptMessageContentType.TEXT)
152+
.text(processedText)
153+
.build())
154+
.build()))
155+
.build();
156+
}
157+
158+
/**
159+
* Applies template arguments to a template string.
160+
*
161+
* @param template The template string containing {{placeholder}} patterns
162+
* @param arguments Document containing replacement values
163+
* @return The template with all placeholders replaced
164+
*/
165+
private String applyTemplateArguments(String template, Document arguments) {
166+
// Common cases
167+
if (template == null || arguments == null || template.isEmpty()) {
168+
return template;
169+
}
170+
171+
// Avoid any regex work if there are no potential placeholders
172+
int firstBrace = template.indexOf("{{");
173+
if (firstBrace == -1) {
174+
return template;
175+
}
176+
177+
Matcher matcher = PROMPT_ARGUMENT_PLACEHOLDER.matcher(template);
178+
179+
int matchCount = 0;
180+
int estimatedResultLength = template.length();
181+
Map<String, String> replacementCache = new HashMap<>();
182+
183+
while (matcher.find()) {
184+
matchCount++;
185+
String argName = matcher.group(1);
186+
187+
// Only look up each unique argument once
188+
if (!replacementCache.containsKey(argName)) {
189+
Document argValue = arguments.getMember(argName);
190+
String replacement = (argValue != null) ? argValue.asString() : "";
191+
replacementCache.put(argName, replacement);
192+
193+
// Adjust estimated length (subtract placeholder length, add replacement length)
194+
estimatedResultLength = estimatedResultLength - matcher.group(0).length() + replacement.length();
195+
}
196+
}
197+
198+
// If no matches found, return original template
199+
if (matchCount == 0) {
200+
return template;
201+
}
202+
203+
// Reset matcher for the actual replacement pass
204+
matcher.reset();
205+
206+
StringBuilder result = new StringBuilder(estimatedResultLength);
207+
208+
// Single-pass replacement using cached values
209+
while (matcher.find()) {
210+
String argName = matcher.group(1);
211+
String replacement = replacementCache.get(argName);
212+
matcher.appendReplacement(result, Matcher.quoteReplacement(replacement));
213+
}
214+
215+
matcher.appendTail(result);
216+
217+
return result.toString();
218+
}
219+
220+
/**
221+
* Extracts the set of required argument names from the PromptInfo.
222+
*/
223+
private Set<String> getRequiredArguments() {
224+
return promptInfo.getArguments()
225+
.stream()
226+
.filter(PromptArgument::isRequired)
227+
.map(PromptArgument::getName)
228+
.collect(Collectors.toSet());
229+
}
230+
}

0 commit comments

Comments
 (0)