Skip to content

Commit ebffd67

Browse files
committed
Add BufferingStompDecoder
Before this change the StompDecoder decoded and returned only the first Message in the ByteBuffer passed to it. So to obtain all messages from the buffer, one had to loop passing the same buffer in until no more complete STOMP frames could be decoded. This chage modifies StompDecoder to return List<Message> after exhaustively decoding all available STOMP frames from the input buffer. Also an overloaded decode method allows passing in Map that will be populated with any headers successfully parsed, which is useful for "peeking" at the "content-length" header. This change also adds a BufferingStompDecoder sub-class which buffers any content left in the input buffer after parsing one or more STOMP frames. This sub-class can also deal with fragmented messages, re-assembling them and parsing as a whole message. Issue: SPR-11527
1 parent 465ca24 commit ebffd67

File tree

7 files changed

+481
-73
lines changed

7 files changed

+481
-73
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright 2002-2014 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.simp.stomp;
18+
19+
20+
import org.springframework.messaging.Message;
21+
import org.springframework.util.Assert;
22+
import org.springframework.util.LinkedMultiValueMap;
23+
import org.springframework.util.MultiValueMap;
24+
25+
import java.nio.ByteBuffer;
26+
import java.util.ArrayList;
27+
import java.util.Collections;
28+
import java.util.List;
29+
import java.util.Queue;
30+
import java.util.concurrent.LinkedBlockingQueue;
31+
32+
33+
/**
34+
* A an extension of {@link org.springframework.messaging.simp.stomp.StompDecoder}
35+
* that chunks any bytes remaining after a single full STOMP frame has been read.
36+
* The remaining bytes may contain more STOMP frames or an incomplete STOMP frame.
37+
*
38+
* <p>Similarly if there is not enough content for a full STOMP frame, the content
39+
* is buffered until more input is received. That means the
40+
* {@link #decode(java.nio.ByteBuffer)} effectively never returns {@code null} as
41+
* the parent class does.
42+
*
43+
* @author Rossen Stoyanchev
44+
* @since 4.0.3
45+
*/
46+
public class BufferingStompDecoder extends StompDecoder {
47+
48+
private final int bufferSizeLimit;
49+
50+
private final Queue<ByteBuffer> chunks = new LinkedBlockingQueue<ByteBuffer>();
51+
52+
private volatile Integer expectedContentLength;
53+
54+
55+
public BufferingStompDecoder(int bufferSizeLimit) {
56+
Assert.isTrue(bufferSizeLimit > 0, "Buffer size must be greater than 0");
57+
this.bufferSizeLimit = bufferSizeLimit;
58+
}
59+
60+
61+
public int getBufferSizeLimit() {
62+
return this.bufferSizeLimit;
63+
}
64+
65+
public int getBufferSize() {
66+
int size = 0;
67+
for (ByteBuffer buffer : this.chunks) {
68+
size = size + buffer.remaining();
69+
}
70+
return size;
71+
}
72+
73+
public Integer getExpectedContentLength() {
74+
return this.expectedContentLength;
75+
}
76+
77+
78+
@Override
79+
public List<Message<byte[]>> decode(ByteBuffer newData) {
80+
81+
this.chunks.add(newData);
82+
83+
checkBufferLimits();
84+
85+
if (getExpectedContentLength() != null && getBufferSize() < this.expectedContentLength) {
86+
return Collections.<Message<byte[]>>emptyList();
87+
}
88+
89+
ByteBuffer buffer = assembleChunksAndReset();
90+
91+
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
92+
List<Message<byte[]>> messages = decode(buffer, headers);
93+
94+
if (buffer.hasRemaining()) {
95+
this.chunks.add(buffer);
96+
this.expectedContentLength = getContentLength(headers);
97+
}
98+
99+
return messages;
100+
}
101+
102+
private void checkBufferLimits() {
103+
if (getExpectedContentLength() != null) {
104+
if (getExpectedContentLength() > getBufferSizeLimit()) {
105+
throw new StompConversionException(
106+
"The 'content-length' header " + getExpectedContentLength() +
107+
" exceeds the configured message buffer size limit " + getBufferSizeLimit());
108+
}
109+
}
110+
if (getBufferSize() > getBufferSizeLimit()) {
111+
throw new StompConversionException("The configured stomp frame buffer size limit of " +
112+
getBufferSizeLimit() + " bytes has been exceeded");
113+
114+
}
115+
}
116+
117+
private ByteBuffer assembleChunksAndReset() {
118+
ByteBuffer result;
119+
if (this.chunks.size() == 1) {
120+
result = this.chunks.remove();
121+
}
122+
else {
123+
result = ByteBuffer.allocate(getBufferSize());
124+
for (ByteBuffer partial : this.chunks) {
125+
result.put(partial);
126+
}
127+
result.flip();
128+
}
129+
this.chunks.clear();
130+
this.expectedContentLength = null;
131+
return result;
132+
}
133+
134+
}

spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2013 the original author or authors.
2+
* Copyright 2002-2014 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,10 +23,13 @@
2323
import reactor.io.Buffer;
2424
import reactor.tcp.encoding.Codec;
2525

26+
import java.util.List;
27+
2628
/**
27-
* A Reactor TCP {@link Codec} for sending and receiving STOMP messages
29+
* A Reactor TCP {@link Codec} for sending and receiving STOMP messages.
2830
*
2931
* @author Andy Wilkinson
32+
* @author Rossen Stoyanchev
3033
* @since 4.0
3134
*/
3235
public class StompCodec implements Codec<Buffer, Message<byte[]>, Message<byte[]>> {
@@ -49,14 +52,8 @@ public Function<Buffer, Message<byte[]>> decoder(final Consumer<Message<byte[]>>
4952

5053
@Override
5154
public Message<byte[]> apply(Buffer buffer) {
52-
while (buffer.remaining() > 0) {
53-
Message<byte[]> message = DECODER.decode(buffer.byteBuffer());
54-
if (message != null) {
55-
next.accept(message);
56-
}
57-
else {
58-
break;
59-
}
55+
for (Message<byte[]> message : DECODER.decode(buffer.byteBuffer())) {
56+
next.accept(message);
6057
}
6158
return null;
6259
}

spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.io.ByteArrayOutputStream;
2020
import java.nio.ByteBuffer;
2121
import java.nio.charset.Charset;
22+
import java.util.ArrayList;
23+
import java.util.List;
2224

2325
import org.apache.commons.logging.Log;
2426
import org.apache.commons.logging.LogFactory;
@@ -30,9 +32,10 @@
3032
import org.springframework.util.MultiValueMap;
3133

3234
/**
33-
* Decodes STOMP frames from a {@link ByteBuffer}. If the buffer does not contain
34-
* enough data to form a complete STOMP frame, the buffer is reset and the value
35-
* returned is {@code null} indicating that no message could be read.
35+
* Decodes one or more STOMP frames from a {@link ByteBuffer}. If the buffer
36+
* contains any additional (incomplete) data, or perhaps not enough data to
37+
* form even one Message, the the buffer is reset and the value returned is
38+
* an empty list indicating that no more message can be read.
3639
*
3740
* @author Andy Wilkinson
3841
* @author Rossen Stoyanchev
@@ -47,21 +50,66 @@ public class StompDecoder {
4750
private final Log logger = LogFactory.getLog(StompDecoder.class);
4851

4952

53+
54+
/**
55+
* Decodes one or more STOMP frames from the given {@code buffer} into a
56+
* list of {@link Message}s.
57+
*
58+
* <p>If the given ByteBuffer contains partial STOMP frame content, or additional
59+
* content with a partial STOMP frame, the buffer is reset and {@code null} is
60+
* returned.
61+
*
62+
* @param buffer The buffer to decode the STOMP frame from
63+
*
64+
* @return the decoded messages or an empty list
65+
*/
66+
public List<Message<byte[]>> decode(ByteBuffer buffer) {
67+
return decode(buffer, new LinkedMultiValueMap<String, String>());
68+
}
69+
70+
/**
71+
* Decodes one or more STOMP frames from the given {@code buffer} into a
72+
* list of {@link Message}s.
73+
*
74+
* <p>If the given ByteBuffer contains partial STOMP frame content, or additional
75+
* content with a partial STOMP frame, the buffer is reset and {@code null} is
76+
* returned.
77+
*
78+
* @param buffer The buffer to decode the STOMP frame from
79+
* @param headers an empty map that will be filled with the successfully parsed
80+
* headers of the last decoded message, or the last attempt at decoding an
81+
* (incomplete) STOMP frame. This can be useful for detecting 'content-length'.
82+
*
83+
* @return the decoded messages or an empty list
84+
*/
85+
public List<Message<byte[]>> decode(ByteBuffer buffer, MultiValueMap<String, String> headers) {
86+
List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
87+
while (buffer.hasRemaining()) {
88+
headers.clear();
89+
Message<byte[]> m = decodeMessage(buffer, headers);
90+
if (m != null) {
91+
messages.add(m);
92+
}
93+
else {
94+
break;
95+
}
96+
}
97+
return messages;
98+
}
99+
50100
/**
51-
* Decodes a STOMP frame in the given {@code buffer} into a {@link Message}.
52-
* If the given ByteBuffer contains partial STOMP frame content, the method
53-
* resets the buffer and returns {@code null}.
54-
* @param buffer the buffer to decode the frame from
55-
* @return the decoded message or {@code null}
101+
* Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
56102
*/
57-
public Message<byte[]> decode(ByteBuffer buffer) {
103+
private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, String> headers) {
104+
58105
Message<byte[]> decodedMessage = null;
59106
skipLeadingEol(buffer);
60107
buffer.mark();
61108

62109
String command = readCommand(buffer);
63110
if (command.length() > 0) {
64-
MultiValueMap<String, String> headers = readHeaders(buffer);
111+
112+
readHeaders(buffer, headers);
65113
byte[] payload = readPayload(buffer, headers);
66114

67115
if (payload != null) {
@@ -78,7 +126,7 @@ public Message<byte[]> decode(ByteBuffer buffer) {
78126
}
79127
else {
80128
if (logger.isTraceEnabled()) {
81-
logger.trace("Received incomplete frame. Resetting buffer");
129+
logger.trace("Received incomplete frame. Resetting buffer.");
82130
}
83131
buffer.reset();
84132
}
@@ -93,27 +141,31 @@ public Message<byte[]> decode(ByteBuffer buffer) {
93141
return decodedMessage;
94142
}
95143

96-
private void skipLeadingEol(ByteBuffer buffer) {
144+
145+
/**
146+
* Skip one ore more EOL characters at the start of the given ByteBuffer.
147+
* Those are STOMP heartbeat frames.
148+
*/
149+
protected void skipLeadingEol(ByteBuffer buffer) {
97150
while (true) {
98-
if (!isEol(buffer)) {
151+
if (!tryConsumeEndOfLine(buffer)) {
99152
break;
100153
}
101154
}
102155
}
103156

104157
private String readCommand(ByteBuffer buffer) {
105158
ByteArrayOutputStream command = new ByteArrayOutputStream(256);
106-
while (buffer.remaining() > 0 && !isEol(buffer)) {
159+
while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
107160
command.write(buffer.get());
108161
}
109162
return new String(command.toByteArray(), UTF8_CHARSET);
110163
}
111164

112-
private MultiValueMap<String, String> readHeaders(ByteBuffer buffer) {
113-
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
165+
private void readHeaders(ByteBuffer buffer, MultiValueMap<String, String> headers) {
114166
while (true) {
115167
ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
116-
while (buffer.remaining() > 0 && !isEol(buffer)) {
168+
while (buffer.remaining() > 0 && !tryConsumeEndOfLine(buffer)) {
117169
headerStream.write(buffer.get());
118170
}
119171
if (headerStream.size() > 0) {
@@ -135,7 +187,6 @@ private MultiValueMap<String, String> readHeaders(ByteBuffer buffer) {
135187
break;
136188
}
137189
}
138-
return headers;
139190
}
140191

141192
private String unescape(String input) {
@@ -146,16 +197,7 @@ private String unescape(String input) {
146197
}
147198

148199
private byte[] readPayload(ByteBuffer buffer, MultiValueMap<String, String> headers) {
149-
Integer contentLength = null;
150-
if (headers.containsKey("content-length")) {
151-
String rawContentLength = headers.getFirst("content-length");
152-
try {
153-
contentLength = Integer.valueOf(rawContentLength);
154-
}
155-
catch (NumberFormatException ex) {
156-
logger.warn("Ignoring invalid content-length header value: '" + rawContentLength + "'");
157-
}
158-
}
200+
Integer contentLength = getContentLength(headers);
159201
if (contentLength != null && contentLength >= 0) {
160202
if (buffer.remaining() > contentLength) {
161203
byte[] payload = new byte[contentLength];
@@ -184,7 +226,25 @@ private byte[] readPayload(ByteBuffer buffer, MultiValueMap<String, String> head
184226
return null;
185227
}
186228

187-
private boolean isEol(ByteBuffer buffer) {
229+
protected Integer getContentLength(MultiValueMap<String, String> headers) {
230+
if (headers.containsKey(StompHeaderAccessor.STOMP_CONTENT_LENGTH_HEADER)) {
231+
String rawContentLength = headers.getFirst(StompHeaderAccessor.STOMP_CONTENT_LENGTH_HEADER);
232+
try {
233+
return Integer.valueOf(rawContentLength);
234+
}
235+
catch (NumberFormatException ex) {
236+
logger.warn("Ignoring invalid content-length header value: '" + rawContentLength + "'");
237+
}
238+
}
239+
return null;
240+
}
241+
242+
/**
243+
* Try to read an EOL incrementing the buffer position if successful.
244+
*
245+
* @return whether an EOL was consumed
246+
*/
247+
private boolean tryConsumeEndOfLine(ByteBuffer buffer) {
188248
if (buffer.remaining() > 0) {
189249
byte b = buffer.get();
190250
if (b == '\n') {

0 commit comments

Comments
 (0)