From ee7e27370a2b75a9870cb515f8f434bbc5ce49e9 Mon Sep 17 00:00:00 2001 From: Yanming Zhou Date: Thu, 31 Jul 2025 11:11:16 +0800 Subject: [PATCH] Support custom HTTP headers for MCP transport Closes GH-3948 Signed-off-by: Yanming Zhou --- .../properties/McpSseClientProperties.java | 4 +- .../McpStreamableHttpClientProperties.java | 4 +- .../McpSseClientPropertiesTests.java | 16 +++++--- ...eHttpClientTransportAutoConfiguration.java | 11 +++++ ...pHttpClientTransportAutoConfiguration.java | 11 +++++ ...ClientTransportAutoConfigurationTests.java | 35 +++++++++++++++- .../SseWebFluxTransportAutoConfiguration.java | 7 ++++ ...HttpWebFluxTransportAutoConfiguration.java | 7 ++++ ...ebFluxTransportAutoConfigurationTests.java | 40 ++++++++++++++++++- ...ebFluxTransportAutoConfigurationTests.java | 39 +++++++++++++++++- 10 files changed, 160 insertions(+), 14 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java index f23029ddd96..581bf665e3d 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java @@ -39,6 +39,7 @@ * * * @author Christian Tzolov + * @author Yanming Zhou * @since 1.0.0 * @see SseParameters */ @@ -68,8 +69,9 @@ public Map getConnections() { * * @param url the URL endpoint for SSE communication with the MCP server * @param sseEndpoint the SSE endpoint for the MCP server + * @param headers the custom HTTP headers for the MCP server */ - public record SseParameters(String url, String sseEndpoint) { + public record SseParameters(String url, String sseEndpoint, Map headers) { } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java index afa74ded003..96ffa63c0a3 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java @@ -39,6 +39,7 @@ * * * @author Christian Tzolov + * @author Yanming Zhou * @see ConnectionParameters */ @ConfigurationProperties(McpStreamableHttpClientProperties.CONFIG_PREFIX) @@ -67,8 +68,9 @@ public Map getConnections() { * * @param url the URL endpoint for Streamable Http communication with the MCP server * @param endpoint the endpoint for the MCP server + * @param headers the custom HTTP headers for the MCP server */ - public record ConnectionParameters(String url, String endpoint) { + public record ConnectionParameters(String url, String endpoint, Map headers) { } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java index b3c72aa08b3..993c5567a98 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java @@ -30,6 +30,7 @@ * Tests for {@link McpSseClientProperties}. * * @author Christian Tzolov + * @author Yanming Zhou */ class McpSseClientPropertiesTests { @@ -105,7 +106,7 @@ void connectionWithNullUrl() { void sseParametersRecord() { String url = "http://test-server:8080/events"; String sseUrl = "/sse"; - McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl); + McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, null); assertThat(params.url()).isEqualTo(url); assertThat(params.sseEndpoint()).isEqualTo(sseUrl); @@ -114,7 +115,7 @@ void sseParametersRecord() { @Test void sseParametersRecordWithNullSseEndpoint() { String url = "http://test-server:8080/events"; - McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null); + McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, null); assertThat(params.url()).isEqualTo(url); assertThat(params.sseEndpoint()).isNull(); @@ -129,7 +130,8 @@ void configPrefixConstant() { void yamlConfigurationBinding() { this.contextRunner .withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080/events", - "spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events") + "spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events", + "spring.ai.mcp.client.sse.connections.server2.headers.Authorization=Bearer ") .run(context -> { McpSseClientProperties properties = context.getBean(McpSseClientProperties.class); assertThat(properties.getConnections()).hasSize(2); @@ -139,6 +141,8 @@ void yamlConfigurationBinding() { assertThat(properties.getConnections().get("server2").url()) .isEqualTo("http://otherserver:8081/events"); assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull(); + assertThat(properties.getConnections().get("server2").headers()).containsEntry("Authorization", + "Bearer "); }); } @@ -150,21 +154,21 @@ void connectionMapManipulation() { // Add a connection connections.put("server1", - new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse")); + new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", null)); assertThat(properties.getConnections()).hasSize(1); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse"); // Add another connection connections.put("server2", - new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null)); + new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, null)); assertThat(properties.getConnections()).hasSize(2); assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events"); assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull(); // Replace a connection connections.put("server1", - new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events")); + new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", null)); assertThat(properties.getConnections()).hasSize(2); assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events"); assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events"); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java index 6d695a468d7..5c2f39bcb8f 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp.client.httpclient.autoconfigure; import java.net.http.HttpClient; +import java.net.http.HttpRequest; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -39,6 +40,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.core.log.LogAccessor; +import org.springframework.util.CollectionUtils; /** * Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model @@ -113,6 +115,15 @@ public List sseHttpClientTransports(McpSseClientPropert .clientBuilder(HttpClient.newBuilder()) .objectMapper(objectMapper); + Map headers = serverParameters.getValue().headers(); + if (!CollectionUtils.isEmpty(headers)) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + for (Map.Entry entry : headers.entrySet()) { + requestBuilder = requestBuilder.header(entry.getKey(), entry.getValue()); + } + transportBuilder = transportBuilder.requestBuilder(requestBuilder); + } + asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer); syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer); if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/StreamableHttpHttpClientTransportAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/StreamableHttpHttpClientTransportAutoConfiguration.java index 93f07b617fe..568aec3efe6 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/StreamableHttpHttpClientTransportAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/StreamableHttpHttpClientTransportAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp.client.httpclient.autoconfigure; import java.net.http.HttpClient; +import java.net.http.HttpRequest; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -40,6 +41,7 @@ import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer; import io.modelcontextprotocol.spec.McpSchema; +import org.springframework.util.CollectionUtils; /** * Auto-configuration for Streamable HTTP client transport in the Model Context Protocol @@ -120,6 +122,15 @@ public List streamableHttpHttpClientTransports( .clientBuilder(HttpClient.newBuilder()) .objectMapper(objectMapper); + Map headers = serverParameters.getValue().headers(); + if (!CollectionUtils.isEmpty(headers)) { + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + for (Map.Entry entry : headers.entrySet()) { + requestBuilder = requestBuilder.header(entry.getKey(), entry.getValue()); + } + transportBuilder = transportBuilder.requestBuilder(requestBuilder); + } + asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer); syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer); if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java index 1e57a9551ea..a86ecf1d259 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java @@ -19,6 +19,8 @@ import static org.assertj.core.api.Assertions.assertThat; import java.lang.reflect.Field; +import java.net.URI; +import java.net.http.HttpRequest; import java.util.List; import org.junit.jupiter.api.Test; @@ -38,6 +40,7 @@ * Tests for {@link SseHttpClientTransportAutoConfiguration}. * * @author Christian Tzolov + * @author Yanming Zhou */ public class SseHttpClientTransportAutoConfigurationTests { @@ -154,10 +157,38 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() { }); } + @Test + void customHttpHeaders() { + this.applicationContext + .withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/custom-sse", + "spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer ") + .run(context -> { + List transports = context.getBean("sseHttpClientTransports", List.class); + assertThat(transports).hasSize(1); + assertThat(transports.get(0).name()).isEqualTo("server1"); + assertThat(transports.get(0).transport()).isInstanceOf(HttpClientSseClientTransport.class); + + HttpRequest.Builder builder = getRequestBuilder( + (HttpClientSseClientTransport) transports.get(0).transport()); + assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization")) + .hasValue("Bearer "); + }); + } + private String getSseEndpoint(HttpClientSseClientTransport transport) { - Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, "sseEndpoint"); + return getField(transport, "sseEndpoint", String.class); + } + + private HttpRequest.Builder getRequestBuilder(HttpClientSseClientTransport transport) { + return getField(transport, "requestBuilder", HttpRequest.Builder.class); + } + + @SuppressWarnings("unchecked") + private T getField(HttpClientSseClientTransport transport, String fieldName, Class type) { + Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, fieldName); ReflectionUtils.makeAccessible(privateField); - return (String) ReflectionUtils.getField(privateField, transport); + return (T) ReflectionUtils.getField(privateField, transport); } @Configuration diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java index 595cd97dfa6..0393ff60827 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java @@ -33,6 +33,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.util.CollectionUtils; import org.springframework.web.reactive.function.client.WebClient; /** @@ -91,6 +92,12 @@ public List sseWebFluxClientTransports(McpSseClientProp for (Map.Entry serverParameters : sseProperties.getConnections().entrySet()) { var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url()); + var headers = serverParameters.getValue().headers(); + if (!CollectionUtils.isEmpty(headers)) { + for (Map.Entry entry : headers.entrySet()) { + webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(), entry.getValue()); + } + } String sseEndpoint = serverParameters.getValue().sseEndpoint() != null ? serverParameters.getValue().sseEndpoint() : "/sse"; var transport = WebFluxSseClientTransport.builder(webClientBuilder) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java index 524785b3322..de660e4ccef 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java @@ -30,6 +30,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; +import org.springframework.util.CollectionUtils; import org.springframework.web.reactive.function.client.WebClient; import com.fasterxml.jackson.databind.ObjectMapper; @@ -98,6 +99,12 @@ public List streamableHttpWebFluxClientTransports( var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url()); String streamableHttpEndpoint = serverParameters.getValue().endpoint() != null ? serverParameters.getValue().endpoint() : "/mcp"; + var headers = serverParameters.getValue().headers(); + if (!CollectionUtils.isEmpty(headers)) { + for (Map.Entry entry : headers.entrySet()) { + webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(), entry.getValue()); + } + } var transport = WebClientStreamableHttpTransport.builder(webClientBuilder) .endpoint(streamableHttpEndpoint) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java index fbac5a2a15c..f5188da2144 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java @@ -19,8 +19,11 @@ import static org.assertj.core.api.Assertions.assertThat; import java.lang.reflect.Field; +import java.net.URI; +import java.net.http.HttpRequest; import java.util.List; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import org.junit.jupiter.api.Test; import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -28,6 +31,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; import org.springframework.util.ReflectionUtils; import org.springframework.web.reactive.function.client.WebClient; @@ -39,6 +43,7 @@ * Tests for {@link SseWebFluxTransportAutoConfiguration}. * * @author Christian Tzolov + * @author Yanming Zhou */ public class SseWebFluxTransportAutoConfigurationTests { @@ -178,10 +183,41 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() { }); } + @Test + void customHttpHeaders() { + this.applicationContext.withUserConfiguration(CustomWebClientConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer ") + .run(context -> { + assertThat(context.getBean(WebClient.Builder.class)).isNotNull(); + List transports = context.getBean("sseWebFluxClientTransports", List.class); + assertThat(transports).hasSize(1); + + WebClient webClient = getWebClient((WebFluxSseClientTransport) transports.get(0).transport()); + HttpHeaders defaultHeaders = getDefaultHeaders(webClient); + assertThat(defaultHeaders.getFirst("Authorization")).isEqualTo("Bearer "); + }); + } + private String getSseEndpoint(WebFluxSseClientTransport transport) { - Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, "sseEndpoint"); + return getField(transport, "sseEndpoint", String.class); + } + + private WebClient getWebClient(WebFluxSseClientTransport transport) { + return getField(transport, "webClient", WebClient.class); + } + + @SuppressWarnings("unchecked") + private T getField(WebFluxSseClientTransport transport, String fieldName, Class type) { + Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, fieldName); + ReflectionUtils.makeAccessible(privateField); + return (T) ReflectionUtils.getField(privateField, transport); + } + + private HttpHeaders getDefaultHeaders(WebClient webClient) { + Field privateField = ReflectionUtils.findField(webClient.getClass(), "defaultHeaders"); ReflectionUtils.makeAccessible(privateField); - return (String) ReflectionUtils.getField(privateField, transport); + return (HttpHeaders) ReflectionUtils.getField(privateField, webClient); } @Configuration diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java index f22c2b2a2ea..58a7ffcc1af 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java @@ -28,6 +28,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; import org.springframework.util.ReflectionUtils; import org.springframework.web.reactive.function.client.WebClient; @@ -39,6 +40,7 @@ * Tests for {@link StreamableHttpWebFluxTransportAutoConfiguration}. * * @author Christian Tzolov + * @author Yanming Zhou */ public class StreamableHttpWebFluxTransportAutoConfigurationTests { @@ -190,10 +192,43 @@ void mixedConnectionsWithAndWithoutCustomStreamableHttpEndpoint() { }); } + @Test + void customHttpHeaders() { + this.applicationContext + .withUserConfiguration(SseWebFluxTransportAutoConfigurationTests.CustomWebClientConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.streamable-http.connections.server1.headers.Authorization=Bearer ") + .run(context -> { + assertThat(context.getBean(WebClient.Builder.class)).isNotNull(); + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(1); + + WebClient webClient = getWebClient((WebClientStreamableHttpTransport) transports.get(0).transport()); + HttpHeaders defaultHeaders = getDefaultHeaders(webClient); + assertThat(defaultHeaders.getFirst("Authorization")).isEqualTo("Bearer "); + }); + } + private String getStreamableHttpEndpoint(WebClientStreamableHttpTransport transport) { - Field privateField = ReflectionUtils.findField(WebClientStreamableHttpTransport.class, "endpoint"); + return getField(transport, "endpoint", String.class); + } + + private WebClient getWebClient(WebClientStreamableHttpTransport transport) { + return getField(transport, "webClient", WebClient.class); + } + + @SuppressWarnings("unchecked") + private T getField(WebClientStreamableHttpTransport transport, String fieldName, Class type) { + Field privateField = ReflectionUtils.findField(WebClientStreamableHttpTransport.class, fieldName); + ReflectionUtils.makeAccessible(privateField); + return (T) ReflectionUtils.getField(privateField, transport); + } + + private HttpHeaders getDefaultHeaders(WebClient webClient) { + Field privateField = ReflectionUtils.findField(webClient.getClass(), "defaultHeaders"); ReflectionUtils.makeAccessible(privateField); - return (String) ReflectionUtils.getField(privateField, transport); + return (HttpHeaders) ReflectionUtils.getField(privateField, webClient); } @Configuration