Skip to content

Commit e29bc3d

Browse files
committed
Refactor PartGenerator to use isLast
This commit refactors the PartGenerator to use the newly introduced Token::isLast property. See gh-28006
1 parent d44ba0a commit e29bc3d

File tree

3 files changed

+50
-80
lines changed

3 files changed

+50
-80
lines changed

spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReader.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Collections;
2424
import java.util.List;
2525
import java.util.Map;
26+
import java.util.concurrent.atomic.AtomicInteger;
2627

2728
import reactor.core.publisher.Flux;
2829
import reactor.core.publisher.Mono;
@@ -222,12 +223,30 @@ public Flux<Part> read(ResolvableType elementType, ReactiveHttpInputMessage mess
222223
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
223224
message.getHeaders().getContentType() + "\""));
224225
}
225-
Flux<MultipartParser.Token> tokens = MultipartParser.parse(message.getBody(), boundary,
226+
Flux<MultipartParser.Token> allPartsTokens = MultipartParser.parse(message.getBody(), boundary,
226227
this.maxHeadersSize, this.headersCharset);
227228

228-
return PartGenerator.createParts(tokens, this.maxParts, this.maxInMemorySize, this.maxDiskUsagePerPart,
229-
this.streaming, this.fileStorage.directory(), this.blockingOperationScheduler);
229+
AtomicInteger partCount = new AtomicInteger();
230+
return allPartsTokens
231+
.windowUntil(MultipartParser.Token::isLast)
232+
.concatMap(partsTokens -> {
233+
if (tooManyParts(partCount)) {
234+
return Mono.error(new DecodingException("Too many parts (" + partCount.get() + "/" +
235+
this.maxParts + " allowed)"));
236+
}
237+
else {
238+
return PartGenerator.createPart(partsTokens,
239+
this.maxInMemorySize, this.maxDiskUsagePerPart, this.streaming,
240+
this.fileStorage.directory(), this.blockingOperationScheduler);
241+
}
242+
});
230243
});
231244
}
232245

246+
private boolean tooManyParts(AtomicInteger partCount) {
247+
int count = partCount.incrementAndGet();
248+
return this.maxParts > 0 && count > this.maxParts;
249+
}
250+
251+
233252
}

spring-web/src/main/java/org/springframework/http/codec/multipart/PartGenerator.java

Lines changed: 24 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import java.util.Queue;
3131
import java.util.concurrent.ConcurrentLinkedQueue;
3232
import java.util.concurrent.atomic.AtomicBoolean;
33-
import java.util.concurrent.atomic.AtomicInteger;
3433
import java.util.concurrent.atomic.AtomicLong;
3534
import java.util.concurrent.atomic.AtomicReference;
3635

@@ -41,10 +40,10 @@
4140
import reactor.core.publisher.Flux;
4241
import reactor.core.publisher.FluxSink;
4342
import reactor.core.publisher.Mono;
43+
import reactor.core.publisher.MonoSink;
4444
import reactor.core.scheduler.Scheduler;
4545
import reactor.util.context.Context;
4646

47-
import org.springframework.core.codec.DecodingException;
4847
import org.springframework.core.io.buffer.DataBuffer;
4948
import org.springframework.core.io.buffer.DataBufferLimitException;
5049
import org.springframework.core.io.buffer.DataBufferUtils;
@@ -65,13 +64,9 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
6564

6665
private final AtomicReference<State> state = new AtomicReference<>(new InitialState());
6766

68-
private final AtomicInteger partCount = new AtomicInteger();
69-
7067
private final AtomicBoolean requestOutstanding = new AtomicBoolean();
7168

72-
private final FluxSink<Part> sink;
73-
74-
private final int maxParts;
69+
private final MonoSink<Part> sink;
7570

7671
private final boolean streaming;
7772

