Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package org.springframework.ai.mcp.client.common.autoconfigure.properties;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.lang.Nullable;

/**
* Configuration properties for Server-Sent Events (SSE) based MCP client connections.
Expand Down Expand Up @@ -53,6 +55,7 @@
* </pre>
*
* @author Christian Tzolov
* @author Yanming Zhou
* @since 1.0.0
* @see SseParameters
*/
Expand Down Expand Up @@ -82,8 +85,9 @@ public Map<String, SseParameters> getConnections() {
*
* @param url the URL endpoint for SSE communication with the MCP server
* @param sseEndpoint the SSE endpoint for the MCP server
* @param headers the custom HTTP headers for the MCP server
*/
public record SseParameters(String url, String sseEndpoint) {
public record SseParameters(String url, String sseEndpoint, @Nullable Map<String, List<String>> headers) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package org.springframework.ai.mcp.client.common.autoconfigure.properties;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.lang.Nullable;

/**
* Configuration properties for Streamable Http client connections.
Expand All @@ -39,6 +41,7 @@
* </pre>
*
* @author Christian Tzolov
* @author Yanming Zhou
* @see ConnectionParameters
*/
@ConfigurationProperties(McpStreamableHttpClientProperties.CONFIG_PREFIX)
Expand Down Expand Up @@ -67,8 +70,9 @@ public Map<String, ConnectionParameters> getConnections() {
*
* @param url the URL endpoint for Streamable Http communication with the MCP server
* @param endpoint the endpoint for the MCP server
* @param headers the custom HTTP headers for the MCP server
*/
public record ConnectionParameters(String url, String endpoint) {
public record ConnectionParameters(String url, String endpoint, @Nullable Map<String, List<String>> headers) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.mcp.client.common.autoconfigure.properties;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
Expand All @@ -30,6 +31,7 @@
* Tests for {@link McpSseClientProperties}.
*
* @author Christian Tzolov
* @author Yanming Zhou
*/
class McpSseClientPropertiesTests {

Expand Down Expand Up @@ -105,7 +107,7 @@ void connectionWithNullUrl() {
void sseParametersRecord() {
String url = "http://test-server:8080/events";
String sseUrl = "/sse";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, null);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
Expand All @@ -114,7 +116,7 @@ void sseParametersRecord() {
@Test
void sseParametersRecordWithNullSseEndpoint() {
String url = "http://test-server:8080/events";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, null);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isNull();
Expand All @@ -129,7 +131,8 @@ void configPrefixConstant() {
void yamlConfigurationBinding() {
this.contextRunner
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080/events",
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events")
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events",
"spring.ai.mcp.client.sse.connections.server2.headers.Authorization=Bearer <access_token>")
.run(context -> {
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
assertThat(properties.getConnections()).hasSize(2);
Expand All @@ -139,6 +142,8 @@ void yamlConfigurationBinding() {
assertThat(properties.getConnections().get("server2").url())
.isEqualTo("http://otherserver:8081/events");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
assertThat(properties.getConnections().get("server2").headers()).containsEntry("Authorization",
List.of("Bearer <access_token>"));
});
}

Expand All @@ -150,21 +155,21 @@ void connectionMapManipulation() {

// Add a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", null));
assertThat(properties.getConnections()).hasSize(1);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");

// Add another connection
connections.put("server2",
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, null));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();

// Replace a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", null));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp.client.httpclient.autoconfigure;

import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -42,6 +43,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.core.log.LogAccessor;
import org.springframework.util.CollectionUtils;

/**
* Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model
Expand Down Expand Up @@ -125,6 +127,17 @@ public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientConnect
.clientBuilder(HttpClient.newBuilder())
.jsonMapper(new JacksonMcpJsonMapper(objectMapper));

Map<String, List<String>> headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
for (String value : entry.getValue()) {
requestBuilder = requestBuilder.header(entry.getKey(), value);
}
}
transportBuilder = transportBuilder.requestBuilder(requestBuilder);
}

asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
if (asyncHttpRequestCustomizer.getIfUnique() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp.client.httpclient.autoconfigure;

import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -40,6 +41,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.core.log.LogAccessor;
import org.springframework.util.CollectionUtils;

/**
* Auto-configuration for Streamable HTTP client transport in the Model Context Protocol
Expand Down Expand Up @@ -116,6 +118,17 @@ public List<NamedClientMcpTransport> streamableHttpHttpClientTransports(
.clientBuilder(HttpClient.newBuilder())
.jsonMapper(new JacksonMcpJsonMapper(objectMapper));

Map<String, List<String>> headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
for (String value : entry.getValue()) {
requestBuilder = requestBuilder.header(entry.getKey(), value);
}
}
transportBuilder = transportBuilder.requestBuilder(requestBuilder);
}

asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.mcp.client.autoconfigure;

import java.lang.reflect.Field;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.List;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -37,6 +39,7 @@
* Tests for {@link SseHttpClientTransportAutoConfiguration}.
*
* @author Christian Tzolov
* @author Yanming Zhou
*/
public class SseHttpClientTransportAutoConfigurationTests {

Expand Down Expand Up @@ -153,10 +156,38 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
});
}

@Test
void customHttpHeaders() {
this.applicationContext
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/custom-sse",
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
.run(context -> {
List<NamedClientMcpTransport> transports = context.getBean("sseHttpClientTransports", List.class);
assertThat(transports).hasSize(1);
assertThat(transports.get(0).name()).isEqualTo("server1");
assertThat(transports.get(0).transport()).isInstanceOf(HttpClientSseClientTransport.class);

HttpRequest.Builder builder = getRequestBuilder(
(HttpClientSseClientTransport) transports.get(0).transport());
assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization"))
.hasValue("Bearer <access_token>");
});
}

private String getSseEndpoint(HttpClientSseClientTransport transport) {
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, "sseEndpoint");
return getField(transport, "sseEndpoint", String.class);
}

private HttpRequest.Builder getRequestBuilder(HttpClientSseClientTransport transport) {
return getField(transport, "requestBuilder", HttpRequest.Builder.class);
}

@SuppressWarnings("unchecked")
private <T> T getField(HttpClientSseClientTransport transport, String fieldName, Class<T> type) {
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, fieldName);
ReflectionUtils.makeAccessible(privateField);
return (String) ReflectionUtils.getField(privateField, transport);
return (T) ReflectionUtils.getField(privateField, transport);
}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.mcp.client.autoconfigure;

import java.lang.reflect.Field;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.List;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -162,10 +164,38 @@ void mixedConnectionsWithAndWithoutCustomEndpoint() {
});
}

@Test
void customHttpHeaders() {
this.applicationContext.withPropertyValues(
"spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.streamable-http.connections.server1.headers.Authorization=Bearer <access_token>")
.run(context -> {
List<NamedClientMcpTransport> transports = context.getBean("streamableHttpHttpClientTransports",
List.class);
assertThat(transports).hasSize(1);
assertThat(transports.get(0).name()).isEqualTo("server1");
assertThat(transports.get(0).transport()).isInstanceOf(HttpClientStreamableHttpTransport.class);

HttpRequest.Builder builder = getRequestBuilder(
(HttpClientStreamableHttpTransport) transports.get(0).transport());
assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization"))
.hasValue("Bearer <access_token>");
});
}

private String getStreamableHttpEndpoint(HttpClientStreamableHttpTransport transport) {
Field privateField = ReflectionUtils.findField(HttpClientStreamableHttpTransport.class, "endpoint");
return getField(transport, "endpoint", String.class);
}

private HttpRequest.Builder getRequestBuilder(HttpClientStreamableHttpTransport transport) {
return getField(transport, "requestBuilder", HttpRequest.Builder.class);
}

@SuppressWarnings("unchecked")
private <T> T getField(HttpClientStreamableHttpTransport transport, String fieldName, Class<T> type) {
Field privateField = ReflectionUtils.findField(HttpClientStreamableHttpTransport.class, fieldName);
ReflectionUtils.makeAccessible(privateField);
return (String) ReflectionUtils.getField(privateField, transport);
return (T) ReflectionUtils.getField(privateField, transport);
}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand Down Expand Up @@ -99,6 +100,13 @@ public List<NamedClientMcpTransport> sseWebFluxClientTransports(McpSseClientConn

for (Map.Entry<String, SseParameters> serverParameters : connectionDetails.getConnections().entrySet()) {
var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url());
var headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(),
entry.getValue().toArray(new String[0]));
}
}
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
? serverParameters.getValue().sseEndpoint() : "/sse";
var transport = WebFluxSseClientTransport.builder(webClientBuilder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand Down Expand Up @@ -98,6 +99,13 @@ public List<NamedClientMcpTransport> streamableHttpWebFluxClientTransports(
var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url());
String streamableHttpEndpoint = serverParameters.getValue().endpoint() != null
? serverParameters.getValue().endpoint() : "/mcp";
var headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(),
entry.getValue().toArray(new String[0]));
}
}

var transport = WebClientStreamableHttpTransport.builder(webClientBuilder)
.endpoint(streamableHttpEndpoint)
Expand Down
Loading