3232import org .junit .Before ;
3333import org .junit .Test ;
3434
35+ import org .mockito .ArgumentCaptor ;
3536import org .springframework .http .HttpStatus ;
3637import org .springframework .http .MediaType ;
3738import org .springframework .mock .web .MockHttpServletRequest ;
3839import org .springframework .mock .web .MockHttpServletResponse ;
40+ import org .springframework .security .authentication .AuthenticationDetailsSource ;
3941import org .springframework .security .authentication .AuthenticationManager ;
4042import org .springframework .security .authentication .TestingAuthenticationToken ;
4143import org .springframework .security .core .Authentication ;
5557import org .springframework .security .web .authentication .AuthenticationConverter ;
5658import org .springframework .security .web .authentication .AuthenticationFailureHandler ;
5759import org .springframework .security .web .authentication .AuthenticationSuccessHandler ;
60+ import org .springframework .security .web .authentication .WebAuthenticationDetails ;
5861import org .springframework .util .StringUtils ;
5962
6063import static org .assertj .core .api .Assertions .assertThat ;
6164import static org .assertj .core .api .Assertions .assertThatThrownBy ;
65+ import static org .assertj .core .api .InstanceOfAssertFactories .type ;
6266import static org .mockito .ArgumentMatchers .any ;
6367import static org .mockito .ArgumentMatchers .same ;
6468import static org .mockito .Mockito .mock ;
7882 */
7983public class OAuth2AuthorizationEndpointFilterTests {
8084 private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize" ;
85+ private static final String REMOTE_ADDRESS = "remote-address" ;
8186 private AuthenticationManager authenticationManager ;
8287 private OAuth2AuthorizationEndpointFilter filter ;
8388 private TestingAuthenticationToken principal ;
@@ -116,6 +121,13 @@ public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentE
116121 .hasMessage ("authorizationEndpointUri cannot be empty" );
117122 }
118123
124+ @ Test
125+ public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException () {
126+ assertThatThrownBy (() -> this .filter .setAuthenticationDetailsSource (null ))
127+ .isInstanceOf (IllegalArgumentException .class )
128+ .hasMessage ("authenticationDetailsSource cannot be null" );
129+ }
130+
119131 @ Test
120132 public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException () {
121133 assertThatThrownBy (() -> this .filter .setAuthenticationConverter (null ))
@@ -364,6 +376,32 @@ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exce
364376 verify (authenticationFailureHandler ).onAuthenticationFailure (any (), any (), same (authenticationException ));
365377 }
366378
379+ @ Test
380+ public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed () throws Exception {
381+ RegisteredClient registeredClient = TestRegisteredClients .registeredClient ().build ();
382+ OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
383+ authorizationCodeRequestAuthentication (registeredClient , this .principal ).build ();
384+ MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
385+
386+ AuthenticationDetailsSource <HttpServletRequest , WebAuthenticationDetails > authenticationDetailsSource =
387+ mock (AuthenticationDetailsSource .class );
388+ WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails (request );
389+ when (authenticationDetailsSource .buildDetails (request )).thenReturn (webAuthenticationDetails );
390+ this .filter .setAuthenticationDetailsSource (authenticationDetailsSource );
391+
392+ when (this .authenticationManager .authenticate (any ()))
393+ .thenReturn (authorizationCodeRequestAuthentication );
394+
395+ MockHttpServletResponse response = new MockHttpServletResponse ();
396+ FilterChain filterChain = mock (FilterChain .class );
397+
398+ this .filter .doFilter (request , response , filterChain );
399+
400+ verify (authenticationDetailsSource ).buildDetails (any ());
401+ verify (this .authenticationManager ).authenticate (any ());
402+ verify (filterChain ).doFilter (any (HttpServletRequest .class ), any (HttpServletResponse .class ));
403+ }
404+
367405 @ Test
368406 public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication () throws Exception {
369407 this .principal .setAuthenticated (false );
@@ -507,9 +545,15 @@ public void doFilterWhenAuthorizationRequestAuthenticatedThenAuthorizationRespon
507545
508546 this .filter .doFilter (request , response , filterChain );
509547
510- verify (this .authenticationManager ).authenticate (any ());
548+ ArgumentCaptor <OAuth2AuthorizationCodeRequestAuthenticationToken > authorizationCodeRequestAuthenticationCaptor =
549+ ArgumentCaptor .forClass (OAuth2AuthorizationCodeRequestAuthenticationToken .class );
550+ verify (this .authenticationManager ).authenticate (authorizationCodeRequestAuthenticationCaptor .capture ());
511551 verifyNoInteractions (filterChain );
512552
553+ assertThat (authorizationCodeRequestAuthenticationCaptor .getValue ().getDetails ())
554+ .asInstanceOf (type (WebAuthenticationDetails .class ))
555+ .extracting (WebAuthenticationDetails ::getRemoteAddress )
556+ .isEqualTo (REMOTE_ADDRESS );
513557 assertThat (response .getStatus ()).isEqualTo (HttpStatus .FOUND .value ());
514558 assertThat (response .getRedirectedUrl ()).isEqualTo ("https://example.com?code=code&state=state" );
515559 }
@@ -578,6 +622,7 @@ private static MockHttpServletRequest createAuthorizationRequest(RegisteredClien
578622 String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI ;
579623 MockHttpServletRequest request = new MockHttpServletRequest ("GET" , requestUri );
580624 request .setServletPath (requestUri );
625+ request .setRemoteAddr (REMOTE_ADDRESS );
581626
582627 request .addParameter (OAuth2ParameterNames .RESPONSE_TYPE , OAuth2AuthorizationResponseType .CODE .getValue ());
583628 request .addParameter (OAuth2ParameterNames .CLIENT_ID , registeredClient .getClientId ());
@@ -593,6 +638,7 @@ private static MockHttpServletRequest createAuthorizationConsentRequest(Register
593638 String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI ;
594639 MockHttpServletRequest request = new MockHttpServletRequest ("POST" , requestUri );
595640 request .setServletPath (requestUri );
641+ request .setRemoteAddr (REMOTE_ADDRESS );
596642
597643 request .addParameter (OAuth2ParameterNames .CLIENT_ID , registeredClient .getClientId ());
598644 registeredClient .getScopes ().forEach ((scope ) -> request .addParameter (OAuth2ParameterNames .SCOPE , scope ));
0 commit comments