Skip to content

Commit b5b2d8f

Browse files
committed
Support custom HTTP headers for MCP SSE transport
Closes GH-3948 Signed-off-by: Yanming Zhou <[email protected]>
1 parent c7f7b68 commit b5b2d8f

File tree

6 files changed

+103
-11
lines changed

6 files changed

+103
-11
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
* </pre>
4040
*
4141
* @author Christian Tzolov
42+
* @author Yanming Zhou
4243
* @since 1.0.0
4344
* @see SseParameters
4445
*/
@@ -68,8 +69,9 @@ public Map<String, SseParameters> getConnections() {
6869
*
6970
* @param url the URL endpoint for SSE communication with the MCP server
7071
* @param sseEndpoint the SSE endpoint for the MCP server
72+
* @param headers the custom HTTP headers for the MCP server
7173
*/
72-
public record SseParameters(String url, String sseEndpoint) {
74+
public record SseParameters(String url, String sseEndpoint, Map<String, String> headers) {
7375
}
7476

7577
}

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
* Tests for {@link McpSseClientProperties}.
3131
*
3232
* @author Christian Tzolov
33+
* @author Yanming Zhou
3334
*/
3435
class McpSseClientPropertiesTests {
3536

@@ -105,7 +106,7 @@ void connectionWithNullUrl() {
105106
void sseParametersRecord() {
106107
String url = "http://test-server:8080/events";
107108
String sseUrl = "/sse";
108-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
109+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, null);
109110

110111
assertThat(params.url()).isEqualTo(url);
111112
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
@@ -114,7 +115,7 @@ void sseParametersRecord() {
114115
@Test
115116
void sseParametersRecordWithNullSseEndpoint() {
116117
String url = "http://test-server:8080/events";
117-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
118+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, null);
118119

119120
assertThat(params.url()).isEqualTo(url);
120121
assertThat(params.sseEndpoint()).isNull();
@@ -129,7 +130,8 @@ void configPrefixConstant() {
129130
void yamlConfigurationBinding() {
130131
this.contextRunner
131132
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080/events",
132-
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events")
133+
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events",
134+
"spring.ai.mcp.client.sse.connections.server2.headers.Authorization=Bearer <access_token>")
133135
.run(context -> {
134136
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
135137
assertThat(properties.getConnections()).hasSize(2);
@@ -139,6 +141,8 @@ void yamlConfigurationBinding() {
139141
assertThat(properties.getConnections().get("server2").url())
140142
.isEqualTo("http://otherserver:8081/events");
141143
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
144+
assertThat(properties.getConnections().get("server2").headers()).containsEntry("Authorization",
145+
"Bearer <access_token>");
142146
});
143147
}
144148

@@ -150,21 +154,21 @@ void connectionMapManipulation() {
150154

151155
// Add a connection
152156
connections.put("server1",
153-
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
157+
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", null));
154158
assertThat(properties.getConnections()).hasSize(1);
155159
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
156160
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");
157161

158162
// Add another connection
159163
connections.put("server2",
160-
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
164+
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, null));
161165
assertThat(properties.getConnections()).hasSize(2);
162166
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
163167
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
164168

165169
// Replace a connection
166170
connections.put("server1",
167-
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
171+
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", null));
168172
assertThat(properties.getConnections()).hasSize(2);
169173
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
170174
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.mcp.client.httpclient.autoconfigure;
1818

1919
import java.net.http.HttpClient;
20+
import java.net.http.HttpRequest;
2021
import java.util.ArrayList;
2122
import java.util.List;
2223
import java.util.Map;
@@ -39,6 +40,7 @@
3940
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4041
import org.springframework.context.annotation.Bean;
4142
import org.springframework.core.log.LogAccessor;
43+
import org.springframework.util.CollectionUtils;
4244

4345
/**
4446
* Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model
@@ -113,6 +115,15 @@ public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientPropert
113115
.clientBuilder(HttpClient.newBuilder())
114116
.objectMapper(objectMapper);
115117

118+
Map<String, String> headers = serverParameters.getValue().headers();
119+
if (!CollectionUtils.isEmpty(headers)) {
120+
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
121+
for (Map.Entry<String, String> entry : headers.entrySet()) {
122+
requestBuilder = requestBuilder.header(entry.getKey(), entry.getValue());
123+
}
124+
transportBuilder = transportBuilder.requestBuilder(requestBuilder);
125+
}
126+
116127
asyncHttpRequestCustomizer.ifUnique(transportBuilder::asyncHttpRequestCustomizer);
117128
syncHttpRequestCustomizer.ifUnique(transportBuilder::httpRequestCustomizer);
118129
if (asyncHttpRequestCustomizer.getIfUnique() != null && syncHttpRequestCustomizer.getIfUnique() != null) {

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import static org.assertj.core.api.Assertions.assertThat;
2020

2121
import java.lang.reflect.Field;
22+
import java.net.URI;
23+
import java.net.http.HttpRequest;
2224
import java.util.List;
2325

2426
import org.junit.jupiter.api.Test;
@@ -38,6 +40,7 @@
3840
* Tests for {@link SseHttpClientTransportAutoConfiguration}.
3941
*
4042
* @author Christian Tzolov
43+
* @author Yanming Zhou
4144
*/
4245
public class SseHttpClientTransportAutoConfigurationTests {
4346

@@ -154,10 +157,38 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
154157
});
155158
}
156159

160+
@Test
161+
void customHttpHeaders() {
162+
this.applicationContext
163+
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
164+
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/custom-sse",
165+
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
166+
.run(context -> {
167+
List<NamedClientMcpTransport> transports = context.getBean("sseHttpClientTransports", List.class);
168+
assertThat(transports).hasSize(1);
169+
assertThat(transports.get(0).name()).isEqualTo("server1");
170+
assertThat(transports.get(0).transport()).isInstanceOf(HttpClientSseClientTransport.class);
171+
172+
HttpRequest.Builder builder = getRequestBuilder(
173+
(HttpClientSseClientTransport) transports.get(0).transport());
174+
assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization"))
175+
.hasValue("Bearer <access_token>");
176+
});
177+
}
178+
157179
private String getSseEndpoint(HttpClientSseClientTransport transport) {
158-
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, "sseEndpoint");
180+
return getField(transport, "sseEndpoint", String.class);
181+
}
182+
183+
private HttpRequest.Builder getRequestBuilder(HttpClientSseClientTransport transport) {
184+
return getField(transport, "requestBuilder", HttpRequest.Builder.class);
185+
}
186+
187+
@SuppressWarnings("unchecked")
188+
private <T> T getField(HttpClientSseClientTransport transport, String fieldName, Class<T> type) {
189+
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, fieldName);
159190
ReflectionUtils.makeAccessible(privateField);
160-
return (String) ReflectionUtils.getField(privateField, transport);
191+
return (T) ReflectionUtils.getField(privateField, transport);
161192
}
162193

163194
@Configuration

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
3434
import org.springframework.boot.context.properties.EnableConfigurationProperties;
3535
import org.springframework.context.annotation.Bean;
36+
import org.springframework.util.CollectionUtils;
3637
import org.springframework.web.reactive.function.client.WebClient;
3738

3839
/**
@@ -91,6 +92,12 @@ public List<NamedClientMcpTransport> sseWebFluxClientTransports(McpSseClientProp
9192

9293
for (Map.Entry<String, SseParameters> serverParameters : sseProperties.getConnections().entrySet()) {
9394
var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url());
95+
var headers = serverParameters.getValue().headers();
96+
if (!CollectionUtils.isEmpty(headers)) {
97+
for (Map.Entry<String, String> entry : headers.entrySet()) {
98+
webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(), entry.getValue());
99+
}
100+
}
94101
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
95102
? serverParameters.getValue().sseEndpoint() : "/sse";
96103
var transport = WebFluxSseClientTransport.builder(webClientBuilder)

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@
1919
import static org.assertj.core.api.Assertions.assertThat;
2020

2121
import java.lang.reflect.Field;
22+
import java.net.URI;
23+
import java.net.http.HttpRequest;
2224
import java.util.List;
2325

26+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
2427
import org.junit.jupiter.api.Test;
2528
import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport;
2629
import org.springframework.boot.autoconfigure.AutoConfigurations;
2730
import org.springframework.boot.test.context.FilteredClassLoader;
2831
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
2932
import org.springframework.context.annotation.Bean;
3033
import org.springframework.context.annotation.Configuration;
34+
import org.springframework.http.HttpHeaders;
3135
import org.springframework.util.ReflectionUtils;
3236
import org.springframework.web.reactive.function.client.WebClient;
3337

@@ -39,6 +43,7 @@
3943
* Tests for {@link SseWebFluxTransportAutoConfiguration}.
4044
*
4145
* @author Christian Tzolov
46+
* @author Yanming Zhou
4247
*/
4348
public class SseWebFluxTransportAutoConfigurationTests {
4449

@@ -178,10 +183,42 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
178183
});
179184
}
180185

186+
@Test
187+
void customHttpHeaders() {
188+
this.applicationContext.withUserConfiguration(CustomWebClientConfiguration.class)
189+
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
190+
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
191+
.run(context -> {
192+
assertThat(context.getBean(WebClient.Builder.class)).isNotNull();
193+
List<NamedClientMcpTransport> transports = context.getBean("sseWebFluxClientTransports", List.class);
194+
assertThat(transports).hasSize(1);
195+
196+
WebClient webClient = getWebClient((WebFluxSseClientTransport) transports.get(0).transport());
197+
HttpHeaders defaultHeaders = getDefaultHeaders(webClient);
198+
assertThat(defaultHeaders.getFirst("Authorization")).isEqualTo("Bearer <access_token>");
199+
});
200+
}
201+
181202
private String getSseEndpoint(WebFluxSseClientTransport transport) {
182-
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, "sseEndpoint");
203+
return getField(transport, "sseEndpoint", String.class);
204+
}
205+
206+
private WebClient getWebClient(WebFluxSseClientTransport transport) {
207+
return getField(transport, "webClient", WebClient.class);
208+
}
209+
210+
@SuppressWarnings("unchecked")
211+
private <T> T getField(WebFluxSseClientTransport transport, String fieldName, Class<T> type) {
212+
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, fieldName);
213+
ReflectionUtils.makeAccessible(privateField);
214+
return (T) ReflectionUtils.getField(privateField, transport);
215+
}
216+
217+
@SuppressWarnings("unchecked")
218+
private HttpHeaders getDefaultHeaders(WebClient webClient) {
219+
Field privateField = ReflectionUtils.findField(webClient.getClass(), "defaultHeaders");
183220
ReflectionUtils.makeAccessible(privateField);
184-
return (String) ReflectionUtils.getField(privateField, transport);
221+
return (HttpHeaders) ReflectionUtils.getField(privateField, webClient);
185222
}
186223

187224
@Configuration

0 commit comments

Comments
 (0)