diff --git a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java index a40e0925de26..711cc903e638 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java @@ -65,6 +65,7 @@ * @author Juergen Hoeller * @author Rossen Stoyanchev * @author Sam Brannen + * @author Simone Conte * @since 06.12.2003 */ public abstract class OncePerRequestFilter extends GenericFilterBean { @@ -83,6 +84,7 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { * attribute is already there. * @see #getAlreadyFilteredAttributeName * @see #shouldNotFilter + * @see #shouldFilter * @see #doFilterInternal */ @Override @@ -109,7 +111,7 @@ else if (hasAlreadyFilteredAttribute) { // Proceed without invoking this filter... filterChain.doFilter(request, response); } - else { + else if (shouldFilter(httpRequest)) { // Do invoke this filter... request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); try { @@ -120,6 +122,10 @@ else if (hasAlreadyFilteredAttribute) { request.removeAttribute(alreadyFilteredAttributeName); } } + else { + // Proceed without invoking this filter... + filterChain.doFilter(request, response); + } } private boolean skipDispatch(HttpServletRequest request) { @@ -173,6 +179,24 @@ protected String getAlreadyFilteredAttributeName() { return name + ALREADY_FILTERED_SUFFIX; } + /** + * Can be overridden in subclasses for custom filtering control, + * returning {@code true} to allow filtering of the given request. + *

This method is called after {@link #shouldNotFilter(HttpServletRequest)} + * and only if that method returns {@code false}. Both methods provide + * complementary ways to control filter execution. + *

The default implementation always returns {@code true}, meaning + * all requests will be filtered unless explicitly excluded by + * {@link #shouldNotFilter(HttpServletRequest)}. + * @param request current HTTP request + * @return whether the given request should be filtered + * @throws ServletException in case of errors + * @since 7.0 + */ + protected boolean shouldFilter(HttpServletRequest request) throws ServletException { + return true; + } + /** * Can be overridden in subclasses for custom filtering control, * returning {@code true} to avoid filtering of the given request. diff --git a/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java index ab2a206c5278..44843c647620 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/OncePerRequestFilterTests.java @@ -17,6 +17,7 @@ package org.springframework.web.filter; import java.io.IOException; +import java.util.stream.Stream; import jakarta.servlet.DispatcherType; import jakarta.servlet.FilterChain; @@ -26,6 +27,9 @@ import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -33,11 +37,16 @@ import org.springframework.web.util.WebUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link OncePerRequestFilter}. * * @author Rossen Stoyanchev + * @author Simone Conte * @since 5.1.9 */ class OncePerRequestFilterTests { @@ -46,6 +55,8 @@ class OncePerRequestFilterTests { private MockHttpServletRequest request; + private MockHttpServletResponse response = new MockHttpServletResponse(); + private MockFilterChain filterChain; @@ -56,9 +67,17 @@ public void setup() throws Exception { this.request.setScheme("http"); this.request.setServerName("localhost"); this.request.setServerPort(80); - this.filterChain = new MockFilterChain(new HttpServlet() {}); + this.filterChain = spy(new MockFilterChain(new HttpServlet() {})); } + @Test + void filterInternal() throws ServletException, IOException { + this.filter.doFilter(this.request, this.response, this.filterChain); + + assertThat(this.filter.didFilter).isTrue(); + assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + verifyNoInteractions(this.filterChain); + } @Test void filterOnce() throws ServletException, IOException { @@ -66,17 +85,19 @@ void filterOnce() throws ServletException, IOException { // Already filtered this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); - this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.filter.didFilter).isFalse(); assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + verify(this.filterChain).doFilter(this.request, this.response); // Remove already filtered this.request.removeAttribute(this.filter.getAlreadyFilteredAttributeName()); this.filter.reset(); - this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.filter.didFilter).isTrue(); assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + verify(this.filterChain).doFilter(this.request, this.response); } @Test @@ -84,9 +105,10 @@ void shouldNotFilterErrorDispatch() throws ServletException, IOException { initErrorDispatch(); - this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.filter.didFilter).isFalse(); assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + verify(this.filterChain).doFilter(this.request, this.response); } @Test @@ -95,9 +117,10 @@ void shouldNotFilterNestedErrorDispatch() throws ServletException, IOException { initErrorDispatch(); this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); - this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.filter.didFilter).isFalse(); assertThat(this.filter.didFilterNestedErrorDispatch).isFalse(); + verify(this.filterChain).doFilter(this.request, this.response); } @Test // gh-23196 @@ -109,9 +132,61 @@ public void filterNestedErrorDispatch() throws ServletException, IOException { this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE); initErrorDispatch(); - this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); + this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.filter.didFilter).isFalse(); assertThat(this.filter.didFilterNestedErrorDispatch).isTrue(); + verify(this.filterChain).doFilter(this.request, this.response); + } + + @ParameterizedTest + @MethodSource + public void shouldNotFilterForUri(String uri, boolean shouldFilterInternal) throws ServletException, IOException { + ShouldNotFilterRequestFilter filter = new ShouldNotFilterRequestFilter(); + + this.request.setRequestURI(uri); + + filter.doFilter(this.request, this.response, this.filterChain); + + assertThat(filter.didFilter).isEqualTo(shouldFilterInternal); + verify(this.filterChain, times(shouldFilterInternal ? 0 : 1)).doFilter(this.request, this.response); + } + + static Stream shouldNotFilterForUri() { + return Stream.of( + Arguments.of("/skip", false), + Arguments.of("/skip/something", false), + Arguments.of("//skip", true), + Arguments.of("", true), + Arguments.of("/", true), + Arguments.of("/do_not_skip", true), + Arguments.of("//do_not_skip", true) + ); + } + + @ParameterizedTest + @MethodSource + public void shouldFilterForUri(String uri, boolean shouldFilterInternal) throws ServletException, IOException { + ShouldFilterRequestFilter filter = new ShouldFilterRequestFilter(); + + this.request.setRequestURI(uri); + + filter.doFilter(this.request, this.response, this.filterChain); + + assertThat(filter.didFilter).isEqualTo(shouldFilterInternal); + verify(this.filterChain, times(shouldFilterInternal ? 0 : 1)).doFilter(this.request, this.response); + } + + static Stream shouldFilterForUri() { + return Stream.of( + Arguments.of("/skip", false), + Arguments.of("/skip/something", false), + Arguments.of("//skip", false), + Arguments.of("", false), + Arguments.of("/", false), + Arguments.of("/do_not_skip", true), + Arguments.of("//do_not_skip", false), + Arguments.of("/do_not_skip/something", true) + ); } private void initErrorDispatch() { @@ -174,4 +249,38 @@ protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServl } } + private static class ShouldNotFilterRequestFilter extends OncePerRequestFilter { + + private boolean didFilter; + + @Override + protected boolean shouldNotFilter(HttpServletRequest request) { + return request.getRequestURI().startsWith("/skip"); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { + + this.didFilter = true; + } + } + + private static class ShouldFilterRequestFilter extends OncePerRequestFilter { + + private boolean didFilter; + + @Override + protected boolean shouldFilter(HttpServletRequest request) { + return request.getRequestURI().startsWith("/do_not_skip"); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) { + + this.didFilter = true; + } + } + }