Skip to content

Commit 7fee863

Browse files
authored
GH-10487: Add STOMP CONNECT frame from the client (#10488)
Fixes: #10487 * Fix `WebSocketInboundChannelAdapter` to register own client session in the `StompSubProtocolHandler` for a proper correlation for upcoming messages from the server * Fix `WebSocketOutboundMessageHandlerTests` to produce required STOMP `CONNECT` before publishing data Cherry-pick to `6.5.x` & `6.4.x`
1 parent 20bc3a5 commit 7fee863

File tree

6 files changed

+78
-49
lines changed

6 files changed

+78
-49
lines changed

spring-integration-websocket/src/main/java/org/springframework/integration/websocket/inbound/WebSocketInboundChannelAdapter.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@
5151
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
5252
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
5353
import org.springframework.messaging.simp.stomp.StompCommand;
54+
import org.springframework.messaging.simp.stomp.StompEncoder;
5455
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
5556
import org.springframework.messaging.support.MessageBuilder;
5657
import org.springframework.util.Assert;
5758
import org.springframework.util.CollectionUtils;
5859
import org.springframework.util.MimeTypeUtils;
60+
import org.springframework.web.socket.BinaryMessage;
5961
import org.springframework.web.socket.CloseStatus;
6062
import org.springframework.web.socket.WebSocketMessage;
6163
import org.springframework.web.socket.WebSocketSession;
@@ -141,7 +143,7 @@ public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketCon
141143
}
142144

143145
/**
144-
* Set the message converters to use. These converters are used to convert the message to send for appropriate
146+
* Set the message converters to use. These converters are used to convert the message to send for the appropriate
145147
* internal subProtocols type.
146148
* @param messageConverters The message converters.
147149
*/
@@ -160,7 +162,7 @@ public void setMergeWithDefaultConverters(boolean mergeWithDefaultConverters) {
160162
}
161163

