Skip to content

Commit ad810e8

Browse files
committed
Roots integration plus tests
1 parent 0d72f29 commit ad810e8

File tree

7 files changed

+259
-0
lines changed

7 files changed

+259
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package io.quarkiverse.langchain4j.mcp.test;
2+
3+
import static io.quarkiverse.langchain4j.mcp.test.McpServerHelper.skipTestsIfJbangNotAvailable;
4+
import static io.quarkiverse.langchain4j.mcp.test.McpServerHelper.startServerHttp;
5+
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.AfterAll;
9+
import org.junit.jupiter.api.BeforeAll;
10+
import org.junit.jupiter.api.extension.RegisterExtension;
11+
12+
import io.quarkus.test.QuarkusUnitTest;
13+
14+
class McpRootsHttpTransportTest extends McpRootsTestBase {
15+
16+
private static Process process;
17+
18+
@RegisterExtension
19+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
20+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
21+
.addClasses(McpServerHelper.class))
22+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.transport-type", "http")
23+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.url", "http://localhost:8082/mcp/sse")
24+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-requests", "true")
25+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-responses", "true")
26+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.roots", "David's workspace=file:///home/david/workspace")
27+
.overrideConfigKey("quarkus.log.category.\"io.quarkiverse\".level", "DEBUG");
28+
29+
@BeforeAll
30+
static void setup() throws Exception {
31+
skipTestsIfJbangNotAvailable();
32+
process = startServerHttp("roots_mcp_server.java");
33+
}
34+
35+
@AfterAll
36+
static void teardown() throws Exception {
37+
if (process != null && process.isAlive()) {
38+
process.destroyForcibly();
39+
}
40+
}
41+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package io.quarkiverse.langchain4j.mcp.test;
2+
3+
import static io.quarkiverse.langchain4j.mcp.test.McpServerHelper.skipTestsIfJbangNotAvailable;
4+
5+
import org.jboss.shrinkwrap.api.ShrinkWrap;
6+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
7+
import org.junit.jupiter.api.BeforeAll;
8+
import org.junit.jupiter.api.extension.RegisterExtension;
9+
10+
import io.quarkus.test.QuarkusUnitTest;
11+
12+
class McpRootsStdioTransportTest extends McpRootsTestBase {
13+
14+
@RegisterExtension
15+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
16+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
17+
.addClasses(McpServerHelper.class))
18+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.transport-type", "stdio")
19+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.command",
20+
"jbang,--quiet,--fresh,run,src/test/resources/roots_mcp_server.java")
21+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-requests", "true")
22+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-responses", "true")
23+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.roots", "David's workspace=file:///home/david/workspace")
24+
.overrideConfigKey("quarkus.log.category.\"io.quarkiverse\".level", "DEBUG");
25+
26+
@BeforeAll
27+
static void setup() {
28+
skipTestsIfJbangNotAvailable();
29+
}
30+
31+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package io.quarkiverse.langchain4j.mcp.test;
2+
3+
import static io.quarkiverse.langchain4j.mcp.test.McpServerHelper.skipTestsIfJbangNotAvailable;
4+
import static io.quarkiverse.langchain4j.mcp.test.McpServerHelper.startServerHttp;
5+
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.AfterAll;
9+
import org.junit.jupiter.api.BeforeAll;
10+
import org.junit.jupiter.api.Disabled;
11+
import org.junit.jupiter.api.extension.RegisterExtension;
12+
13+
import io.quarkus.test.QuarkusUnitTest;
14+
15+
@Disabled("With streamable-http, server requesting roots from the client doesn't work atm")
16+
class McpRootsStreamableHttpTransportTest extends McpRootsTestBase {
17+
18+
private static Process process;
19+
20+
@RegisterExtension
21+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
22+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
23+
.addClasses(McpServerHelper.class))
24+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.transport-type", "streamable-http")
25+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.url", "http://localhost:8082/mcp")
26+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-requests", "true")
27+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.log-responses", "true")
28+
.overrideConfigKey("quarkus.langchain4j.mcp.client1.roots", "David's workspace=file:///home/david/workspace")
29+
.overrideConfigKey("quarkus.log.category.\"io.quarkiverse\".level", "DEBUG");
30+
31+
@BeforeAll
32+
static void setup() throws Exception {
33+
skipTestsIfJbangNotAvailable();
34+
process = startServerHttp("roots_mcp_server.java");
35+
}
36+
37+
@AfterAll
38+
static void teardown() throws Exception {
39+
if (process != null && process.isAlive()) {
40+
process.destroyForcibly();
41+
}
42+
}
43+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package io.quarkiverse.langchain4j.mcp.test;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
8+
import jakarta.inject.Inject;
9+
10+
import org.junit.jupiter.api.Test;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
14+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
15+
import dev.langchain4j.mcp.client.McpClient;
16+
import dev.langchain4j.mcp.client.McpRoot;
17+
import io.quarkiverse.langchain4j.mcp.runtime.McpClientName;
18+
19+
public abstract class McpRootsTestBase {
20+
21+
@Inject
22+
@McpClientName("client1")
23+
McpClient mcpClient;
24+
25+
private static final Logger log = LoggerFactory.getLogger(McpRootsTestBase.class);
26+
27+
@Test
28+
public void verifyServerHasReceivedTools() {
29+
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
30+
.name("assertRoots")
31+
.arguments("{}")
32+
.build();
33+
String result = mcpClient.executeTool(toolExecutionRequest);
34+
assertThat(result).isEqualTo("OK");
35+
36+
// now update the roots
37+
List<McpRoot> newRoots = new ArrayList<>();
38+
newRoots.add(new McpRoot("Paul's workspace", "file:///home/paul/workspace"));
39+
mcpClient.setRoots(newRoots);
40+
41+
// and verify that the server has asked for the roots again and received them
42+
toolExecutionRequest = ToolExecutionRequest.builder()
43+
.name("assertRootsAfterUpdate")
44+
.arguments("{}")
45+
.build();
46+
result = mcpClient.executeTool(toolExecutionRequest);
47+
assertThat(result).isEqualTo("OK");
48+
}
49+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
///usr/bin/env jbang "$0" "$@" ; exit $?
2+
//DEPS io.quarkus:quarkus-bom:${quarkus.version:3.25.0}@pom
3+
//DEPS io.quarkiverse.mcp:quarkus-mcp-server-stdio:1.4.0
4+
//DEPS io.quarkiverse.mcp:quarkus-mcp-server-sse:1.4.0
5+
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
import java.util.concurrent.CountDownLatch;
9+
import java.util.concurrent.TimeUnit;
10+
11+
import io.quarkus.logging.Log;
12+
13+
import io.quarkiverse.mcp.server.McpConnection;
14+
import io.quarkiverse.mcp.server.Root;
15+
import io.quarkiverse.mcp.server.Roots;
16+
import io.quarkiverse.mcp.server.Notification;
17+
import io.quarkiverse.mcp.server.Notification.Type;
18+
import io.quarkiverse.mcp.server.TextContent;
19+
import io.quarkiverse.mcp.server.Tool;
20+
import io.quarkiverse.mcp.server.ToolArg;
21+
import io.quarkiverse.mcp.server.ToolResponse;
22+
23+
// this server expects the client to provide a list of roots during initialization
24+
// and provides a tool named 'assertRoots' that returns 'OK' if the
25+
// roots were received correctly.
26+
public class roots_mcp_server {
27+
28+
private volatile List<Root> rootList;
29+
private CountDownLatch initializationLatch = new CountDownLatch(1);
30+
private CountDownLatch updateLatch = new CountDownLatch(1);
31+
32+
@Notification(Type.INITIALIZED)
33+
void init(McpConnection connection, Roots roots) throws Exception {
34+
if (!connection.initialRequest().supportsRoots()) {
35+
throw new RuntimeException("The client does not support roots.");
36+
}
37+
rootList = roots.listAndAwait();
38+
initializationLatch.countDown();
39+
Log.info("Roots list = " + rootList);
40+
}
41+
42+
43+
@Notification(Type.ROOTS_LIST_CHANGED)
44+
void change(McpConnection connection, Roots roots) {
45+
rootList = roots.listAndAwait();
46+
updateLatch.countDown();
47+
}
48+
49+
@Tool
50+
String assertRoots() throws Exception {
51+
// Wait up to 20 seconds until the `rootList` variable has been set because that happens asynchronously
52+
// and the tool may be called before the MCP server receives the roots.
53+
initializationLatch.await(20, TimeUnit.SECONDS);
54+
return assertRoot("David's workspace", "file:///home/david/workspace");
55+
}
56+
57+
@Tool
58+
String assertRootsAfterUpdate() throws Exception {
59+
updateLatch.await(20, TimeUnit.SECONDS);
60+
return assertRoot("Paul's workspace", "file:///home/paul/workspace");
61+
}
62+
63+
private String assertRoot(String name, String uri) {
64+
if(rootList.isEmpty()) {
65+
throw new RuntimeException("The client didn't send any roots");
66+
}
67+
if(rootList.size() != 1) {
68+
throw new RuntimeException("The client sent more than one root: " + rootList);
69+
}
70+
Root root = rootList.get(0);
71+
if(!root.name().equals(name) || !root.uri().equals(uri)) {
72+
throw new RuntimeException("The client sent the wrong root: " + root);
73+
}
74+
return "OK";
75+
}
76+
77+
}

mcp/runtime/src/main/java/io/quarkiverse/langchain4j/mcp/runtime/McpRecorder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import dev.langchain4j.mcp.client.DefaultMcpClient;
1515
import dev.langchain4j.mcp.client.McpClient;
16+
import dev.langchain4j.mcp.client.McpRoot;
1617
import dev.langchain4j.mcp.client.transport.McpTransport;
1718
import dev.langchain4j.mcp.client.transport.stdio.StdioMcpTransport;
1819
import dev.langchain4j.service.tool.ToolProvider;
@@ -59,6 +60,13 @@ public Supplier<McpClient> mcpClientSupplier(String key,
5960
public McpClient get() {
6061
McpTransport transport;
6162
McpClientRuntimeConfig runtimeConfig = mcpRuntimeConfiguration.getValue().clients().get(key);
63+
List<McpRoot> initialRoots = new ArrayList<>();
64+
if (runtimeConfig.roots().isPresent()) {
65+
for (String kvPair : runtimeConfig.roots().get()) {
66+
String[] split = kvPair.split("=");
67+
initialRoots.add(new McpRoot(split[0], split[1]));
68+
}
69+
}
6270
transport = switch (mcpTransportType) {
6371
case STDIO -> {
6472
List<String> command = runtimeConfig.command().orElseThrow(() -> new ConfigurationException(
@@ -95,6 +103,7 @@ public McpClient get() {
95103
.pingTimeout(runtimeConfig.pingTimeout())
96104
// TODO: it should be possible to choose a log handler class via configuration
97105
.logHandler(new QuarkusDefaultMcpLogHandler(key))
106+
.roots(initialRoots)
98107
.build();
99108
shutdown.addShutdownTask(client::close);
100109
return client;

mcp/runtime/src/main/java/io/quarkiverse/langchain4j/mcp/runtime/config/McpClientRuntimeConfig.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,13 @@ public interface McpClientRuntimeConfig {
6666
*/
6767
@WithDefault("10s")
6868
Duration pingTimeout();
69+
70+
/**
71+
* The initial list of MCP roots that the client can present to the server. The list can
72+
* be later updated programmatically during runtime. The list is formatted as key-value pairs
73+
* separated by commas. For example:
74+
* workspace1=/path/to/workspace1,workspace2=/path/to/workspace2
75+
*/
76+
Optional<List<String>> roots();
77+
6978
}

0 commit comments

Comments
 (0)