@@ -84,11 +79,10 @@ final class PartGenerator extends BaseSubscriber<MultipartParser.Token> {
8479
private final Scheduler blockingOperationScheduler;
8580

8681

87-
private PartGenerator(FluxSink<Part> sink, int maxParts, int maxInMemorySize, long maxDiskUsagePerPart,
82+
private PartGenerator(MonoSink<Part> sink, int maxInMemorySize, long maxDiskUsagePerPart,
8883
boolean streaming, Mono<Path> fileStorageDirectory, Scheduler blockingOperationScheduler) {
8984

9085
this.sink = sink;
91-
this.maxParts = maxParts;
9286
this.maxInMemorySize = maxInMemorySize;
9387
this.maxDiskUsagePerPart = maxDiskUsagePerPart;
9488
this.streaming = streaming;
@@ -99,15 +93,15 @@ private PartGenerator(FluxSink<Part> sink, int maxParts, int maxInMemorySize, lo
9993
/**
10094
* Creates parts from a given stream of tokens.
10195
*/
102-
public static Flux<Part> createParts(Flux<MultipartParser.Token> tokens, int maxParts, int maxInMemorySize,
96+
public static Mono<Part> createPart(Flux<MultipartParser.Token> tokens, int maxInMemorySize,
10397
long maxDiskUsagePerPart, boolean streaming, Mono<Path> fileStorageDirectory,
10498
Scheduler blockingOperationScheduler) {
10599

106-
return Flux.create(sink -> {
107-
PartGenerator generator = new PartGenerator(sink, maxParts, maxInMemorySize, maxDiskUsagePerPart, streaming,
100+
return Mono.create(sink -> {
101+
PartGenerator generator = new PartGenerator(sink, maxInMemorySize, maxDiskUsagePerPart, streaming,
108102
fileStorageDirectory, blockingOperationScheduler);
109103

110-
sink.onCancel(generator::onSinkCancel);
104+
sink.onCancel(generator);
111105
sink.onRequest(l -> generator.requestToken());
112106
tokens.subscribe(generator);
113107
});
@@ -128,13 +122,6 @@ protected void hookOnNext(MultipartParser.Token token) {
128122
this.requestOutstanding.set(false);
129123
State state = this.state.get();
130124
if (token instanceof MultipartParser.HeadersToken) {
131-
// finish previous part
132-
state.partComplete(false);
133-
134-
if (tooManyParts()) {
135-
return;
136-
}
137-
138125
newPart(state, token.headers());
139126
}
140127
else {
@@ -144,11 +131,11 @@ protected void hookOnNext(MultipartParser.Token token) {
144131

145132
private void newPart(State currentState, HttpHeaders headers) {
146133
if (MultipartUtils.isFormField(headers)) {
147-
changeStateInternal(new FormFieldState(headers));
134+
changeState(currentState, new FormFieldState(headers));
148135
requestToken();
149136
}
150137
else if (!this.streaming) {
151-
changeStateInternal(new InMemoryState(headers));
138+
changeState(currentState, new InMemoryState(headers));
152139
requestToken();
153140
}
154141
else {
@@ -165,7 +152,7 @@ else if (!this.streaming) {
165152

166153
@Override
167154
protected void hookOnComplete() {
168-
this.state.get().partComplete(true);
155+
this.state.get().onComplete();
169156
}
170157

171158
@Override
@@ -175,7 +162,8 @@ protected void hookOnError(Throwable throwable) {
175162
this.sink.error(throwable);
176163
}
177164

178-
private void onSinkCancel() {
165+
@Override
166+
public void dispose() {
179167
changeStateInternal(DisposedState.INSTANCE);
180168
cancel();
181169
}
@@ -211,39 +199,21 @@ void emitPart(Part part) {
211199
if (logger.isTraceEnabled()) {
212200
logger.trace("Emitting: " + part);
213201
}
214-
this.sink.next(part);
202+
this.sink.success(part);
215203
}
216204

217-
void emitComplete() {
218-
this.sink.complete();
219-
}
220-
221-
222205
void emitError(Throwable t) {
223206
cancel();
224207
this.sink.error(t);
225208
}
226209

227210
void requestToken() {
228211
if (upstream() != null &&
229-
!this.sink.isCancelled() &&
230-
this.sink.requestedFromDownstream() > 0 &&
231212
this.requestOutstanding.compareAndSet(false, true)) {
232213
request(1);
233214
}
234215
}
235216

236-
private boolean tooManyParts() {
237-
int count = this.partCount.incrementAndGet();
238-
if (this.maxParts > 0 && count > this.maxParts) {
239-
emitError(new DecodingException("Too many parts (" + count + "/" + this.maxParts + " allowed)"));
240-
return true;
241-
}
242-
else {
243-
return false;
244-
}
245-
}
246-
247217
/**
248218
* Represents the internal state of the {@link PartGenerator} for
249219
* creating a single {@link Part}.
@@ -273,10 +243,8 @@ private interface State {
273243

274244
/**
275245
* Invoked when all tokens for the part have been received.
276-
* @param finalPart {@code true} if this was the last part (and
277-
* {@link #emitComplete()} should be called; {@code false} otherwise
278246
*/
279-
void partComplete(boolean finalPart);
247+
void onComplete();
280248

281249
/**
282250
* Invoked when an error has been received.
@@ -307,10 +275,7 @@ public void body(DataBuffer dataBuffer) {
307275
}
308276

309277
@Override
310-
public void partComplete(boolean finalPart) {
311-
if (finalPart) {
312-
emitComplete();
313-
}
278+
public void onComplete() {
314279
}
315280

316281
@Override
@@ -364,13 +329,10 @@ private void store(DataBuffer dataBuffer) {
364329
}
365330

366331
@Override
367-
public void partComplete(boolean finalPart) {
332+
public void onComplete() {
368333
byte[] bytes = this.value.toByteArrayUnsafe();
369334
String value = new String(bytes, MultipartUtils.charset(this.headers));
370335
emitPart(DefaultParts.formFieldPart(this.headers, value));
371-
if (finalPart) {
372-
emitComplete();
373-
}
374336
}
375337

376338
@Override
@@ -410,13 +372,10 @@ public void body(DataBuffer dataBuffer) {
410372
}
411373

412374
@Override
413-
public void partComplete(boolean finalPart) {
375+
public void onComplete() {
414376
if (!this.bodySink.isCancelled()) {
415377
this.bodySink.complete();
416378
}
417-
if (finalPart) {
418-
emitComplete();
419-
}
420379
}
421380

422381
@Override
@@ -493,11 +452,8 @@ private void switchToFile(DataBuffer current, long byteCount) {
493452
}
494453

495454
@Override
496-
public void partComplete(boolean finalPart) {
455+
public void onComplete() {
497456
emitMemoryPart();
498-
if (finalPart) {
499-
emitComplete();
500-
}
501457
}
502458

503459
private void emitMemoryPart() {
@@ -545,8 +501,6 @@ private final class CreateFileState implements State {
545501

546502
private volatile boolean completed;
547503

548-
private volatile boolean finalPart;
549-
550504
private volatile boolean releaseOnDispose = true;
551505

552506

@@ -563,9 +517,8 @@ public void body(DataBuffer dataBuffer) {
563517
}
564518

565519
@Override
566-
public void partComplete(boolean finalPart) {
520+
public void onComplete() {
567521
this.completed = true;
568-
this.finalPart = finalPart;
569522
}
570523

571524
public void createFile() {
@@ -597,7 +550,7 @@ private void fileCreated(WritingFileState newState) {
597550
newState.writeBuffers(this.content);
598551

599552
if (this.completed) {
600-
newState.partComplete(this.finalPart);
553+
newState.onComplete();
601554
}
602555
}
603556
else {
@@ -665,12 +618,9 @@ public void body(DataBuffer dataBuffer) {
665618
}
666619

667620
@Override
668-
public void partComplete(boolean finalPart) {
621+
public void onComplete() {
669622
MultipartUtils.closeChannel(this.channel);
670623
emitPart(DefaultParts.part(this.headers, this.file, PartGenerator.this.blockingOperationScheduler));
671-
if (finalPart) {
672-
emitComplete();
673-
}
674624
}
675625

676626
@Override
@@ -701,8 +651,6 @@ private final class WritingFileState implements State {
701651

702652
private volatile boolean completed;
703653

704-
private volatile boolean finalPart;
705-
706654

707655
public WritingFileState(CreateFileState state, Path file, WritableByteChannel channel) {
708656
this.headers = state.headers;
@@ -725,9 +673,8 @@ public void body(DataBuffer dataBuffer) {
725673
}
726674

727675
@Override
728-
public void partComplete(boolean finalPart) {
676+
public void onComplete() {
729677
this.completed = true;
730-
this.finalPart = finalPart;
731678
}
732679

733680
public void writeBuffer(DataBuffer dataBuffer) {
@@ -752,7 +699,7 @@ public void writeBuffers(Iterable<DataBuffer> dataBuffers) {
752699
private void writeComplete() {
753700
IdleFileState newState = new IdleFileState(this);
754701
if (this.completed) {
755-
newState.partComplete(this.finalPart);
702+
newState.onComplete();
756703
}
757704
else if (changeState(this, newState)) {
758705
requestToken();
@@ -799,7 +746,7 @@ public void body(DataBuffer dataBuffer) {
799746
}
800747

801748
@Override
802-
public void partComplete(boolean finalPart) {
749+
public void onComplete() {
803750
}
804751

805752
@Override

spring-web/src/test/java/org/springframework/http/codec/multipart/DefaultPartHttpMessageReaderTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ void noEndBoundary(DefaultPartHttpMessageReader reader) {
118118
Flux<Part> result = reader.read(forClass(Part.class), request, emptyMap());
119119

120120
StepVerifier.create(result)
121+
.consumeNextWith(part -> {
122+
assertThat(part.headers().getFirst("Header")).isEqualTo("Value");
123+
part.content().subscribe(DataBufferUtils::release);
124+
})
121125
.expectError(DecodingException.class)
122126
.verify();
123127
}

0 commit comments

Comments
 (0)