Skip to content

Commit 9331e1e

Browse files
committed
Allow Protobuf codec extensions
Closes spring-projectsgh-35403
1 parent 6ebb207 commit 9331e1e

File tree

3 files changed

+127
-52
lines changed

3 files changed

+127
-52
lines changed

spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,22 @@ public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType
129129
public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType,
130130
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
131131

132-
MessageDecoderFunction decoderFunction = new MessageDecoderFunction(elementType, this.maxMessageSize);
132+
MessageDecoderFunction decoderFunction =
133+
new MessageDecoderFunction(elementType, this.maxMessageSize, initMessageSizeReader());
133134

134135
return Flux.from(inputStream)
135136
.flatMapIterable(decoderFunction)
136137
.doOnTerminate(decoderFunction::discard);
137138
}
138139

140+
/**
141+
* Return a reader for message size information encoded in the input stream.
142+
* @since 7.0
143+
*/
144+
protected MessageSizeReader initMessageSizeReader() {
145+
return new DefaultMessageSizeReader();
146+
}
147+
139148
@Override
140149
public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType,
141150
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
@@ -150,9 +159,7 @@ public Message decode(DataBuffer dataBuffer, ResolvableType targetType,
150159

151160
try {
152161
Message.Builder builder = getMessageBuilder(targetType.toClass());
153-
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount());
154-
dataBuffer.toByteBuffer(byteBuffer);
155-
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
162+
merge(dataBuffer, builder);
156163
return builder.build();
157164
}
158165
catch (IOException ex) {
@@ -166,6 +173,17 @@ public Message decode(DataBuffer dataBuffer, ResolvableType targetType,
166173
}
167174
}
168175

176+
/**
177+
* Use merge methods on {@link Message.Builder} to read a single message
178+
* from the given {@code DataBuffer}.
179+
* @since 7.0
180+
*/
181+
protected void merge(DataBuffer dataBuffer, Message.Builder builder) throws IOException {
182+
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount());
183+
dataBuffer.toByteBuffer(byteBuffer);
184+
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
185+
}
186+
169187