162164
/**
163-
* Set the type for target message payload to convert the WebSocket message body to.
165+
* Set the type for the target message payload to convert the WebSocket message body to.
164166
* @param payloadType to convert inbound WebSocket message body
165167
* @see CompositeMessageConverter
166168
*/
@@ -174,9 +176,9 @@ public void setPayloadType(Class<?> payloadType) {
174176
* bean for {@code non-MESSAGE} {@link org.springframework.web.socket.WebSocketMessage}s
175177
* and to route messages with broker destinations.
176178
* Since only single {@link AbstractBrokerMessageHandler} bean is allowed in the current
177-
* application context, the algorithm to lookup the former by type, rather than applying
179+
* application context, the algorithm is to look up the former by type, rather than applying
178180
* the bean reference.
179-
* This is used only on server side and is ignored from client side.
181+
* This is used only on the server side and is ignored from the client side.
180182
* @param useBroker the boolean flag.
181183
*/
182184
public void setUseBroker(boolean useBroker) {
@@ -234,13 +236,23 @@ public void afterSessionStarted(WebSocketSession session) {
234236
SubProtocolHandler protocolHandler = this.subProtocolHandlerRegistry.findProtocolHandler(session);
235237
protocolHandler.afterSessionStarted(session, this.subProtocolHandlerChannel);
236238
if (!this.server && protocolHandler instanceof StompSubProtocolHandler) {
239+
// The CONNECT frame is required by the STOMP specification.
237240
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
238241
accessor.setSessionId(session.getId());
239242
accessor.setLeaveMutable(true);
240243
accessor.setAcceptVersion("1.1,1.2");
241244

242-
Message<?> connectMessage =
245+
Message<byte[]> connectMessage =
243246
MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
247+
248+
// In the client mode, the client session has to register itself
249+
// into the StompSubProtocolHandler cache
250+
// for proper correlation of the messages from the server side.
251+
StompEncoder stompEncoder = new StompEncoder();
252+
byte[] connectMessageBytes = stompEncoder.encode(connectMessage);
253+
protocolHandler.handleMessageFromClient(session, new BinaryMessage(connectMessageBytes),
254+
this.subProtocolHandlerChannel);
255+
244256
protocolHandler.handleMessageToClient(session, connectMessage);
245257
}
246258
}
@@ -313,7 +325,11 @@ private void handleMessageAndSend(final Message<?> message) {
313325
SimpMessageType messageType = headerAccessor.getMessageType();
314326
if (isProcessingTypeOrCommand(headerAccessor, stompCommand, messageType)) {
315327
if (SimpMessageType.CONNECT.equals(messageType)) {
316-
produceConnectAckMessage(message, headerAccessor);
328+
// Ignore the CONNECT frame in the client mode.
329+
// Essentially, it has been just initiated from the {@link #afterSessionStarted}.
330+
if (this.server) {
331+
produceConnectAckMessage(message, headerAccessor);
332+
}
317333
}
318334
else if (StompCommand.CONNECTED.equals(stompCommand)) {
319335
this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
@@ -338,10 +354,10 @@ else if (StompCommand.RECEIPT.equals(stompCommand)) {
338354
}
339355
}
340356

341-
private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor, @Nullable StompCommand stompCommand,
342-
@Nullable SimpMessageType messageType) {
357+
private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor,
358+
@Nullable StompCommand stompCommand, @Nullable SimpMessageType messageType) {
343359

344-
return (messageType == null // NOSONAR pretty simple logic
360+
return (messageType == null
345361
|| SimpMessageType.MESSAGE.equals(messageType)
346362
|| (SimpMessageType.CONNECT.equals(messageType) && !this.useBroker)
347363
|| StompCommand.CONNECTED.equals(stompCommand)

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/client/StompIntegrationTests.java

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.util.concurrent.CountDownLatch;
2626
import java.util.concurrent.TimeUnit;
2727

28-
import org.junit.jupiter.api.Disabled;
2928
import org.junit.jupiter.api.Test;
3029

3130
import org.springframework.beans.factory.annotation.Autowired;
@@ -85,6 +84,7 @@
8584
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
8685
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
8786
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
87+
import org.springframework.web.socket.messaging.SessionConnectEvent;
8888
import org.springframework.web.socket.messaging.SessionConnectedEvent;
8989
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
9090
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
@@ -95,6 +95,7 @@
9595
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
9696

9797
import static org.assertj.core.api.Assertions.assertThat;
98+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
9899

99100
/**
100101
* @author Artem Bilan
@@ -103,7 +104,6 @@
103104
*/
104105
@SpringJUnitConfig(classes = StompIntegrationTests.ClientConfig.class)
105106
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
106-
@Disabled("TODO until the lastest fix from SF mitigation")
107107
public class StompIntegrationTests {
108108

109109
@Value("#{server.serverContext}")
@@ -126,35 +126,44 @@ public class StompIntegrationTests {
126126

127127
@Test
128128
public void sendMessageToController() throws Exception {
129-
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
130-
this.webSocketOutputChannel.send(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
131-
132129
Message<?> receive = this.webSocketEvents.receive(20000);
133-
assertThat(receive).isNotNull();
134-
Object event = receive.getPayload();
135-
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
136-
Message<?> connectedMessage = ((SessionConnectedEvent) event).getMessage();
137-
headers = StompHeaderAccessor.wrap(connectedMessage);
138-
assertThat(headers.getCommand()).isEqualTo(StompCommand.CONNECTED);
130+
assertThat(receive)
131+
.extracting(Message::getPayload)
132+
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
133+
.isInstanceOf(SessionConnectEvent.class);
139134

140-
headers = StompHeaderAccessor.create(StompCommand.SEND);
135+
receive = this.webSocketEvents.receive(20000);
136+
assertThat(receive)
137+
.extracting(Message::getPayload)
138+
.asInstanceOf(type(SessionConnectedEvent.class))
139+
.extracting(SessionConnectedEvent::getMessage)
140+
.extracting(connectedMessage -> StompHeaderAccessor.wrap(connectedMessage).getCommand())
141+
.isEqualTo(StompCommand.CONNECTED);
142+
143+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
141144
headers.setSubscriptionId("sub1");
142145
headers.setDestination("/app/simple");
143146
Message<String> message = MessageBuilder.withPayload("foo").setHeaders(headers).build();
144147

145148
this.webSocketOutputChannel.send(message);
146149

147150
SimpleController controller = this.serverContext.getBean(SimpleController.class);
148-
assertThat(controller.latch.await(20, TimeUnit.SECONDS)).isTrue();
151+
assertThat(controller.latch.await(10, TimeUnit.SECONDS)).isTrue();
149152
assertThat(controller.stompCommand).isEqualTo(StompCommand.SEND.name());
150153
}
151154

152155
@Test
153156
public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
154157
Message<?> receive = this.webSocketEvents.receive(20000);
155-
assertThat(receive).isNotNull();
156-
Object event = receive.getPayload();
157-
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
158+
assertThat(receive)
159+
.extracting(Message::getPayload)
160+
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
161+
.isInstanceOf(SessionConnectEvent.class);
162+
163+
receive = this.webSocketEvents.receive(20000);
164+
assertThat(receive)
165+
.extracting(Message::getPayload)
166+
.isInstanceOf(SessionConnectedEvent.class);
158167

159168
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
160169
headers.setSubscriptionId("subs1");
@@ -167,13 +176,14 @@ public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
167176
this.webSocketOutputChannel.send(message);
168177

169178
receive = this.webSocketEvents.receive(20000);
170-
assertThat(receive).isNotNull();
171-
event = receive.getPayload();
172-
assertThat(event).isInstanceOf(ReceiptEvent.class);
173-
Message<?> receiptMessage = ((ReceiptEvent) event).getMessage();
174-
headers = StompHeaderAccessor.wrap(receiptMessage);
175-
assertThat(headers.getCommand()).isEqualTo(StompCommand.RECEIPT);
176-
assertThat(headers.getReceiptId()).isEqualTo("myReceipt");
179+
assertThat(receive)
180+
.extracting(Message::getPayload)
181+
.asInstanceOf(type(ReceiptEvent.class))
182+
.extracting(event -> StompHeaderAccessor.wrap(event.getMessage()))
183+
.satisfies(headerAccessor -> {
184+
assertThat(headerAccessor.getCommand()).isEqualTo(StompCommand.RECEIPT);
185+
assertThat(headerAccessor.getReceiptId()).isEqualTo("myReceipt");
186+
});
177187

178188
waitForSubscribe("/topic/increment");
179189

@@ -494,7 +504,7 @@ public void configureMessageBroker(MessageBrokerRegistry configurer) {
494504
public ApplicationListener<SessionSubscribeEvent> webSocketEventListener(
495505
final AbstractSubscribableChannel clientOutboundChannel) {
496506
// Cannot be lambda because Java can't infer generic type from lambdas,
497-
// therefore we end up with ClassCastException for other event types
507+
// therefore, we end up with ClassCastException for other event types
498508
return new ApplicationListener<SessionSubscribeEvent>() {
499509

500510
@Override

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/client/WebSocketClientTests.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.Map;
2222

2323
import org.apache.tomcat.websocket.Constants;
24-
import org.junit.jupiter.api.Disabled;
2524
import org.junit.jupiter.api.Test;
2625

2726
import org.springframework.beans.factory.annotation.Autowired;
@@ -68,7 +67,6 @@
6867
*/
6968
@SpringJUnitConfig(classes = WebSocketClientTests.ClientConfig.class)
7069
@DirtiesContext
71-
@Disabled("TODO until the lastest fix from SF mitigation")
7270
public class WebSocketClientTests {
7371

7472
@Autowired

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/inbound/WebSocketInboundChannelAdapterTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.Collections;
2121
import java.util.Map;
2222

23-
import org.junit.jupiter.api.Disabled;
2423
import org.junit.jupiter.api.Test;
2524

2625
import org.springframework.beans.factory.annotation.Autowired;
@@ -64,7 +63,6 @@
6463
*/
6564
@SpringJUnitConfig
6665
@DirtiesContext
67-
@Disabled("TODO until the lastest fix from SF mitigation")
6866
public class WebSocketInboundChannelAdapterTests {
6967

7068
@Value("#{server.serverContext.getBean('subProtocolWebSocketHandler')}")
@@ -98,6 +96,7 @@ public void testWebSocketInboundChannelAdapter() {
9896
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
9997
headers.setLeaveMutable(true);
10098
headers.setSessionId(sessionId);
99+
headers.setSubscriptionId("sub1");
101100
Message<byte[]> message =
102101
MessageBuilder.createMessage(ByteBuffer.allocate(0).array(), headers.getMessageHeaders());
103102

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/outbound/WebSocketOutboundMessageHandlerTests.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import java.util.Collections;
2020

21-
import org.junit.jupiter.api.Disabled;
2221
import org.junit.jupiter.api.Test;
2322

2423
import org.springframework.beans.factory.annotation.Autowired;
@@ -56,7 +55,6 @@
5655
*/
5756
@SpringJUnitConfig
5857
@DirtiesContext
59-
@Disabled("TODO until the lastest fix from SF mitigation")
6058
public class WebSocketOutboundMessageHandlerTests {
6159

6260
@Autowired
@@ -68,22 +66,32 @@ public class WebSocketOutboundMessageHandlerTests {
6866

6967
@Test
7068
public void testWebSocketOutboundMessageHandler() {
71-
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
69+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
70+
this.messageHandler.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
71+
72+
headers = StompHeaderAccessor.create(StompCommand.SEND);
7273
headers.setMessageId("mess0");
7374
headers.setSubscriptionId("sub0");
74-
headers.setDestination("/foo");
75+
headers.setDestination("/dest");
7576
String payload = "Hello World";
7677
Message<String> message = MessageBuilder.withPayload(payload).setHeaders(headers).build();
7778

7879
this.messageHandler.handleMessage(message);
7980

8081
Message<?> received = this.clientInboundChannel.receive(10000);
81-
assertThat(received).isNotNull();
82-
83-
StompHeaderAccessor receivedHeaders = StompHeaderAccessor.wrap(received);
84-
assertThat(receivedHeaders.getMessageId()).isEqualTo("mess0");
85-
assertThat(receivedHeaders.getSubscriptionId()).isEqualTo("sub0");
86-
assertThat(receivedHeaders.getDestination()).isEqualTo("/foo");
82+
assertThat(received)
83+
.extracting(StompHeaderAccessor::wrap)
84+
.extracting(StompHeaderAccessor::getCommand)
85+
.isEqualTo(StompCommand.CONNECT);
86+
87+
received = this.clientInboundChannel.receive(10000);
88+
assertThat(received)
89+
.extracting(StompHeaderAccessor::wrap)
90+
.satisfies(headerAccessor -> {
91+
assertThat(headerAccessor.getMessageId()).isEqualTo("mess0");
92+
assertThat(headerAccessor.getSubscriptionId()).isEqualTo("sub0");
93+
assertThat(headerAccessor.getDestination()).isEqualTo("/dest");
94+
});
8795

8896
Object receivedPayload = received.getPayload();
8997
assertThat(receivedPayload).isInstanceOf(byte[].class);

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/server/WebSocketServerTests.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.Collections;
2121
import java.util.List;
2222

23-
import org.junit.jupiter.api.Disabled;
2423
import org.junit.jupiter.api.Test;
2524
import org.mockito.Mockito;
2625

@@ -118,7 +117,6 @@ public class WebSocketServerTests {
118117
private Lifecycle requestUpgradeStrategy;
119118

120119
@Test
121-
@Disabled("TODO until the lastest fix from SF mitigation")
122120
public void testWebSocketOutboundMessageHandler() {
123121
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
124122
headers.setSubscriptionId("subs1");

0 commit comments

Comments
 (0)