Skip to content

Commit b7bdd72

Browse files
committed
Simplify use of headers for SockJsClient requests
Before this change, XhrTransport implementations had to be configured with the headers to use for HTTP requests other than the initial handshake. After this change the handshake headers passed to SockJsClient by default are used for all other HTTP requests related to the SockJS connection (e.g. info request, xhr send/receive). A property on SockJsClient allows restricting the headers to use for other HTTP requests to a subset of the handshake headers. Issue: SPR-13254
1 parent 9f557cf commit b7bdd72

17 files changed

+264
-87
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import org.springframework.http.HttpHeaders;
2828
import org.springframework.http.HttpStatus;
29-
import org.springframework.http.MediaType;
3029
import org.springframework.http.ResponseEntity;
3130
import org.springframework.util.concurrent.ListenableFuture;
3231
import org.springframework.util.concurrent.SettableListenableFuture;
@@ -61,8 +60,6 @@ public abstract class AbstractXhrTransport implements XhrTransport {
6160

6261
private HttpHeaders requestHeaders = new HttpHeaders();
6362

64-
private HttpHeaders xhrSendRequestHeaders = new HttpHeaders();
65-
6663

6764
@Override
6865
public List<TransportType> getTransportTypes() {
@@ -97,24 +94,25 @@ public boolean isXhrStreamingDisabled() {
9794
/**
9895
* Configure headers to be added to every executed HTTP request.
9996
* @param requestHeaders the headers to add to requests
97+
* @deprecated as of 4.2 in favor of {@link SockJsClient#setHttpHeaderNames}.
10098
*/
99+
@Deprecated
101100
public void setRequestHeaders(HttpHeaders requestHeaders) {
102101
this.requestHeaders.clear();
103-
this.xhrSendRequestHeaders.clear();
104102
if (requestHeaders != null) {
105103
this.requestHeaders.putAll(requestHeaders);
106-
this.xhrSendRequestHeaders.putAll(requestHeaders);
107-
this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON);
108104
}
109105
}
110106

107+
@Deprecated
111108
public HttpHeaders getRequestHeaders() {
112109
return this.requestHeaders;
113110
}
114111

115112

116113
// Transport methods
117114

115+
@SuppressWarnings("deprecation")
118116
@Override
119117
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
120118
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
@@ -128,8 +126,8 @@ public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebS
128126
}
129127

130128
HttpHeaders handshakeHeaders = new HttpHeaders();
131-
handshakeHeaders.putAll(request.getHandshakeHeaders());
132129
handshakeHeaders.putAll(getRequestHeaders());
130+
handshakeHeaders.putAll(request.getHandshakeHeaders());
133131

134132
connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture);
135133
return connectFuture;
@@ -142,11 +140,17 @@ protected abstract void connectInternal(TransportRequest request, WebSocketHandl
142140
// InfoReceiver methods
143141

144142
@Override
145-
public String executeInfoRequest(URI infoUrl) {
143+
@SuppressWarnings("deprecation")
144+
public String executeInfoRequest(URI infoUrl, HttpHeaders headers) {
146145
if (logger.isDebugEnabled()) {
147146
logger.debug("Executing SockJS Info request, url=" + infoUrl);
148147
}
149-
ResponseEntity<String> response = executeInfoRequestInternal(infoUrl);
148+
HttpHeaders infoRequestHeaders = new HttpHeaders();
149+
infoRequestHeaders.putAll(getRequestHeaders());
150+
if (headers != null) {
151+
infoRequestHeaders.putAll(headers);
152+
}
153+
ResponseEntity<String> response = executeInfoRequestInternal(infoUrl, infoRequestHeaders);
150154
if (response.getStatusCode() != HttpStatus.OK) {
151155
if (logger.isErrorEnabled()) {
152156
logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response);
@@ -159,16 +163,16 @@ public String executeInfoRequest(URI infoUrl) {
159163
return response.getBody();
160164
}
161165

162-
protected abstract ResponseEntity<String> executeInfoRequestInternal(URI infoUrl);
166+
protected abstract ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers);
163167

164168
// XhrTransport methods
165169

166170
@Override
167-
public void executeSendRequest(URI url, TextMessage message) {
171+
public void executeSendRequest(URI url, HttpHeaders headers, TextMessage message) {
168172
if (logger.isTraceEnabled()) {
169173
logger.trace("Starting XHR send, url=" + url);
170174
}
171-
ResponseEntity<String> response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message);
175+
ResponseEntity<String> response = executeSendRequestInternal(url, headers, message);
172176
if (response.getStatusCode() != HttpStatus.NO_CONTENT) {
173177
if (logger.isErrorEnabled()) {
174178
logger.error("XHR send request (url=" + url + ") failed: " + response);
@@ -180,7 +184,8 @@ public void executeSendRequest(URI url, TextMessage message) {
180184
}
181185
}
182186

183-
protected abstract ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message);
187+
protected abstract ResponseEntity<String> executeSendRequestInternal(URI url,
188+
HttpHeaders headers, TextMessage message);
184189

185190

186191
@Override

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class DefaultTransportRequest implements TransportRequest {
5252

5353
private final HttpHeaders handshakeHeaders;
5454

55+
private final HttpHeaders httpRequestHeaders;
56+
5557
private final Transport transport;
5658

5759
private final TransportType serverTransportType;
@@ -69,7 +71,8 @@ class DefaultTransportRequest implements TransportRequest {
6971
private DefaultTransportRequest fallbackRequest;
7072

7173

72-
public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders,
74+
public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo,
75+
HttpHeaders handshakeHeaders, HttpHeaders httpRequestHeaders,
7376
Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) {
7477

7578
Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required");
@@ -78,6 +81,7 @@ public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshak
7881
Assert.notNull(codec, "'codec' is required");
7982
this.sockJsUrlInfo = sockJsUrlInfo;
8083
this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders());
84+
this.httpRequestHeaders = (httpRequestHeaders != null ? httpRequestHeaders : new HttpHeaders());
8185
this.transport = transport;
8286
this.serverTransportType = serverTransportType;
8387
this.codec = codec;
@@ -94,6 +98,11 @@ public HttpHeaders getHandshakeHeaders() {
9498
return this.handshakeHeaders;
9599
}
96100

101+
@Override
102+
public HttpHeaders getHttpRequestHeaders() {
103+
return this.httpRequestHeaders;
104+
}
105+
97106
@Override
98107
public URI getTransportUrl() {
99108
return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType);

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import java.net.URI;
1919

20+
import org.springframework.http.HttpHeaders;
21+
2022
/**
2123
* A component that can execute the SockJS "Info" request that needs to be
2224
* performed before the SockJS session starts in order to check server endpoint
@@ -34,10 +36,11 @@ public interface InfoReceiver {
3436
/**
3537
* Perform an HTTP request to the SockJS "Info" URL.
3638
* and return the resulting JSON response content, or raise an exception.
37-
*
39+
* <p>Note that as of 4.2 this method accepts a {@code headers} parameter.
3840
* @param infoUrl the URL to obtain SockJS server information from
41+
* @param headers the headers to use for the request
3942
* @return the body of the response
4043
*/
41-
String executeInfoRequest(URI infoUrl);
44+
String executeInfoRequest(URI infoUrl, HttpHeaders headers);
4245

4346
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ public boolean isRunning() {
106106

107107

108108
@Override
109-
protected void connectInternal(TransportRequest request, WebSocketHandler handler,
109+
protected void connectInternal(TransportRequest transportRequest, WebSocketHandler handler,
110110
URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
111111
SettableListenableFuture<WebSocketSession> connectFuture) {
112112

113-
SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture);
113+
HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
114+
SockJsResponseListener listener = new SockJsResponseListener(url, httpHeaders, session, connectFuture);
114115
executeReceiveRequest(url, handshakeHeaders, listener);
115116
}
116117

@@ -124,8 +125,8 @@ private void executeReceiveRequest(URI url, HttpHeaders headers, SockJsResponseL
124125
}
125126

126127
@Override
127-
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
128-
return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null);
128+
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
129+
return executeRequest(infoUrl, HttpMethod.GET, headers, null);
129130
}
130131

131132
@Override

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,16 @@ public TaskExecutor getTaskExecutor() {
9494

9595

9696
@Override
97-
protected void connectInternal(final TransportRequest request, final WebSocketHandler handler,
97+
protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler,
9898
final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session,
9999
final SettableListenableFuture<WebSocketSession> connectFuture) {
100100

101101
getTaskExecutor().execute(new Runnable() {
102102
@Override
103103
public void run() {
104+
HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders();
104105
XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders);
105-
XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders());
106+
XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders);
106107
XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session);
107108
while (true) {
108109
if (session.isDisconnected()) {
@@ -132,8 +133,8 @@ public void run() {
132133
}
133134

134135
@Override
135-
public ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
136-
RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders());
136+
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
137+
RequestCallback requestCallback = new XhrRequestCallback(headers);
137138
return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor);
138139
}
139140

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle {
7878

7979
private final List<Transport> transports;
8080

81+
private String[] httpHeaderNames;
82+
8183
private InfoReceiver infoReceiver;
8284

8385
private SockJsMessageCodec messageCodec;
@@ -116,6 +118,30 @@ private static InfoReceiver initInfoReceiver(List<Transport> transports) {
116118
}
117119

118120

121+
/**
122+
* The names of HTTP headers that should be copied from the handshake headers
123+
* of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)}
124+
* and also used with other HTTP requests issued as part of that SockJS
125+
* connection, e.g. the initial info request, XHR send or receive requests.
126+
*
127+
* <p>By default if this property is not set, all handshake headers are also
128+
* used for other HTTP requests. Set it if you want only a subset of handshake
129+
* headers (e.g. auth headers) to be used for other HTTP requests.
130+
*
131+
* @param httpHeaderNames HTTP header names
132+
*/
133+
public void setHttpHeaderNames(String... httpHeaderNames) {
134+
this.httpHeaderNames = httpHeaderNames;
135+
}
136+
137+
/**
138+
* The configured HTTP header names to be copied from the handshake
139+
* headers and also included in other HTTP requests.
140+
*/
141+
public String[] getHttpHeaderNames() {
142+
return this.httpHeaderNames;
143+
}
144+
119145
/**
120146
* Configure the {@code InfoReceiver} to use to perform the SockJS "Info"
121147
* request before the SockJS session starts.
@@ -225,7 +251,7 @@ public final ListenableFuture<WebSocketSession> doHandshake(
225251
SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<WebSocketSession>();
226252
try {
227253
SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url);
228-
ServerInfo serverInfo = getServerInfo(sockJsUrlInfo);
254+
ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers));
229255
createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture);
230256
}
231257
catch (Throwable exception) {
@@ -237,12 +263,27 @@ public final ListenableFuture<WebSocketSession> doHandshake(
237263
return connectFuture;
238264
}
239265

240-
private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) {
266+
private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) {
267+
if (getHttpHeaderNames() == null) {
268+
return webSocketHttpHeaders;
269+
}
270+
else {
271+
HttpHeaders httpHeaders = new HttpHeaders();
272+
for (String name : getHttpHeaderNames()) {
273+
if (webSocketHttpHeaders.containsKey(name)) {
274+
httpHeaders.put(name, webSocketHttpHeaders.get(name));
275+
}
276+
}
277+
return httpHeaders;
278+
}
279+
}
280+
281+
private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) {
241282
URI infoUrl = sockJsUrlInfo.getInfoUrl();
242283
ServerInfo info = this.serverInfoCache.get(infoUrl);
243284
if (info == null) {
244285
long start = System.currentTimeMillis();
245-
String response = this.infoReceiver.executeInfoRequest(infoUrl);
286+
String response = this.infoReceiver.executeInfoRequest(infoUrl, headers);
246287
long infoRequestTime = System.currentTimeMillis() - start;
247288
info = new ServerInfo(response, infoRequestTime);
248289
this.serverInfoCache.put(infoUrl, info);
@@ -255,7 +296,8 @@ private DefaultTransportRequest createRequest(SockJsUrlInfo urlInfo, HttpHeaders
255296
for (Transport transport : this.transports) {
256297
for (TransportType type : transport.getTransportTypes()) {
257298
if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) {
258-
requests.add(new DefaultTransportRequest(urlInfo, headers, transport, type, getMessageCodec()));
299+
requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers),
300+
transport, type, getMessageCodec()));
259301
}
260302
}
261303
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ public interface TransportRequest {
4747
*/
4848
HttpHeaders getHandshakeHeaders();
4949

50+
/**
51+
* Return the headers to add to all other HTTP requests besides the handshake
52+
* request such XHR receive and send requests.
53+
* @since 4.2
54+
*/
55+
HttpHeaders getHttpRequestHeaders();
56+
5057
/**
5158
* Return the transport URL for the given transport.
5259
* For an {@link XhrTransport} this is the URL for receiving messages.

0 commit comments

Comments
 (0)