Skip to content

Commit 9df735b

Browse files
committed
Merge pull request #30028 from yuzawa-san:request-predicate-commit
* gh-30028: Polishing external contribution Improve attribute handling in RequestPredicates
2 parents ed172d6 + c5c8436 commit 9df735b

File tree

3 files changed

+300
-109
lines changed

3 files changed

+300
-109
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.security.Principal;
2222
import java.util.Arrays;
2323
import java.util.Collections;
24-
import java.util.HashMap;
2524
import java.util.LinkedHashMap;
2625
import java.util.LinkedHashSet;
2726
import java.util.List;
@@ -296,11 +295,6 @@ private static void traceMatch(String prefix, Object desired, @Nullable Object a
296295
}
297296
}
298297

299-
private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
300-
request.attributes().clear();
301-
request.attributes().putAll(attributes);
302-
}
303-
304298
private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
305299
Map<String, String> newVariables) {
306300

@@ -432,13 +426,94 @@ public interface Visitor {
432426
}
433427

434428

429+
/**
430+
* Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}.
431+
*/
432+
static abstract class RequestModifyingPredicate implements RequestPredicate {
433+
434+
435+
public static RequestModifyingPredicate of(RequestPredicate requestPredicate) {
436+
if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) {
437+
return modifyingPredicate;
438+
}
439+
else {
440+
return new RequestModifyingPredicate() {
441+
@Override
442+
protected Result testInternal(ServerRequest request) {
443+
return Result.of(requestPredicate.test(request));
444+
}
445+
};
446+
}
447+
}
448+
449+
450+
@Override
451+
public final boolean test(ServerRequest request) {
452+
Result result = testInternal(request);
453+
boolean value = result.value();
454+
if (value) {
455+
result.modify(request);
456+
}
457+
return value;
458+
}
459+
460+
protected abstract Result testInternal(ServerRequest request);
461+
462+
463+
protected static final class Result {
464+
465+
private static final Result TRUE = new Result(true, null);
466+
467+
private static final Result FALSE = new Result(false, null);
468+
469+
470+
private final boolean value;
471+
472+
@Nullable
473+
private final Consumer<ServerRequest> modify;
474+
475+
476+
private Result(boolean value, @Nullable Consumer<ServerRequest> modify) {
477+
this.value = value;
478+
this.modify = modify;
479+
}
480+
481+
482+
public static Result of(boolean value) {
483+
return of(value, null);
484+
}
485+
486+
public static Result of(boolean value, @Nullable Consumer<ServerRequest> commit) {
487+
if (commit == null) {
488+
return value ? TRUE : FALSE;
489+
}
490+
else {
491+
return new Result(value, commit);
492+
}
493+
}
494+
495+
496+
public boolean value() {
497+
return this.value;
498+
}
499+
500+
public void modify(ServerRequest request) {
501+
if (this.modify != null) {
502+
this.modify.accept(request);
503+
}
504+
}
505+
}
506+
507+
}
508+
509+
435510
private static class HttpMethodPredicate implements RequestPredicate {
436511

437512
private final Set<HttpMethod> httpMethods;
438513

439514
public HttpMethodPredicate(HttpMethod httpMethod) {
440515
Assert.notNull(httpMethod, "HttpMethod must not be null");
441-
this.httpMethods = Collections.singleton(httpMethod);
516+
this.httpMethods = Set.of(httpMethod);
442517
}
443518

444519
public HttpMethodPredicate(HttpMethod... httpMethods) {
@@ -482,39 +557,41 @@ public String toString() {
482557
}
483558

484559

485-
private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
560+
private static class PathPatternPredicate extends RequestModifyingPredicate
561+
implements ChangePathPatternParserVisitor.Target {
486562

487563
private PathPattern pattern;
488564

565+
489566
public PathPatternPredicate(PathPattern pattern) {
490567
Assert.notNull(pattern, "'pattern' must not be null");
491568
this.pattern = pattern;
492569
}
493570

571+
494572
@Override
495-
public boolean test(ServerRequest request) {
573+
protected Result testInternal(ServerRequest request) {
496574
PathContainer pathContainer = request.requestPath().pathWithinApplication();
497575
PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
498576
traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
499577
if (info != null) {
500-
mergeAttributes(request, info.getUriVariables(), this.pattern);
501-
return true;
578+
return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables()));
502579
}
503580
else {
504-
return false;
581+
return Result.of(false);
505582
}
506583
}
507584

508-
private static void mergeAttributes(ServerRequest request, Map<String, String> variables,
509-
PathPattern pattern) {
585+
private void mergeAttributes(ServerRequest request, Map<String, String> variables) {
586+
Map<String, Object> attributes = request.attributes();
510587
Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
511-
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
512-
Collections.unmodifiableMap(pathVariables));
588+
attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
589+
Collections.unmodifiableMap(pathVariables));
513590

514-
pattern = mergePatterns(
515-
(PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
516-
pattern);
517-
request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
591+
PathPattern pattern = mergePatterns(
592+
(PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
593+
this.pattern);
594+
attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
518595
}
519596

520597
@Override
@@ -756,28 +833,42 @@ public String toString() {
756833
* {@link RequestPredicate} for where both {@code left} and {@code right} predicates
757834
* must match.
758835
*/
759-
static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
836+
static class AndRequestPredicate extends RequestModifyingPredicate
837+
implements ChangePathPatternParserVisitor.Target {
760838

761839
private final RequestPredicate left;
762840

841+
private final RequestModifyingPredicate leftModifying;
842+
763843
private final RequestPredicate right;
764844

845+
private final RequestModifyingPredicate rightModifying;
846+
847+
765848
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
766849
Assert.notNull(left, "Left RequestPredicate must not be null");
767850
Assert.notNull(right, "Right RequestPredicate must not be null");
768851
this.left = left;
852+
this.leftModifying = of(left);
769853
this.right = right;
854+
this.rightModifying = of(right);
770855
}
771856

772-
@Override
773-
public boolean test(ServerRequest request) {
774-
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
775857

776-
if (this.left.test(request) && this.right.test(request)) {
777-
return true;
858+
@Override
859+
protected Result testInternal(ServerRequest request) {
860+
Result leftResult = this.leftModifying.testInternal(request);
861+
if (!leftResult.value()) {
862+
return leftResult;
863+
}
864+
Result rightResult = this.rightModifying.testInternal(request);
865+
if (!rightResult.value()) {
866+
return rightResult;
778867
}
779-
restoreAttributes(request, oldAttributes);
780-
return false;
868+
return Result.of(true, serverRequest -> {
869+
leftResult.modify(serverRequest);
870+
rightResult.modify(serverRequest);
871+
});
781872
}
782873

783874
@Override
@@ -796,11 +887,11 @@ public void accept(Visitor visitor) {
796887

797888
@Override
798889
public void changeParser(PathPatternParser parser) {
799-
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) {
800-
leftTarget.changeParser(parser);
890+
if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
891+
target.changeParser(parser);
801892
}
802-
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) {
803-
rightTarget.changeParser(parser);
893+
if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
894+
target.changeParser(parser);
804895
}
805896
}
806897

@@ -814,23 +905,25 @@ public String toString() {
814905
/**
815906
* {@link RequestPredicate} that negates a delegate predicate.
816907
*/
817-
static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
908+
static class NegateRequestPredicate extends RequestModifyingPredicate
909+
implements ChangePathPatternParserVisitor.Target {
818910

819911
private final RequestPredicate delegate;
820912

913+
private final RequestModifyingPredicate delegateModifying;
914+
915+
821916
public NegateRequestPredicate(RequestPredicate delegate) {
822917
Assert.notNull(delegate, "Delegate must not be null");
823918
this.delegate = delegate;
919+
this.delegateModifying = of(delegate);
824920
}
825921

922+
826923
@Override
827-
public boolean test(ServerRequest request) {
828-
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
829-
boolean result = !this.delegate.test(request);
830-
if (!result) {
831-
restoreAttributes(request, oldAttributes);
832-
}
833-
return result;
924+
protected Result testInternal(ServerRequest request) {
925+
Result result = this.delegateModifying.testInternal(request);
926+
return Result.of(!result.value(), result::modify);
834927
}
835928

836929
@Override
@@ -858,34 +951,36 @@ public String toString() {
858951
* {@link RequestPredicate} where either {@code left} or {@code right} predicates
859952
* may match.
860953
*/
861-
static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
954+
static class OrRequestPredicate extends RequestModifyingPredicate
955+
implements ChangePathPatternParserVisitor.Target {
862956

863957
private final RequestPredicate left;
864958

959+
private final RequestModifyingPredicate leftModifying;
960+
865961
private final RequestPredicate right;
866962

963+
private final RequestModifyingPredicate rightModifying;
964+
965+
867966
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
868967
Assert.notNull(left, "Left RequestPredicate must not be null");
869968
Assert.notNull(right, "Right RequestPredicate must not be null");
870969
this.left = left;
970+
this.leftModifying = of(left);
871971
this.right = right;
972+
this.rightModifying = of(right);
872973
}
873974

874975
@Override
875-
public boolean test(ServerRequest request) {
876-
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
877-
878-
if (this.left.test(request)) {
879-
return true;
976+
protected Result testInternal(ServerRequest request) {
977+
Result leftResult = this.leftModifying.testInternal(request);
978+
if (leftResult.value()) {
979+
return leftResult;
880980
}
881981
else {
882-
restoreAttributes(request, oldAttributes);
883-
if (this.right.test(request)) {
884-
return true;
885-
}
982+
return this.rightModifying.testInternal(request);
886983
}
887-
restoreAttributes(request, oldAttributes);
888-
return false;
889984
}
890985

891986
@Override
@@ -910,11 +1005,11 @@ public void accept(Visitor visitor) {
9101005

9111006
@Override
9121007
public void changeParser(PathPatternParser parser) {
913-
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) {
914-
leftTarget.changeParser(parser);
1008+
if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
1009+
target.changeParser(parser);
9151010
}
916-
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) {
917-
rightTarget.changeParser(parser);
1011+
if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
1012+
target.changeParser(parser);
9181013
}
9191014
}
9201015

spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicateAttributesTests.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,25 @@ public void orBothFail() {
182182
}
183183

184184

185-
private static class AddAttributePredicate implements RequestPredicate {
185+
private static class AddAttributePredicate extends RequestPredicates.RequestModifyingPredicate {
186186

187-
private boolean result;
187+
private final boolean result;
188188

189189
private final String key;
190190

191191
private final String value;
192192

193-
private AddAttributePredicate(boolean result, String key, String value) {
193+
194+
public AddAttributePredicate(boolean result, String key, String value) {
194195
this.result = result;
195196
this.key = key;
196197
this.value = value;
197198
}
198199

200+
199201
@Override
200-
public boolean test(ServerRequest request) {
201-
request.attributes().put(key, value);
202-
return this.result;
202+
protected Result testInternal(ServerRequest request) {
203+
return Result.of(this.result, serverRequest -> serverRequest.attributes().put(this.key, this.value));
203204
}
204205
}
205206

0 commit comments

Comments
 (0)