|
18 | 18 |
|
19 | 19 | import java.util.function.Supplier;
|
20 | 20 |
|
| 21 | +import jakarta.servlet.DispatcherType; |
21 | 22 | import jakarta.servlet.FilterChain;
|
22 | 23 | import jakarta.servlet.http.HttpServletRequest;
|
23 | 24 | import jakarta.servlet.http.HttpServletResponse;
|
24 | 25 | import org.junit.jupiter.api.AfterEach;
|
25 | 26 | import org.junit.jupiter.api.BeforeEach;
|
26 | 27 | import org.junit.jupiter.api.Test;
|
27 | 28 | import org.junit.jupiter.api.extension.ExtendWith;
|
| 29 | +import org.junit.jupiter.params.ParameterizedTest; |
| 30 | +import org.junit.jupiter.params.provider.EnumSource; |
28 | 31 | import org.mockito.ArgumentCaptor;
|
29 | 32 | import org.mockito.Captor;
|
| 33 | +import org.mockito.InOrder; |
30 | 34 | import org.mockito.Mock;
|
31 | 35 | import org.mockito.junit.jupiter.MockitoExtension;
|
32 | 36 |
|
| 37 | +import org.springframework.mock.web.MockFilterChain; |
33 | 38 | import org.springframework.security.authentication.TestAuthentication;
|
34 | 39 | import org.springframework.security.core.Authentication;
|
35 | 40 | import org.springframework.security.core.context.SecurityContext;
|
|
39 | 44 |
|
40 | 45 | import static org.assertj.core.api.Assertions.assertThat;
|
41 | 46 | import static org.mockito.BDDMockito.given;
|
| 47 | +import static org.mockito.Mockito.inOrder; |
| 48 | +import static org.mockito.Mockito.lenient; |
| 49 | +import static org.mockito.Mockito.times; |
42 | 50 | import static org.mockito.Mockito.verify;
|
| 51 | +import static org.mockito.Mockito.verifyNoInteractions; |
43 | 52 |
|
44 | 53 | @ExtendWith(MockitoExtension.class)
|
45 | 54 | class SecurityContextHolderFilterTests {
|
46 | 55 |
|
| 56 | + private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED"; |
| 57 | + |
47 | 58 | @Mock
|
48 | 59 | private SecurityContextRepository repository;
|
49 | 60 |
|
@@ -104,14 +115,38 @@ void doFilterThenSetsAndClearsSecurityContextHolderStrategy() throws Exception {
|
104 | 115 | }
|
105 | 116 |
|
106 | 117 | @Test
|
107 |
| - void shouldNotFilterErrorDispatchWhenDefault() { |
108 |
| - assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse(); |
| 118 | + void doFilterWhenFilterAppliedThenDoNothing() throws Exception { |
| 119 | + given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true); |
| 120 | + this.filter.doFilter(this.request, this.response, new MockFilterChain()); |
| 121 | + verify(this.request, times(1)).getAttribute(FILTER_APPLIED); |
| 122 | + verifyNoInteractions(this.repository, this.response); |
109 | 123 | }
|
110 | 124 |
|
111 | 125 | @Test
|
112 |
| - void shouldNotFilterErrorDispatchWhenOverridden() { |
113 |
| - this.filter.setShouldNotFilterErrorDispatch(true); |
114 |
| - assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue(); |
| 126 | + void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception { |
| 127 | + given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn( |
| 128 | + new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy)); |
| 129 | + |
| 130 | + this.filter.doFilter(this.request, this.response, new MockFilterChain()); |
| 131 | + |
| 132 | + InOrder inOrder = inOrder(this.request, this.repository); |
| 133 | + inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true); |
| 134 | + inOrder.verify(this.repository).loadDeferredContext(this.request); |
| 135 | + inOrder.verify(this.request).removeAttribute(FILTER_APPLIED); |
| 136 | + } |
| 137 | + |
| 138 | + @ParameterizedTest |
| 139 | + @EnumSource(DispatcherType.class) |
| 140 | + void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception { |
| 141 | + lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType); |
| 142 | + Authentication authentication = TestAuthentication.authenticatedUser(); |
| 143 | + SecurityContext expectedContext = new SecurityContextImpl(authentication); |
| 144 | + given(this.repository.loadDeferredContext(this.requestArg.capture())) |
| 145 | + .willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy)); |
| 146 | + FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext()) |
| 147 | + .isEqualTo(expectedContext); |
| 148 | + |
| 149 | + this.filter.doFilter(this.request, this.response, filterChain); |
115 | 150 | }
|
116 | 151 |
|
117 | 152 | }
|
0 commit comments