Skip to content

Commit 6d6dc11

Browse files
Steve Riesenbergsjohnr
authored andcommitted
Add converter for authentication result in OAuth2LoginAuthenticationFilter
Closes gh-10033
1 parent fc553bf commit 6d6dc11

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import javax.servlet.http.HttpServletRequest;
2020
import javax.servlet.http.HttpServletResponse;
2121

22+
import org.springframework.core.convert.converter.Converter;
2223
import org.springframework.security.authentication.AuthenticationManager;
2324
import org.springframework.security.core.Authentication;
2425
import org.springframework.security.core.AuthenticationException;
@@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
111112

112113
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
113114

115+
private Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter = this::createAuthenticationResult;
116+
114117
/**
115118
* Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided
116119
* parameters.
@@ -190,9 +193,9 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ
190193
authenticationRequest.setDetails(authenticationDetails);
191194
OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this
192195
.getAuthenticationManager().authenticate(authenticationRequest);
193-
OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken(
194-
authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
195-
authenticationResult.getClientRegistration().getRegistrationId());
196+
OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter
197+
.convert(authenticationResult);
198+
Assert.notNull(oauth2Authentication, "authentication result cannot be null");
196199
oauth2Authentication.setDetails(authenticationDetails);
197200
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
198201
authenticationResult.getClientRegistration(), oauth2Authentication.getName(),
@@ -213,4 +216,22 @@ public final void setAuthorizationRequestRepository(
213216
this.authorizationRequestRepository = authorizationRequestRepository;
214217
}
215218

219+
/**
220+
* Sets the converter responsible for converting from
221+
* {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken}
222+
* authentication result.
223+
* @param authenticationResultConverter the converter for
224+
* {@link OAuth2AuthenticationToken}'s
225+
*/
226+
public final void setAuthenticationResultConverter(
227+
Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter) {
228+
Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null");
229+
this.authenticationResultConverter = authenticationResultConverter;
230+
}
231+
232+
private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) {
233+
return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
234+
authenticationResult.getClientRegistration().getRegistrationId());
235+
}
236+
216237
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.security.oauth2.client.web;
1818

19+
import java.util.Collection;
1920
import java.util.HashMap;
2021
import java.util.Map;
2122

@@ -33,10 +34,12 @@
3334
import org.springframework.security.authentication.AuthenticationManager;
3435
import org.springframework.security.core.Authentication;
3536
import org.springframework.security.core.AuthenticationException;
37+
import org.springframework.security.core.GrantedAuthority;
3638
import org.springframework.security.core.authority.AuthorityUtils;
3739
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
3840
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
3941
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
42+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
4043
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
4144
import org.springframework.security.oauth2.client.registration.ClientRegistration;
4245
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -152,6 +155,12 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI
152155
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
153156
}
154157

158+
// gh-10033
159+
@Test
160+
public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
161+
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
162+
}
163+
155164
@Test
156165
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
157166
String requestUri = "/path";
@@ -416,6 +425,41 @@ public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationR
416425
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
417426
}
418427

428+
// gh-10033
429+
@Test
430+
public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
431+
this.filter.setAuthenticationResultConverter((authentication) -> null);
432+
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
433+
String state = "state";
434+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
435+
request.setServletPath(requestUri);
436+
request.addParameter(OAuth2ParameterNames.CODE, "code");
437+
request.addParameter(OAuth2ParameterNames.STATE, state);
438+
MockHttpServletResponse response = new MockHttpServletResponse();
439+
this.setUpAuthorizationRequest(request, response, this.registration1, state);
440+
this.setUpAuthenticationResult(this.registration1);
441+
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
442+
}
443+
444+
// gh-10033
445+
@Test
446+
public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
447+
this.filter.setAuthenticationResultConverter(
448+
(authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
449+
authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
450+
String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
451+
String state = "state";
452+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
453+
request.setServletPath(requestUri);
454+
request.addParameter(OAuth2ParameterNames.CODE, "code");
455+
request.addParameter(OAuth2ParameterNames.STATE, state);
456+
MockHttpServletResponse response = new MockHttpServletResponse();
457+
this.setUpAuthorizationRequest(request, response, this.registration1, state);
458+
this.setUpAuthenticationResult(this.registration1);
459+
Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
460+
assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
461+
}
462+
419463
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
420464
ClientRegistration registration, String state) {
421465
Map<String, Object> attributes = new HashMap<>();
@@ -454,4 +498,13 @@ private void setUpAuthenticationResult(ClientRegistration registration) {
454498
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
455499
}
456500

501+
private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
502+
503+
CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
504+
String authorizedClientRegistrationId) {
505+
super(principal, authorities, authorizedClientRegistrationId);
506+
}
507+
508+
}
509+
457510
}

0 commit comments

Comments
 (0)