Skip to content

Support custom HTTP headers for MCP SSE transport #3949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
* </pre>
*
* @author Christian Tzolov
* @author Yanming Zhou
* @since 1.0.0
* @see SseParameters
*/
Expand Down Expand Up @@ -68,8 +69,9 @@ public Map<String, SseParameters> getConnections() {
*
* @param url the URL endpoint for SSE communication with the MCP server
* @param sseEndpoint the SSE endpoint for the MCP server
* @param headers the custom HTTP headers for the MCP server
*/
public record SseParameters(String url, String sseEndpoint) {
public record SseParameters(String url, String sseEndpoint, Map<String, String> headers) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
* Tests for {@link McpSseClientProperties}.
*
* @author Christian Tzolov
* @author Yanming Zhou
*/
class McpSseClientPropertiesTests {

Expand Down Expand Up @@ -105,7 +106,7 @@ void connectionWithNullUrl() {
void sseParametersRecord() {
String url = "http://test-server:8080/events";
String sseUrl = "/sse";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, null);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
Expand All @@ -114,7 +115,7 @@ void sseParametersRecord() {
@Test
void sseParametersRecordWithNullSseEndpoint() {
String url = "http://test-server:8080/events";
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, null);

assertThat(params.url()).isEqualTo(url);
assertThat(params.sseEndpoint()).isNull();
Expand All @@ -129,7 +130,8 @@ void configPrefixConstant() {
void yamlConfigurationBinding() {
this.contextRunner
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080/events",
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events")
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events",
"spring.ai.mcp.client.sse.connections.server2.headers.Authorization=Bearer <access_token>")
.run(context -> {
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
assertThat(properties.getConnections()).hasSize(2);
Expand All @@ -139,6 +141,8 @@ void yamlConfigurationBinding() {
assertThat(properties.getConnections().get("server2").url())
.isEqualTo("http://otherserver:8081/events");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
assertThat(properties.getConnections().get("server2").headers()).containsEntry("Authorization",
"Bearer <access_token>");
});
}

Expand All @@ -150,21 +154,21 @@ void connectionMapManipulation() {

// Add a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", null));
assertThat(properties.getConnections()).hasSize(1);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");

// Add another connection
connections.put("server2",
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, null));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();

// Replace a connection
connections.put("server1",
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", null));
assertThat(properties.getConnections()).hasSize(2);
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp.client.httpclient.autoconfigure;

import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -39,6 +40,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.core.log.LogAccessor;
import org.springframework.util.CollectionUtils;

/**
* Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model
Expand Down Expand Up @@ -113,6 +115,15 @@ public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientPropert
.clientBuilder(HttpClient.newBuilder())
.objectMapper(objectMapper);

Map<String, String> headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
for (Map.Entry<String, String> entry : headers.entrySet()) {
requestBuilder = requestBuilder.header(entry.getKey(), entry.getValue());
}
transportBuilder = transportBuilder.requestBuilder(requestBuilder);
}

asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.lang.reflect.Field;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.List;

import org.junit.jupiter.api.Test;
Expand All @@ -38,6 +40,7 @@
* Tests for {@link SseHttpClientTransportAutoConfiguration}.
*
* @author Christian Tzolov
* @author Yanming Zhou
*/
public class SseHttpClientTransportAutoConfigurationTests {

Expand Down Expand Up @@ -154,10 +157,38 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
});
}

@Test
void customHttpHeaders() {
this.applicationContext
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/custom-sse",
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
.run(context -> {
List<NamedClientMcpTransport> transports = context.getBean("sseHttpClientTransports", List.class);
assertThat(transports).hasSize(1);
assertThat(transports.get(0).name()).isEqualTo("server1");
assertThat(transports.get(0).transport()).isInstanceOf(HttpClientSseClientTransport.class);

HttpRequest.Builder builder = getRequestBuilder(
(HttpClientSseClientTransport) transports.get(0).transport());
assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization"))
.hasValue("Bearer <access_token>");
});
}

private String getSseEndpoint(HttpClientSseClientTransport transport) {
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, "sseEndpoint");
return getField(transport, "sseEndpoint", String.class);
}

private HttpRequest.Builder getRequestBuilder(HttpClientSseClientTransport transport) {
return getField(transport, "requestBuilder", HttpRequest.Builder.class);
}

@SuppressWarnings("unchecked")
private <T> T getField(HttpClientSseClientTransport transport, String fieldName, Class<T> type) {
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, fieldName);
ReflectionUtils.makeAccessible(privateField);
return (String) ReflectionUtils.getField(privateField, transport);
return (T) ReflectionUtils.getField(privateField, transport);
}

@Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand Down Expand Up @@ -91,6 +92,12 @@ public List<NamedClientMcpTransport> sseWebFluxClientTransports(McpSseClientProp

for (Map.Entry<String, SseParameters> serverParameters : sseProperties.getConnections().entrySet()) {
var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url());
var headers = serverParameters.getValue().headers();
if (!CollectionUtils.isEmpty(headers)) {
for (Map.Entry<String, String> entry : headers.entrySet()) {
webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(), entry.getValue());
}
}
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
? serverParameters.getValue().sseEndpoint() : "/sse";
var transport = WebFluxSseClientTransport.builder(webClientBuilder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.lang.reflect.Field;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.List;

import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import org.junit.jupiter.api.Test;
import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.FilteredClassLoader;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.reactive.function.client.WebClient;

Expand All @@ -39,6 +43,7 @@
* Tests for {@link SseWebFluxTransportAutoConfiguration}.
*
* @author Christian Tzolov
* @author Yanming Zhou
*/
public class SseWebFluxTransportAutoConfigurationTests {

Expand Down Expand Up @@ -178,10 +183,42 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
});
}

@Test
void customHttpHeaders() {
this.applicationContext.withUserConfiguration(CustomWebClientConfiguration.class)
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
.run(context -> {
assertThat(context.getBean(WebClient.Builder.class)).isNotNull();
List<NamedClientMcpTransport> transports = context.getBean("sseWebFluxClientTransports", List.class);
assertThat(transports).hasSize(1);

WebClient webClient = getWebClient((WebFluxSseClientTransport) transports.get(0).transport());
HttpHeaders defaultHeaders = getDefaultHeaders(webClient);
assertThat(defaultHeaders.getFirst("Authorization")).isEqualTo("Bearer <access_token>");
});
}

private String getSseEndpoint(WebFluxSseClientTransport transport) {
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, "sseEndpoint");
return getField(transport, "sseEndpoint", String.class);
}

private WebClient getWebClient(WebFluxSseClientTransport transport) {
return getField(transport, "webClient", WebClient.class);
}

@SuppressWarnings("unchecked")
private <T> T getField(WebFluxSseClientTransport transport, String fieldName, Class<T> type) {
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, fieldName);
ReflectionUtils.makeAccessible(privateField);
return (T) ReflectionUtils.getField(privateField, transport);
}

@SuppressWarnings("unchecked")
private HttpHeaders getDefaultHeaders(WebClient webClient) {
Field privateField = ReflectionUtils.findField(webClient.getClass(), "defaultHeaders");
ReflectionUtils.makeAccessible(privateField);
return (String) ReflectionUtils.getField(privateField, transport);
return (HttpHeaders) ReflectionUtils.getField(privateField, webClient);
}

@Configuration
Expand Down