Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* @author Juergen Hoeller
* @author Rossen Stoyanchev
* @author Sam Brannen
* @author Simone Conte
* @since 06.12.2003
*/
public abstract class OncePerRequestFilter extends GenericFilterBean {
Expand All @@ -83,6 +84,7 @@ public abstract class OncePerRequestFilter extends GenericFilterBean {
* attribute is already there.
* @see #getAlreadyFilteredAttributeName
* @see #shouldNotFilter
* @see #shouldFilter
* @see #doFilterInternal
*/
@Override
Expand All @@ -109,7 +111,7 @@ else if (hasAlreadyFilteredAttribute) {
// Proceed without invoking this filter...
filterChain.doFilter(request, response);
}
else {
else if (shouldFilter(httpRequest)) {
// Do invoke this filter...
request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE);
try {
Expand All @@ -120,6 +122,10 @@ else if (hasAlreadyFilteredAttribute) {
request.removeAttribute(alreadyFilteredAttributeName);
}
}
else {
// Proceed without invoking this filter...
filterChain.doFilter(request, response);
}
}

private boolean skipDispatch(HttpServletRequest request) {
Expand Down Expand Up @@ -173,6 +179,24 @@ protected String getAlreadyFilteredAttributeName() {
return name + ALREADY_FILTERED_SUFFIX;
}

/**
* Can be overridden in subclasses for custom filtering control,
* returning {@code true} to allow filtering of the given request.
* <p>This method is called after {@link #shouldNotFilter(HttpServletRequest)}
* and only if that method returns {@code false}. Both methods provide
* complementary ways to control filter execution.
* <p>The default implementation always returns {@code true}, meaning
* all requests will be filtered unless explicitly excluded by
* {@link #shouldNotFilter(HttpServletRequest)}.
* @param request current HTTP request
* @return whether the given request should be filtered
* @throws ServletException in case of errors
* @since 7.0
*/
protected boolean shouldFilter(HttpServletRequest request) throws ServletException {
return true;
}

/**
* Can be overridden in subclasses for custom filtering control,
* returning {@code true} to avoid filtering of the given request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.web.filter;

import java.io.IOException;
import java.util.stream.Stream;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.FilterChain;
Expand All @@ -26,18 +27,26 @@
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.web.testfixture.servlet.MockFilterChain;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.util.WebUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;

/**
* Tests for {@link OncePerRequestFilter}.
*
* @author Rossen Stoyanchev
* @author Simone Conte
* @since 5.1.9
*/
class OncePerRequestFilterTests {
Expand All @@ -46,6 +55,8 @@ class OncePerRequestFilterTests {

private MockHttpServletRequest request;

private MockHttpServletResponse response = new MockHttpServletResponse();

private MockFilterChain filterChain;


Expand All @@ -56,37 +67,48 @@ public void setup() throws Exception {
this.request.setScheme("http");
this.request.setServerName("localhost");
this.request.setServerPort(80);
this.filterChain = new MockFilterChain(new HttpServlet() {});
this.filterChain = spy(new MockFilterChain(new HttpServlet() {}));
}

@Test
void filterInternal() throws ServletException, IOException {
this.filter.doFilter(this.request, this.response, this.filterChain);

assertThat(this.filter.didFilter).isTrue();
assertThat(this.filter.didFilterNestedErrorDispatch).isFalse();
verifyNoInteractions(this.filterChain);
}

@Test
void filterOnce() throws ServletException, IOException {

// Already filtered
this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE);

this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.filter.didFilter).isFalse();
assertThat(this.filter.didFilterNestedErrorDispatch).isFalse();
verify(this.filterChain).doFilter(this.request, this.response);

// Remove already filtered
this.request.removeAttribute(this.filter.getAlreadyFilteredAttributeName());
this.filter.reset();

this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.filter.didFilter).isTrue();
assertThat(this.filter.didFilterNestedErrorDispatch).isFalse();
verify(this.filterChain).doFilter(this.request, this.response);
}

@Test
void shouldNotFilterErrorDispatch() throws ServletException, IOException {

initErrorDispatch();

this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.filter.didFilter).isFalse();
assertThat(this.filter.didFilterNestedErrorDispatch).isFalse();
verify(this.filterChain).doFilter(this.request, this.response);
}

@Test
Expand All @@ -95,9 +117,10 @@ void shouldNotFilterNestedErrorDispatch() throws ServletException, IOException {
initErrorDispatch();
this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE);

this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.filter.didFilter).isFalse();
assertThat(this.filter.didFilterNestedErrorDispatch).isFalse();
verify(this.filterChain).doFilter(this.request, this.response);
}

@Test // gh-23196
Expand All @@ -109,9 +132,61 @@ public void filterNestedErrorDispatch() throws ServletException, IOException {
this.request.setAttribute(this.filter.getAlreadyFilteredAttributeName(), Boolean.TRUE);
initErrorDispatch();

this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
this.filter.doFilter(this.request, this.response, this.filterChain);
assertThat(this.filter.didFilter).isFalse();
assertThat(this.filter.didFilterNestedErrorDispatch).isTrue();
verify(this.filterChain).doFilter(this.request, this.response);
}

@ParameterizedTest
@MethodSource
public void shouldNotFilterForUri(String uri, boolean shouldFilterInternal) throws ServletException, IOException {
ShouldNotFilterRequestFilter filter = new ShouldNotFilterRequestFilter();

this.request.setRequestURI(uri);

filter.doFilter(this.request, this.response, this.filterChain);

assertThat(filter.didFilter).isEqualTo(shouldFilterInternal);
verify(this.filterChain, times(shouldFilterInternal ? 0 : 1)).doFilter(this.request, this.response);
}

static Stream<Arguments> shouldNotFilterForUri() {
return Stream.of(
Arguments.of("/skip", false),
Arguments.of("/skip/something", false),
Arguments.of("//skip", true),
Arguments.of("", true),
Arguments.of("/", true),
Arguments.of("/do_not_skip", true),
Arguments.of("//do_not_skip", true)
);
}

@ParameterizedTest
@MethodSource
public void shouldFilterForUri(String uri, boolean shouldFilterInternal) throws ServletException, IOException {
ShouldFilterRequestFilter filter = new ShouldFilterRequestFilter();

this.request.setRequestURI(uri);

filter.doFilter(this.request, this.response, this.filterChain);

assertThat(filter.didFilter).isEqualTo(shouldFilterInternal);
verify(this.filterChain, times(shouldFilterInternal ? 0 : 1)).doFilter(this.request, this.response);
}

static Stream<Arguments> shouldFilterForUri() {
return Stream.of(
Arguments.of("/skip", false),
Arguments.of("/skip/something", false),
Arguments.of("//skip", false),
Arguments.of("", false),
Arguments.of("/", false),
Arguments.of("/do_not_skip", true),
Arguments.of("//do_not_skip", false),
Arguments.of("/do_not_skip/something", true)
);
}

private void initErrorDispatch() {
Expand Down Expand Up @@ -174,4 +249,38 @@ protected void doFilterNestedErrorDispatch(HttpServletRequest request, HttpServl
}
}

private static class ShouldNotFilterRequestFilter extends OncePerRequestFilter {

private boolean didFilter;

@Override
protected boolean shouldNotFilter(HttpServletRequest request) {
return request.getRequestURI().startsWith("/skip");
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) {

this.didFilter = true;
}
}

private static class ShouldFilterRequestFilter extends OncePerRequestFilter {

private boolean didFilter;

@Override
protected boolean shouldFilter(HttpServletRequest request) {
return request.getRequestURI().startsWith("/do_not_skip");
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) {

this.didFilter = true;
}
}

}