|
4 | 4 | import java.util.List; |
5 | 5 | import java.util.Set; |
6 | 6 | import java.util.function.Function; |
7 | | -import java.util.function.Supplier; |
8 | 7 |
|
9 | 8 | import dev.langchain4j.mcp.McpToolProvider; |
10 | 9 | import dev.langchain4j.mcp.client.DefaultMcpClient; |
11 | 10 | import dev.langchain4j.mcp.client.McpClient; |
12 | 11 | import dev.langchain4j.mcp.client.transport.McpTransport; |
13 | 12 | import dev.langchain4j.mcp.client.transport.stdio.StdioMcpTransport; |
| 13 | +import dev.langchain4j.model.chat.ChatLanguageModel; |
14 | 14 | import dev.langchain4j.service.tool.ToolProvider; |
| 15 | +import io.quarkiverse.langchain4j.ModelName; |
15 | 16 | import io.quarkiverse.langchain4j.mcp.runtime.config.McpBuildTimeConfiguration; |
16 | 17 | import io.quarkiverse.langchain4j.mcp.runtime.config.McpClientBuildTimeConfig; |
17 | 18 | import io.quarkiverse.langchain4j.mcp.runtime.config.McpClientRuntimeConfig; |
|
24 | 25 | @Recorder |
25 | 26 | public class McpRecorder { |
26 | 27 |
|
27 | | - public Supplier<McpClient> mcpClientSupplier(String key, McpBuildTimeConfiguration buildTimeConfiguration, |
| 28 | + public Function<SyntheticCreationalContext<McpClient>, McpClient> mcpClientSupplier(String clientName, |
| 29 | + McpBuildTimeConfiguration buildTimeConfiguration, |
28 | 30 | McpRuntimeConfiguration mcpRuntimeConfiguration) { |
29 | | - return new Supplier<McpClient>() { |
| 31 | + return new Function<>() { |
30 | 32 | @Override |
31 | | - public McpClient get() { |
32 | | - McpTransport transport = null; |
33 | | - McpClientBuildTimeConfig buildTimeConfig = buildTimeConfiguration.clients().get(key); |
34 | | - McpClientRuntimeConfig runtimeConfig = mcpRuntimeConfiguration.clients().get(key); |
35 | | - switch (buildTimeConfig.transportType()) { |
36 | | - case STDIO: |
| 33 | + public McpClient apply(SyntheticCreationalContext<McpClient> context) { |
| 34 | + McpTransport transport; |
| 35 | + McpClientBuildTimeConfig buildTimeConfig = buildTimeConfiguration.clients().get(clientName); |
| 36 | + McpClientRuntimeConfig runtimeConfig = mcpRuntimeConfiguration.clients().get(clientName); |
| 37 | + transport = switch (buildTimeConfig.transportType()) { |
| 38 | + case STDIO -> { |
37 | 39 | List<String> command = runtimeConfig.command().orElseThrow(() -> new ConfigurationException( |
38 | | - "MCP client configuration named " + key + " is missing the 'command' property")); |
39 | | - transport = new StdioMcpTransport.Builder() |
| 40 | + "MCP client configuration named " + clientName + " is missing the 'command' property")); |
| 41 | + yield new StdioMcpTransport.Builder() |
40 | 42 | .command(command) |
41 | 43 | .logEvents(runtimeConfig.logResponses().orElse(false)) |
42 | 44 | .environment(runtimeConfig.environment()) |
43 | 45 | .build(); |
44 | | - break; |
45 | | - case HTTP: |
46 | | - transport = new QuarkusHttpMcpTransport.Builder() |
47 | | - .sseUrl(runtimeConfig.url().orElseThrow(() -> new ConfigurationException( |
48 | | - "MCP client configuration named " + key + " is missing the 'url' property"))) |
49 | | - .logRequests(runtimeConfig.logRequests().orElse(false)) |
50 | | - .logResponses(runtimeConfig.logResponses().orElse(false)) |
51 | | - .build(); |
52 | | - break; |
53 | | - default: |
54 | | - throw new IllegalArgumentException("Unknown transport type: " + buildTimeConfig.transportType()); |
55 | | - } |
56 | | - return new DefaultMcpClient.Builder() |
| 46 | + } |
| 47 | + case HTTP -> new QuarkusHttpMcpTransport.Builder() |
| 48 | + .sseUrl(runtimeConfig.url().orElseThrow(() -> new ConfigurationException( |
| 49 | + "MCP client configuration named " + clientName + " is missing the 'url' property"))) |
| 50 | + .logRequests(runtimeConfig.logRequests().orElse(false)) |
| 51 | + .logResponses(runtimeConfig.logResponses().orElse(false)) |
| 52 | + .build(); |
| 53 | + }; |
| 54 | + McpClient result = new DefaultMcpClient.Builder() |
57 | 55 | .transport(transport) |
58 | 56 | .toolExecutionTimeout(runtimeConfig.toolExecutionTimeout()) |
59 | 57 | .resourcesTimeout(runtimeConfig.resourcesTimeout()) |
60 | 58 | // TODO: it should be possible to choose a log handler class via configuration |
61 | | - .logHandler(new QuarkusDefaultMcpLogHandler(key)) |
| 59 | + .logHandler(new QuarkusDefaultMcpLogHandler(clientName)) |
62 | 60 | .build(); |
| 61 | + if (runtimeConfig.toolValidationModelName().isPresent()) { |
| 62 | + ChatLanguageModel chatLanguageModel; |
| 63 | + if ("default".equals(runtimeConfig.toolValidationModelName().get())) { |
| 64 | + chatLanguageModel = context.getInjectedReference(ChatLanguageModel.class); |
| 65 | + } else { |
| 66 | + chatLanguageModel = context.getInjectedReference(ChatLanguageModel.class, |
| 67 | + ModelName.Literal.of(runtimeConfig.toolValidationModelName().get())); |
| 68 | + } |
| 69 | + result = new ValidatingMcpClient(result, chatLanguageModel); |
| 70 | + } |
| 71 | + return result; |
63 | 72 | } |
64 | 73 | }; |
65 | 74 | } |
|
0 commit comments