Skip to content

Commit 329fbf3

Browse files
committed
Fix concurrency issue in DefaultSubscriptionRegistry
1 parent d73c2e2 commit 329fbf3

File tree

4 files changed

+88
-39
lines changed

4 files changed

+88
-39
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
4343

4444

4545
public SimpMessagingTemplate(MessageChannel messageChannel) {
46-
Assert.notNull(messageChannel, "outputChannel is required");
46+
Assert.notNull(messageChannel, "messageChannel is required");
4747
this.messageChannel = messageChannel;
4848
}
4949

@@ -117,6 +117,8 @@ protected void doSend(String destination, Message<?> message) {
117117
}
118118
}
119119

120+
121+
120122
@Override
121123
public <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException {
122124
convertAndSendToUser(user, destination, message, null);

spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
import org.springframework.core.MethodParameter;
3535
import org.springframework.core.annotation.AnnotationUtils;
3636
import org.springframework.messaging.Message;
37+
import org.springframework.messaging.MessageChannel;
3738
import org.springframework.messaging.MessageHandler;
3839
import org.springframework.messaging.MessagingException;
40+
import org.springframework.messaging.core.AbstractMessageSendingTemplate;
3941
import org.springframework.messaging.handler.annotation.MessageMapping;
4042
import org.springframework.messaging.handler.annotation.ReplyTo;
4143
import org.springframework.messaging.handler.annotation.support.ExceptionHandlerMethodResolver;
@@ -48,6 +50,7 @@
4850
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
4951
import org.springframework.messaging.simp.SimpMessageSendingOperations;
5052
import org.springframework.messaging.simp.SimpMessageType;
53+
import org.springframework.messaging.simp.SimpMessagingTemplate;
5154
import org.springframework.messaging.simp.annotation.SubscribeEvent;
5255
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
5356
import org.springframework.messaging.simp.annotation.support.PrincipalMethodArgumentResolver;
@@ -68,9 +71,9 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
6871

6972
private static final Log logger = LogFactory.getLog(AnnotationMethodMessageHandler.class);
7073

71-
private final SimpMessageSendingOperations inboundMessagingTemplate;
74+
private final SimpMessageSendingOperations dispatchMessagingTemplate;
7275

73-
private final SimpMessageSendingOperations outboundMessagingTemplate;
76+
private final SimpMessageSendingOperations webSocketSessionMessagingTemplate;
7477

7578
private MessageConverter<?> messageConverter;
7679

@@ -91,31 +94,30 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
9194

9295

9396
/**
94-
* @param inboundMessagingTemplate a template for sending messages on the channel
95-
* where incoming messages from clients are sent; essentially messages sent
96-
* through this template will be re-processed by the application. One example
97-
* is the use of {@link ReplyTo} annotation on a method to send a broadcast
98-
* message.
99-
* @param outboundMessagingTemplate a template for sending messages on the client used
100-
* to send messages back out to connected clients; such messages must have all
101-
* necessary information to reach the client such as session and subscription
102-
* id's. One example is returning a value from an {@link SubscribeEvent}
103-
* method.
97+
* @param dispatchMessagingTemplate a messaging template to dispatch messages to for
98+
* further processing, e.g. the use of an {@link ReplyTo} annotation on a
99+
* message handling method, causes a new (broadcast) message to be sent.
100+
* @param webSocketSessionChannel the channel to send messages to WebSocket sessions
101+
* on this application server. This is used primarily for processing the return
102+
* values from {@link SubscribeEvent}-annotated methods.
104103
*/
105-
public AnnotationMethodMessageHandler(SimpMessageSendingOperations inboundMessagingTemplate,
106-
SimpMessageSendingOperations outboundMessagingTemplate) {
104+
public AnnotationMethodMessageHandler(SimpMessageSendingOperations dispatchMessagingTemplate,
105+
MessageChannel webSocketSessionChannel) {
107106

108-
Assert.notNull(inboundMessagingTemplate, "inboundMessagingTemplate is required");
109-
Assert.notNull(outboundMessagingTemplate, "outboundMessagingTemplate is required");
110-
this.inboundMessagingTemplate = inboundMessagingTemplate;
111-
this.outboundMessagingTemplate = outboundMessagingTemplate;
107+
Assert.notNull(dispatchMessagingTemplate, "dispatchMessagingTemplate is required");
108+
Assert.notNull(webSocketSessionChannel, "webSocketSessionChannel is required");
109+
this.dispatchMessagingTemplate = dispatchMessagingTemplate;
110+
this.webSocketSessionMessagingTemplate = new SimpMessagingTemplate(webSocketSessionChannel);
112111
}
113112

114113
/**
115114
* TODO: multiple converters with 'content-type' header
116115
*/
117116
public void setMessageConverter(MessageConverter<?> converter) {
118117
this.messageConverter = converter;
118+
if (converter != null) {
119+
((AbstractMessageSendingTemplate<?>) this.webSocketSessionMessagingTemplate).setMessageConverter(converter);
120+
}
119121
}
120122

121123
@Override
@@ -131,8 +133,8 @@ public void afterPropertiesSet() {
131133
this.argumentResolvers.addResolver(new PrincipalMethodArgumentResolver());
132134
this.argumentResolvers.addResolver(new MessageBodyMethodArgumentResolver(this.messageConverter));
133135

134-
this.returnValueHandlers.addHandler(new ReplyToMethodReturnValueHandler(this.inboundMessagingTemplate));
135-
this.returnValueHandlers.addHandler(new SubscriptionMethodReturnValueHandler(this.outboundMessagingTemplate));
136+
this.returnValueHandlers.addHandler(new ReplyToMethodReturnValueHandler(this.dispatchMessagingTemplate));
137+
this.returnValueHandlers.addHandler(new SubscriptionMethodReturnValueHandler(this.webSocketSessionMessagingTemplate));
136138
}
137139

138140
protected void initHandlerMethods() {

spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.messaging.simp.handler;
1818

1919
import java.util.Collection;
20-
import java.util.HashMap;
2120
import java.util.HashSet;
2221
import java.util.Map;
2322
import java.util.Set;
@@ -29,6 +28,8 @@
2928
import org.springframework.util.LinkedMultiValueMap;
3029
import org.springframework.util.MultiValueMap;
3130

31+
import reactor.util.Assert;
32+
3233

3334
/**
3435
* @author Rossen Stoyanchev
@@ -102,6 +103,14 @@ protected MultiValueMap<String, String> findSubscriptionsInternal(String destina
102103
return result;
103104
}
104105

106+
@Override
107+
public String toString() {
108+
return "[destinationCache=" + this.destinationCache + ", subscriptionRegistry="
109+
+ this.subscriptionRegistry + "]";
110+
}
111+
112+
113+
105114

106115
/**
107116
* Provide direct lookup of session subscriptions by destination (for non-pattern destinations).
@@ -116,7 +125,7 @@ private static class DestinationCache {
116125

117126

118127
public void mapToDestination(String destination, SessionSubscriptionInfo info) {
119-
synchronized (monitor) {
128+
synchronized(this.monitor) {
120129
Set<SessionSubscriptionInfo> registrations = this.subscriptionsByDestination.get(destination);
121130
if (registrations == null) {
122131
registrations = new CopyOnWriteArraySet<SessionSubscriptionInfo>();
@@ -127,7 +136,7 @@ public void mapToDestination(String destination, SessionSubscriptionInfo info) {
127136
}
128137

129138
public void unmapFromDestination(String destination, SessionSubscriptionInfo info) {
130-
synchronized (monitor) {
139+
synchronized(this.monitor) {
131140
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
132141
if (infos != null) {
133142
infos.remove(info);
@@ -159,6 +168,11 @@ public MultiValueMap<String, String> getSubscriptions(String destination) {
159168
}
160169
return result;
161170
}
171+
172+
@Override
173+
public String toString() {
174+
return "[subscriptionsByDestination=" + this.subscriptionsByDestination + "]";
175+
}
162176
}
163177

164178
/**
@@ -169,6 +183,8 @@ private static class SessionSubscriptionRegistry {
169183
private final Map<String, SessionSubscriptionInfo> sessions =
170184
new ConcurrentHashMap<String, SessionSubscriptionInfo>();
171185

186+
private final Object monitor = new Object();
187+
172188

173189
public SessionSubscriptionInfo getSubscriptions(String sessionId) {
174190
return this.sessions.get(sessionId);
@@ -181,16 +197,26 @@ public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
181197
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) {
182198
SessionSubscriptionInfo info = this.sessions.get(sessionId);
183199
if (info == null) {
184-
info = new SessionSubscriptionInfo(sessionId);
185-
this.sessions.put(sessionId, info);
200+
synchronized(this.monitor) {
201+
info = this.sessions.get(sessionId);
202+
if (info == null) {
203+
info = new SessionSubscriptionInfo(sessionId);
204+
this.sessions.put(sessionId, info);
205+
}
206+
}
186207
}
187-
info.addSubscription(subscriptionId, destination);
208+
info.addSubscription(destination, subscriptionId);
188209
return info;
189210
}
190211

191212
public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
192213
return this.sessions.remove(sessionId);
193214
}
215+
216+
@Override
217+
public String toString() {
218+
return "[sessions=" + sessions + "]";
219+
}
194220
}
195221

196222
/**
@@ -200,10 +226,13 @@ private static class SessionSubscriptionInfo {
200226

201227
private final String sessionId;
202228

203-
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
229+
private final Map<String, Set<String>> subscriptions = new ConcurrentHashMap<String, Set<String>>(4);
230+
231+
private final Object monitor = new Object();
204232

205233

206234
public SessionSubscriptionInfo(String sessionId) {
235+
Assert.notNull(sessionId, "sessionId is required");
207236
this.sessionId = sessionId;
208237
}
209238

@@ -219,27 +248,36 @@ public Set<String> getSubscriptions(String destination) {
219248
return this.subscriptions.get(destination);
220249
}
221250

222-
public void addSubscription(String subscriptionId, String destination) {
223-
Set<String> subs = this.subscriptions.get(destination);
224-
if (subs == null) {
225-
subs = new HashSet<String>(4);
226-
this.subscriptions.put(destination, subs);
251+
public void addSubscription(String destination, String subscriptionId) {
252+
synchronized(this.monitor) {
253+
Set<String> subs = this.subscriptions.get(destination);
254+
if (subs == null) {
255+
subs = new HashSet<String>(4);
256+
this.subscriptions.put(destination, subs);
257+
}
258+
subs.add(subscriptionId);
227259
}
228-
subs.add(subscriptionId);
229260
}
230261

231262
public String removeSubscription(String subscriptionId) {
232263
for (String destination : this.subscriptions.keySet()) {
233264
Set<String> subscriptionIds = this.subscriptions.get(destination);
234265
if (subscriptionIds.remove(subscriptionId)) {
235-
if (subscriptionIds.isEmpty()) {
236-
this.subscriptions.remove(destination);
266+
synchronized(this.monitor) {
267+
if (subscriptionIds.isEmpty()) {
268+
this.subscriptions.remove(destination);
269+
}
237270
}
238271
return destination;
239272
}
240273
}
241274
return null;
242275
}
276+
277+
@Override
278+
public String toString() {
279+
return "[sessionId=" + this.sessionId + ", subscriptions=" + this.subscriptions + "]";
280+
}
243281
}
244282

245283
}

spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,30 @@ public void handleMessage(Message<?> message) throws MessagingException {
6868
SimpMessageType messageType = headers.getMessageType();
6969

7070
if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
71-
// TODO: need a way to communicate back if subscription was successfully created or
72-
// not in which case an ERROR should be sent back and close the connection
73-
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
71+
preProcessMessage(message);
7472
this.subscriptionRegistry.registerSubscription(message);
7573
}
7674
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
75+
preProcessMessage(message);
7776
this.subscriptionRegistry.unregisterSubscription(message);
7877
}
7978
else if (SimpMessageType.MESSAGE.equals(messageType)) {
79+
preProcessMessage(message);
8080
sendMessageToSubscribers(headers.getDestination(), message);
8181
}
8282
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
83+
preProcessMessage(message);
8384
String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId();
8485
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
8586
}
8687
}
8788

89+
private void preProcessMessage(Message<?> message) {
90+
if (logger.isTraceEnabled()) {
91+
logger.trace("Processing " + message);
92+
}
93+
}
94+
8895
protected void sendMessageToSubscribers(String destination, Message<?> message) {
8996
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
9097
for (String sessionId : subscriptions.keySet()) {

0 commit comments

Comments
 (0)