Skip to content

Commit ec7ab5c

Browse files
Kehrlannjgrandja
authored andcommitted
Add authenticationDetailsSource to AuthorizationEndpointFilter
Closes gh-768
1 parent fdf0a2f commit ec7ab5c

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2021 the original author or authors.
2+
* Copyright 2020-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@
2828
import org.springframework.http.HttpMethod;
2929
import org.springframework.http.HttpStatus;
3030
import org.springframework.http.MediaType;
31+
import org.springframework.security.authentication.AuthenticationDetailsSource;
3132
import org.springframework.security.authentication.AuthenticationManager;
3233
import org.springframework.security.core.Authentication;
3334
import org.springframework.security.core.AuthenticationException;
@@ -45,6 +46,7 @@
4546
import org.springframework.security.web.authentication.AuthenticationConverter;
4647
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
4748
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
49+
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
4850
import org.springframework.security.web.util.RedirectUrlBuilder;
4951
import org.springframework.security.web.util.UrlUtils;
5052
import org.springframework.security.web.util.matcher.AndRequestMatcher;
@@ -82,6 +84,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
8284
private final AuthenticationManager authenticationManager;
8385
private final RequestMatcher authorizationEndpointMatcher;
8486
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
87+
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
8588
private AuthenticationConverter authenticationConverter;
8689
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse;
8790
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
@@ -144,6 +147,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
144147
try {
145148
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
146149
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationConverter.convert(request);
150+
authorizationCodeRequestAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
147151

148152
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
149153
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationManager.authenticate(authorizationCodeRequestAuthentication);
@@ -169,6 +173,17 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
169173
}
170174
}
171175

176+
/**
177+
* Sets the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}.
178+
*
179+
* @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}
180+
* @since 0.3.1
181+
*/
182+
public void setAuthenticationDetailsSource(AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
183+
Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
184+
this.authenticationDetailsSource = authenticationDetailsSource;
185+
}
186+
172187
/**
173188
* Sets the {@link AuthenticationConverter} used when attempting to extract an Authorization Request (or Consent) from {@link HttpServletRequest}
174189
* to an instance of {@link OAuth2AuthorizationCodeRequestAuthenticationToken} used for authenticating the request.

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
import org.junit.Before;
3333
import org.junit.Test;
3434

35+
import org.mockito.ArgumentCaptor;
3536
import org.springframework.http.HttpStatus;
3637
import org.springframework.http.MediaType;
3738
import org.springframework.mock.web.MockHttpServletRequest;
3839
import org.springframework.mock.web.MockHttpServletResponse;
40+
import org.springframework.security.authentication.AuthenticationDetailsSource;
3941
import org.springframework.security.authentication.AuthenticationManager;
4042
import org.springframework.security.authentication.TestingAuthenticationToken;
4143
import org.springframework.security.core.Authentication;
@@ -55,10 +57,12 @@
5557
import org.springframework.security.web.authentication.AuthenticationConverter;
5658
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
5759
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
60+
import org.springframework.security.web.authentication.WebAuthenticationDetails;
5861
import org.springframework.util.StringUtils;
5962

6063
import static org.assertj.core.api.Assertions.assertThat;
6164
import static org.assertj.core.api.Assertions.assertThatThrownBy;
65+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
6266
import static org.mockito.ArgumentMatchers.any;
6367
import static org.mockito.ArgumentMatchers.same;
6468
import static org.mockito.Mockito.mock;
@@ -78,6 +82,7 @@
7882
*/
7983
public class OAuth2AuthorizationEndpointFilterTests {
8084
private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
85+
private static final String REMOTE_ADDRESS = "remote-address";
8186
private AuthenticationManager authenticationManager;
8287
private OAuth2AuthorizationEndpointFilter filter;
8388
private TestingAuthenticationToken principal;
@@ -116,6 +121,13 @@ public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentE
116121
.hasMessage("authorizationEndpointUri cannot be empty");
117122
}
118123

124+
@Test
125+
public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
126+
assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
127+
.isInstanceOf(IllegalArgumentException.class)
128+
.hasMessage("authenticationDetailsSource cannot be null");
129+
}
130+
119131
@Test
120132
public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
121133
assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
@@ -364,6 +376,32 @@ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exce
364376
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
365377
}
366378

379+
@Test
380+
public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
381+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
382+
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
383+
authorizationCodeRequestAuthentication(registeredClient, this.principal).build();
384+
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
385+
386+
AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource =
387+
mock(AuthenticationDetailsSource.class);
388+
WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
389+
when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails);
390+
this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
391+
392+
when(this.authenticationManager.authenticate(any()))
393+
.thenReturn(authorizationCodeRequestAuthentication);
394+
395+
MockHttpServletResponse response = new MockHttpServletResponse();
396+
FilterChain filterChain = mock(FilterChain.class);
397+
398+
this.filter.doFilter(request, response, filterChain);
399+
400+
verify(authenticationDetailsSource).buildDetails(any());
401+
verify(this.authenticationManager).authenticate(any());
402+
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
403+
}
404+
367405
@Test
368406
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
369407
this.principal.setAuthenticated(false);
@@ -507,9 +545,15 @@ public void doFilterWhenAuthorizationRequestAuthenticatedThenAuthorizationRespon
507545

508546
this.filter.doFilter(request, response, filterChain);
509547

510-
verify(this.authenticationManager).authenticate(any());
548+
ArgumentCaptor<OAuth2AuthorizationCodeRequestAuthenticationToken> authorizationCodeRequestAuthenticationCaptor =
549+
ArgumentCaptor.forClass(OAuth2AuthorizationCodeRequestAuthenticationToken.class);
550+
verify(this.authenticationManager).authenticate(authorizationCodeRequestAuthenticationCaptor.capture());
511551
verifyNoInteractions(filterChain);
512552

553+
assertThat(authorizationCodeRequestAuthenticationCaptor.getValue().getDetails())
554+
.asInstanceOf(type(WebAuthenticationDetails.class))
555+
.extracting(WebAuthenticationDetails::getRemoteAddress)
556+
.isEqualTo(REMOTE_ADDRESS);
513557
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
514558
assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
515559
}
@@ -578,6 +622,7 @@ private static MockHttpServletRequest createAuthorizationRequest(RegisteredClien
578622
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
579623
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
580624
request.setServletPath(requestUri);
625+
request.setRemoteAddr(REMOTE_ADDRESS);
581626

582627
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
583628
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
@@ -593,6 +638,7 @@ private static MockHttpServletRequest createAuthorizationConsentRequest(Register
593638
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
594639
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
595640
request.setServletPath(requestUri);
641+
request.setRemoteAddr(REMOTE_ADDRESS);
596642

597643
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
598644
registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));

0 commit comments

Comments
 (0)