diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index a6b5f7c52bf..fb51256930a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -35,6 +35,7 @@ import org.springframework.context.event.GenericApplicationListenerAdapter; import org.springframework.context.event.SmartApplicationListener; import org.springframework.core.ResolvableType; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; @@ -50,6 +51,7 @@ import org.springframework.security.core.session.SessionDestroyedEvent; import org.springframework.security.core.session.SessionIdChangedEvent; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient; @@ -430,6 +432,10 @@ public void configure(B http) throws Exception { authenticationFilter .setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository); } + if (this.authorizationEndpointConfig.authenticationResultConverter != null) { + authenticationFilter + .setAuthenticationResultConverter(this.authorizationEndpointConfig.authenticationResultConverter); + } configureOidcSessionRegistry(http); super.configure(http); } @@ -619,6 +625,8 @@ public final class AuthorizationEndpointConfig { private AuthorizationRequestRepository authorizationRequestRepository; + private Converter authenticationResultConverter; + private RedirectStrategy authorizationRedirectStrategy; private AuthorizationEndpointConfig() { @@ -663,6 +671,20 @@ public AuthorizationEndpointConfig authorizationRequestRepository( return this; } + /** + * Sets the converter responsible for converting from + * {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken} + * authentication result. + * @param authenticationResultConverter the converter for + * {@link OAuth2AuthenticationToken}'s + */ + public AuthorizationEndpointConfig authenticationResultConverter( + Converter authenticationResultConverter) { + Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null"); + this.authenticationResultConverter = authenticationResultConverter; + return this; + } + /** * Sets the redirect strategy for Authorization Endpoint redirect URI. * @param authorizationRedirectStrategy the redirect strategy diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index b56d047a5f7..61b050fabaf 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -211,6 +212,20 @@ public void oauth2Login() throws Exception { .hasToString("OAUTH2_USER"); } + @Test + public void oauth2LoginWhenSetAuthenticationResultConverter() throws Exception { + loadConfig(OAuth2LoginConfiguration.class); + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication).isInstanceOf(CustomOAuth2AuthenticationToken.class); + } + @Test public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception { loadConfig(OAuth2LoginConfig.class, SecurityContextChangedListenerConfig.class); @@ -754,6 +769,31 @@ public void onApplicationEvent(AuthenticationSuccessEvent event) { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfiguration extends CommonSecurityFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + http.oauth2Login((c) -> c + .authorizationEndpoint((d) -> d + .authenticationResultConverter((r) -> new CustomOAuth2AuthenticationToken(r.getPrincipal(), + r.getAuthorities(), r.getClientRegistration().getRegistrationId()))) + .clientRegistrationRepository(new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION))); + return super.configureFilterChain(http); + } + + } + + private static class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { + + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection authorities, + String authorizedClientRegistrationId) { + super(principal, authorities, authorizedClientRegistrationId); + } + + } + @Configuration @EnableWebSecurity static class OAuth2LoginConfigFormLogin extends CommonSecurityFilterChainConfig {