170188
/**
171189
* Create a new {@code Message.Builder} instance for the given class.
@@ -196,15 +214,14 @@ private class MessageDecoderFunction implements Function<DataBuffer, Iterable<?
196214

197215
private int messageBytesToRead;
198216

199-
private int offset;
200-
217+
private final MessageSizeReader messageSizeReader;
201218

202-
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
219+
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize, MessageSizeReader messageSizeReader) {
203220
this.elementType = elementType;
204221
this.maxMessageSize = maxMessageSize;
222+
this.messageSizeReader = messageSizeReader;
205223
}
206224

207-
208225
@Override
209226
public Iterable<? extends Message> apply(DataBuffer input) {
210227
try {
@@ -214,9 +231,11 @@ public Iterable<? extends Message> apply(DataBuffer input) {
214231

215232
do {
216233
if (this.output == null) {
217-
if (!readMessageSize(input)) {
234+
Integer messageSize = this.messageSizeReader.readMessageSize(input);
235+
if (messageSize == null) {
218236
return messages;
219237
}
238+
this.messageBytesToRead = messageSize;
220239
if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) {
221240
throw new DataBufferLimitException(
222241
"The number of bytes to read for message " +
@@ -262,60 +281,89 @@ public Iterable<? extends Message> apply(DataBuffer input) {
262281
}
263282
}
264283

284+
public void discard() {
285+
if (this.output != null) {
286+
DataBufferUtils.release(this.output);
287+
}
288+
}
289+
}
290+
291+
/**
292+
* Component to read the size of a message. Implementations must be
293+
* stateful and expect size information is potentially split
294+
* across input chunks.
295+
* @since 7.0
296+
*/
297+
protected interface MessageSizeReader {
298+
265299
/**
266-
* Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and
267-
* {@code offset} fields if needed to allow processing of upcoming chunks.
268-
* Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)}
269-
* @return {@code true} when the message size is parsed successfully, {@code false} when the message size is
270-
* truncated
271-
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
300+
* Read the message size from the given buffer. This method may be
301+
* called multiple times before the message size is fully read.
302+
* @return return the message size, or {@code null} if the data in the
303+
* input buffer was insufficient
272304
*/
273-
private boolean readMessageSize(DataBuffer input) {
305+
@Nullable Integer readMessageSize(DataBuffer input);
306+
}
307+
308+
309+
/**
310+
* Default reader for Protobuf messages.
311+
* <p>Parses the message size as a varint from the input stream.
312+
* Inspired by {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)},
313+
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
314+
*/
315+
private static class DefaultMessageSizeReader implements MessageSizeReader {
316+
317+
private int offset;
318+
319+
private int messageSize;
320+
321+
@Override
322+
public @Nullable Integer readMessageSize(DataBuffer input) {
274323
if (this.offset == 0) {
275324
if (input.readableByteCount() == 0) {
276-
return false;
325+
return null;
277326
}
278327
int firstByte = input.read();
279328
if ((firstByte & 0x80) == 0) {
280-
this.messageBytesToRead = firstByte;
281-
return true;
329+
this.messageSize = firstByte;
330+
return getAndReset();
282331
}
283-
this.messageBytesToRead = firstByte & 0x7f;
332+
this.messageSize = firstByte & 0x7f;
284333
this.offset = 7;
285334
}
286335

287336
if (this.offset < 32) {
288337
for (; this.offset < 32; this.offset += 7) {
289338
if (input.readableByteCount() == 0) {
290-
return false;
339+
return null;
291340
}
292341
final int b = input.read();
293-
this.messageBytesToRead |= (b & 0x7f) << this.offset;
342+
this.messageSize |= (b & 0x7f) << this.offset;
294343
if ((b & 0x80) == 0) {
295-
this.offset = 0;
296-
return true;
344+
return getAndReset();
297345
}
298346
}
299347
}
300348
// Keep reading up to 64 bits.
301349
for (; this.offset < 64; this.offset += 7) {
302350
if (input.readableByteCount() == 0) {
303-
return false;
351+
return null;
304352
}
305353
final int b = input.read();
306354
if ((b & 0x80) == 0) {
307-
this.offset = 0;
308-
return true;
355+
return getAndReset();
309356
}
310357
}
311-
this.offset = 0;
358+
getAndReset();
312359
throw new DecodingException("Cannot parse message size: malformed varint");
313360
}
314361

315-
public void discard() {
316-
if (this.output != null) {
317-
DataBufferUtils.release(this.output);
318-
}
362+
private @Nullable Integer getAndReset() {
363+
Integer result = (this.messageSize != 0 ? this.messageSize : null);
364+
this.offset = 0;
365+
this.messageSize = 0;
366+
return result;
319367
}
320368
}
321369

spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,30 @@ public DataBuffer encodeValue(Message message, DataBufferFactory bufferFactory,
107107
}
108108

109109
private DataBuffer encodeValue(Message message, DataBufferFactory bufferFactory, boolean delimited) {
110-
FastByteArrayOutputStream bos = new FastByteArrayOutputStream();
110+
FastByteArrayOutputStream outputStream = new FastByteArrayOutputStream();
111111
try {
112-
if (delimited) {
113-
message.writeDelimitedTo((OutputStream) bos);
114-
}
115-
else {
116-
message.writeTo((OutputStream) bos);
117-
}
118-
byte[] bytes = bos.toByteArrayUnsafe();
112+
writeMessage(message, delimited, outputStream);
113+
byte[] bytes = outputStream.toByteArrayUnsafe();
119114
return bufferFactory.wrap(bytes);
120115
}
121116
catch (IOException ex) {
122117
throw new IllegalStateException("Unexpected I/O error while writing to data buffer", ex);
123118
}
124119
}
125120

121+
/**
122+
* Use write methods on {@link Message} to write to the given {@code OutputStream}.
123+
* @since 7.0
124+
*/
125+
protected void writeMessage(
126+
Message message, boolean delimited, OutputStream outputStream) throws IOException {
127+
128+
if (delimited) {
129+
message.writeDelimitedTo(outputStream);
130+
}
131+
else {
132+
message.writeTo(outputStream);
133+
}
134+
}
135+
126136
}

spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,30 +77,47 @@ public ProtobufHttpMessageWriter(Encoder<Message> encoder) {
7777
@SuppressWarnings("unchecked")
7878
@Override
7979
public Mono<Void> write(Publisher<? extends Message> inputStream, ResolvableType elementType,
80-
@Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map<String, Object> hints) {
80+
@Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage, Map<String, Object> hints) {
8181

8282
try {
8383
Message.Builder builder = getMessageBuilder(elementType.toClass());
8484
Descriptors.Descriptor descriptor = builder.getDescriptorForType();
85-
message.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName());
86-
message.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName());
85+
outputMessage.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName());
86+
outputMessage.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName());
8787
if (inputStream instanceof Flux) {
88-
if (mediaType == null) {
89-
message.getHeaders().setContentType(((HttpMessageEncoder<?>)getEncoder()).getStreamingMediaTypes().get(0));
90-
}
91-
else if (!ProtobufEncoder.DELIMITED_VALUE.equals(mediaType.getParameters().get(ProtobufEncoder.DELIMITED_KEY))) {
92-
Map<String, String> parameters = new HashMap<>(mediaType.getParameters());
93-
parameters.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
94-
message.getHeaders().setContentType(new MediaType(mediaType.getType(), mediaType.getSubtype(), parameters));
95-
}
88+
outputMessage.getHeaders().setContentType(getStreamingContentType(mediaType));
9689
}
97-
return super.write(inputStream, elementType, mediaType, message, hints);
90+
extendHeaders(outputMessage, hints);
91+
return super.write(inputStream, elementType, mediaType, outputMessage, hints);
9892
}
9993
catch (Exception ex) {
10094
return Mono.error(new EncodingException("Could not write Protobuf message: " + ex.getMessage(), ex));
10195
}
10296
}
10397

98+
/**
99+
* Return the {@code MediaType} to use when the input Publisher is multivalued.
100+
* @since 7.0
101+
*/
102+
protected MediaType getStreamingContentType(@Nullable MediaType mediaType) {
103+
if (mediaType == null) {
104+
return ((HttpMessageEncoder<?>) getEncoder()).getStreamingMediaTypes().get(0);
105+
}
106+
Map<String, String> params = new HashMap<>(mediaType.getParameters());
107+
if (!ProtobufEncoder.DELIMITED_VALUE.equals(params.get(ProtobufEncoder.DELIMITED_KEY))) {
108+
params.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
109+
mediaType = new MediaType(mediaType, params);
110+
}
111+
return mediaType;
112+
}
113+
114+
/**
115+
* Make further updates to headers.
116+
* @since 7.0
117+
*/
118+
protected void extendHeaders(ReactiveHttpOutputMessage message, Map<String, Object> hints) {
119+
}
120+
104121
/**
105122
* Create a new {@code Message.Builder} instance for the given class.
106123
* <p>This method uses a ConcurrentHashMap for caching method lookups.

0 commit comments

Comments
 (0)