Skip to content

Commit 183f1de

Browse files
committed
Support custom HTTP headers for MCP SSE transport
Closes GH-3948 Signed-off-by: Yanming Zhou <[email protected]>
1 parent 1256eaf commit 183f1de

File tree

6 files changed

+104
-14
lines changed

6 files changed

+104
-14
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
* </pre>
4040
*
4141
* @author Christian Tzolov
42+
* @author Yanming Zhou
4243
* @since 1.0.0
4344
* @see SseParameters
4445
*/
@@ -68,8 +69,9 @@ public Map<String, SseParameters> getConnections() {
6869
*
6970
* @param url the URL endpoint for SSE communication with the MCP server
7071
* @param sseEndpoint the SSE endpoint for the MCP server
72+
* @param headers the custom HTTP headers for the MCP server
7173
*/
72-
public record SseParameters(String url, String sseEndpoint) {
74+
public record SseParameters(String url, String sseEndpoint, Map<String, String> headers) {
7375
}
7476

7577
}

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
* Tests for {@link McpSseClientProperties}.
3131
*
3232
* @author Christian Tzolov
33+
* @author Yanming Zhou
3334
*/
3435
class McpSseClientPropertiesTests {
3536

@@ -105,7 +106,7 @@ void connectionWithNullUrl() {
105106
void sseParametersRecord() {
106107
String url = "http://test-server:8080/events";
107108
String sseUrl = "/sse";
108-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
109+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, null);
109110

110111
assertThat(params.url()).isEqualTo(url);
111112
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
@@ -114,7 +115,7 @@ void sseParametersRecord() {
114115
@Test
115116
void sseParametersRecordWithNullSseEndpoint() {
116117
String url = "http://test-server:8080/events";
117-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
118+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, null);
118119

119120
assertThat(params.url()).isEqualTo(url);
120121
assertThat(params.sseEndpoint()).isNull();
@@ -129,7 +130,8 @@ void configPrefixConstant() {
129130
void yamlConfigurationBinding() {
130131
this.contextRunner
131132
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080/events",
132-
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events")
133+
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081/events",
134+
"spring.ai.mcp.client.sse.connections.server2.headers.Authorization=Bearer <access_token>")
133135
.run(context -> {
134136
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
135137
assertThat(properties.getConnections()).hasSize(2);
@@ -139,6 +141,8 @@ void yamlConfigurationBinding() {
139141
assertThat(properties.getConnections().get("server2").url())
140142
.isEqualTo("http://otherserver:8081/events");
141143
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
144+
assertThat(properties.getConnections().get("server2").headers()).containsEntry("Authorization",
145+
"Bearer <access_token>");
142146
});
143147
}
144148

@@ -150,21 +154,21 @@ void connectionMapManipulation() {
150154

151155
// Add a connection
152156
connections.put("server1",
153-
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
157+
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", null));
154158
assertThat(properties.getConnections()).hasSize(1);
155159
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
156160
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");
157161

158162
// Add another connection
159163
connections.put("server2",
160-
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
164+
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, null));
161165
assertThat(properties.getConnections()).hasSize(2);
162166
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
163167
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
164168

165169
// Replace a connection
166170
connections.put("server1",
167-
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
171+
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", null));
168172
assertThat(properties.getConnections()).hasSize(2);
169173
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
170174
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/main/java/org/springframework/ai/mcp/client/httpclient/autoconfigure/SseHttpClientTransportAutoConfiguration.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.mcp.client.httpclient.autoconfigure;
1818

1919
import java.net.http.HttpClient;
20+
import java.net.http.HttpRequest;
2021
import java.util.ArrayList;
2122
import java.util.List;
2223
import java.util.Map;
@@ -36,6 +37,7 @@
3637
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
3738
import org.springframework.boot.context.properties.EnableConfigurationProperties;
3839
import org.springframework.context.annotation.Bean;
40+
import org.springframework.util.CollectionUtils;
3941

4042
/**
4143
* Auto-configuration for Server-Sent Events (SSE) HTTP client transport in the Model
@@ -96,11 +98,19 @@ public List<NamedClientMcpTransport> sseHttpClientTransports(McpSseClientPropert
9698
String baseUrl = serverParameters.getValue().url();
9799
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
98100
? serverParameters.getValue().sseEndpoint() : "/sse";
99-
var transport = HttpClientSseClientTransport.builder(baseUrl)
101+
var transportBuilder = HttpClientSseClientTransport.builder(baseUrl)
100102
.sseEndpoint(sseEndpoint)
101103
.clientBuilder(HttpClient.newBuilder())
102-
.objectMapper(objectMapper)
103-
.build();
104+
.objectMapper(objectMapper);
105+
var headers = serverParameters.getValue().headers();
106+
if (!CollectionUtils.isEmpty(headers)) {
107+
var requestBuilder = HttpRequest.newBuilder();
108+
for (Map.Entry<String, String> entry : headers.entrySet()) {
109+
requestBuilder = requestBuilder.header(entry.getKey(), entry.getValue());
110+
}
111+
transportBuilder = transportBuilder.requestBuilder(requestBuilder);
112+
}
113+
var transport = transportBuilder.build();
104114
sseTransports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport));
105115
}
106116

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationTests.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import static org.assertj.core.api.Assertions.assertThat;
2020

2121
import java.lang.reflect.Field;
22+
import java.net.URI;
23+
import java.net.http.HttpRequest;
2224
import java.util.List;
2325

2426
import org.junit.jupiter.api.Test;
@@ -154,10 +156,38 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
154156
});
155157
}
156158

159+
@Test
160+
void customHttpHeaders() {
161+
this.applicationContext
162+
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
163+
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/custom-sse",
164+
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
165+
.run(context -> {
166+
List<NamedClientMcpTransport> transports = context.getBean("sseHttpClientTransports", List.class);
167+
assertThat(transports).hasSize(1);
168+
assertThat(transports.get(0).name()).isEqualTo("server1");
169+
assertThat(transports.get(0).transport()).isInstanceOf(HttpClientSseClientTransport.class);
170+
171+
HttpRequest.Builder builder = getRequestBuilder(
172+
(HttpClientSseClientTransport) transports.get(0).transport());
173+
assertThat(builder.uri(new URI("http://localhost:8080")).build().headers().firstValue("Authorization"))
174+
.hasValue("Bearer <access_token>");
175+
});
176+
}
177+
157178
private String getSseEndpoint(HttpClientSseClientTransport transport) {
158-
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, "sseEndpoint");
179+
return getField(transport, "sseEndpoint", String.class);
180+
}
181+
182+
private HttpRequest.Builder getRequestBuilder(HttpClientSseClientTransport transport) {
183+
return getField(transport, "requestBuilder", HttpRequest.Builder.class);
184+
}
185+
186+
@SuppressWarnings("unchecked")
187+
private <T> T getField(HttpClientSseClientTransport transport, String fieldName, Class<T> type) {
188+
Field privateField = ReflectionUtils.findField(HttpClientSseClientTransport.class, fieldName);
159189
ReflectionUtils.makeAccessible(privateField);
160-
return (String) ReflectionUtils.getField(privateField, transport);
190+
return (T) ReflectionUtils.getField(privateField, transport);
161191
}
162192

163193
@Configuration

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfiguration.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
3434
import org.springframework.boot.context.properties.EnableConfigurationProperties;
3535
import org.springframework.context.annotation.Bean;
36+
import org.springframework.util.CollectionUtils;
3637
import org.springframework.web.reactive.function.client.WebClient;
3738

3839
/**
@@ -91,6 +92,12 @@ public List<NamedClientMcpTransport> sseWebFluxClientTransports(McpSseClientProp
9192

9293
for (Map.Entry<String, SseParameters> serverParameters : sseProperties.getConnections().entrySet()) {
9394
var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url());
95+
var headers = serverParameters.getValue().headers();
96+
if (!CollectionUtils.isEmpty(headers)) {
97+
for (Map.Entry<String, String> entry : headers.entrySet()) {
98+
webClientBuilder = webClientBuilder.defaultHeader(entry.getKey(), entry.getValue());
99+
}
100+
}
94101
String sseEndpoint = serverParameters.getValue().sseEndpoint() != null
95102
? serverParameters.getValue().sseEndpoint() : "/sse";
96103
var transport = WebFluxSseClientTransport.builder(webClientBuilder)

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationTests.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@
1919
import static org.assertj.core.api.Assertions.assertThat;
2020

2121
import java.lang.reflect.Field;
22+
import java.net.URI;
23+
import java.net.http.HttpRequest;
2224
import java.util.List;
2325

26+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
2427
import org.junit.jupiter.api.Test;
2528
import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport;
2629
import org.springframework.boot.autoconfigure.AutoConfigurations;
2730
import org.springframework.boot.test.context.FilteredClassLoader;
2831
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
2932
import org.springframework.context.annotation.Bean;
3033
import org.springframework.context.annotation.Configuration;
34+
import org.springframework.http.HttpHeaders;
3135
import org.springframework.util.ReflectionUtils;
3236
import org.springframework.web.reactive.function.client.WebClient;
3337

@@ -39,6 +43,7 @@
3943
* Tests for {@link SseWebFluxTransportAutoConfiguration}.
4044
*
4145
* @author Christian Tzolov
46+
* @author Yanming Zhou
4247
*/
4348
public class SseWebFluxTransportAutoConfigurationTests {
4449

@@ -178,10 +183,42 @@ void mixedConnectionsWithAndWithoutCustomSseEndpoint() {
178183
});
179184
}
180185

186+
@Test
187+
void customHttpHeaders() {
188+
this.applicationContext.withUserConfiguration(CustomWebClientConfiguration.class)
189+
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
190+
"spring.ai.mcp.client.sse.connections.server1.headers.Authorization=Bearer <access_token>")
191+
.run(context -> {
192+
assertThat(context.getBean(WebClient.Builder.class)).isNotNull();
193+
List<NamedClientMcpTransport> transports = context.getBean("sseWebFluxClientTransports", List.class);
194+
assertThat(transports).hasSize(1);
195+
196+
WebClient webClient = getWebClient((WebFluxSseClientTransport) transports.get(0).transport());
197+
HttpHeaders defaultHeaders = getDefaultHeaders(webClient);
198+
assertThat(defaultHeaders.getFirst("Authorization")).isEqualTo("Bearer <access_token>");
199+
});
200+
}
201+
181202
private String getSseEndpoint(WebFluxSseClientTransport transport) {
182-
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, "sseEndpoint");
203+
return getField(transport, "sseEndpoint", String.class);
204+
}
205+
206+
private WebClient getWebClient(WebFluxSseClientTransport transport) {
207+
return getField(transport, "webClient", WebClient.class);
208+
}
209+
210+
@SuppressWarnings("unchecked")
211+
private <T> T getField(WebFluxSseClientTransport transport, String fieldName, Class<T> type) {
212+
Field privateField = ReflectionUtils.findField(WebFluxSseClientTransport.class, fieldName);
213+
ReflectionUtils.makeAccessible(privateField);
214+
return (T) ReflectionUtils.getField(privateField, transport);
215+
}
216+
217+
@SuppressWarnings("unchecked")
218+
private HttpHeaders getDefaultHeaders(WebClient webClient) {
219+
Field privateField = ReflectionUtils.findField(webClient.getClass(), "defaultHeaders");
183220
ReflectionUtils.makeAccessible(privateField);
184-
return (String) ReflectionUtils.getField(privateField, transport);
221+
return (HttpHeaders) ReflectionUtils.getField(privateField, webClient);
185222
}
186223

187224
@Configuration

0 commit comments

Comments
 (0)