Skip to content

Commit 951e641

Browse files
committed
Register OAuth2AuthorizedClientArgumentResolver for XML Config
Closes gh-8669
1 parent e113bd3 commit 951e641

9 files changed

+443
-104
lines changed

config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,8 @@
1515
*/
1616
package org.springframework.security.config.http;
1717

18-
import java.security.SecureRandom;
19-
import java.util.ArrayList;
20-
import java.util.Collections;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.function.Function;
24-
import javax.servlet.http.HttpServletRequest;
25-
2618
import org.apache.commons.logging.Log;
2719
import org.apache.commons.logging.LogFactory;
28-
import org.w3c.dom.Element;
29-
3020
import org.springframework.beans.BeanMetadataElement;
3121
import org.springframework.beans.factory.config.BeanDefinition;
3222
import org.springframework.beans.factory.config.BeanReference;
@@ -63,8 +53,18 @@
6353
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
6454
import org.springframework.security.web.csrf.CsrfToken;
6555
import org.springframework.util.Assert;
56+
import org.springframework.util.ClassUtils;
6657
import org.springframework.util.StringUtils;
6758
import org.springframework.util.xml.DomUtils;
59+
import org.w3c.dom.Element;
60+
61+
import javax.servlet.http.HttpServletRequest;
62+
import java.security.SecureRandom;
63+
import java.util.ArrayList;
64+
import java.util.Collections;
65+
import java.util.List;
66+
import java.util.Map;
67+
import java.util.function.Function;
6868

6969
import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER;
7070
import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER;
@@ -160,12 +160,16 @@ final class AuthenticationConfigBuilder {
160160

161161
private String openIDLoginPage;
162162

163+
private boolean oauth2LoginEnabled;
164+
private boolean defaultAuthorizedClientRepositoryRegistered;
163165
private String oauth2LoginFilterId;
164166
private BeanDefinition oauth2AuthorizationRequestRedirectFilter;
165167
private BeanDefinition oauth2LoginEntryPoint;
166168
private BeanReference oauth2LoginAuthenticationProviderRef;
167169
private BeanReference oauth2LoginOidcAuthenticationProviderRef;
168170
private BeanDefinition oauth2LoginLinks;
171+
172+
private boolean oauth2ClientEnabled;
169173
private BeanDefinition authorizationRequestRedirectFilter;
170174
private BeanDefinition authorizationCodeGrantFilter;
171175
private BeanReference authorizationCodeAuthenticationProviderRef;
@@ -196,8 +200,7 @@ final class AuthenticationConfigBuilder {
196200
createBasicFilter(authenticationManager);
197201
createBearerTokenAuthenticationFilter(authenticationManager);
198202
createFormLoginFilter(sessionStrategy, authenticationManager);
199-
createOAuth2LoginFilter(sessionStrategy, authenticationManager);
200-
createOAuth2ClientFilter(requestCache, authenticationManager);
203+
createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager);
201204
createOpenIDLoginFilter(sessionStrategy, authenticationManager);
202205
createX509Filter(authenticationManager);
203206
createJeeFilter(authenticationManager);
@@ -274,15 +277,27 @@ void createFormLoginFilter(BeanReference sessionStrategy, BeanReference authMana
274277
}
275278
}
276279

280+
void createOAuth2ClientFilters(BeanReference sessionStrategy, BeanReference requestCache,
281+
BeanReference authenticationManager) {
282+
createOAuth2LoginFilter(sessionStrategy, authenticationManager);
283+
createOAuth2ClientFilter(requestCache, authenticationManager);
284+
registerOAuth2ClientPostProcessors();
285+
}
286+
277287
void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
278288
Element oauth2LoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_LOGIN);
279289
if (oauth2LoginElt == null) {
280290
return;
281291
}
292+
this.oauth2LoginEnabled = true;
282293

283294
OAuth2LoginBeanDefinitionParser parser = new OAuth2LoginBeanDefinitionParser(requestCache, portMapper,
284295
portResolver, sessionStrategy, allowSessionCreation);
285296
BeanDefinition oauth2LoginFilterBean = parser.parse(oauth2LoginElt, this.pc);
297+
298+
BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
299+
registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
300+
286301
oauth2LoginFilterBean.getPropertyValues().addPropertyValue("authenticationManager", authManager);
287302

288303
// retrieve the other bean result
@@ -319,11 +334,15 @@ void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenti
319334
if (oauth2ClientElt == null) {
320335
return;
321336
}
337+
this.oauth2ClientEnabled = true;
322338

323339
OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser(
324340
requestCache, authenticationManager);
325341
parser.parse(oauth2ClientElt, this.pc);
326342

343+
BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
344+
registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
345+
327346
this.authorizationRequestRedirectFilter = parser.getAuthorizationRequestRedirectFilter();
328347
String authorizationRequestRedirectFilterId = pc.getReaderContext()
329348
.generateBeanName(this.authorizationRequestRedirectFilter);
@@ -344,6 +363,28 @@ void createOAuth2ClientFilter(BeanReference requestCache, BeanReference authenti
344363
this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId);
345364
}
346365

366+
void registerDefaultAuthorizedClientRepositoryIfNecessary(BeanDefinition defaultAuthorizedClientRepository) {
367+
if (!this.defaultAuthorizedClientRepositoryRegistered && defaultAuthorizedClientRepository != null) {
368+
String authorizedClientRepositoryId = pc.getReaderContext()
369+
.generateBeanName(defaultAuthorizedClientRepository);
370+
this.pc.registerBeanComponent(new BeanComponentDefinition(
371+
defaultAuthorizedClientRepository, authorizedClientRepositoryId));
372+
this.defaultAuthorizedClientRepositoryRegistered = true;
373+
}
374+
}
375+
376+
private void registerOAuth2ClientPostProcessors() {
377+
if (!this.oauth2LoginEnabled && !this.oauth2ClientEnabled) {
378+
return;
379+
}
380+
381+
boolean webmvcPresent = ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader());
382+
if (webmvcPresent) {
383+
this.pc.getReaderContext().registerWithGeneratedName(
384+
new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class));
385+
}
386+
}
387+
347388
void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
348389
Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt,
349390
Elements.OPENID_LOGIN);

config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,30 @@
2323
import org.springframework.beans.factory.xml.BeanDefinitionParser;
2424
import org.springframework.beans.factory.xml.ParserContext;
2525
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
26-
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2726
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
2827
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
2928
import org.springframework.util.StringUtils;
3029
import org.springframework.util.xml.DomUtils;
3130
import org.w3c.dom.Element;
3231

32+
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createAuthorizedClientRepository;
33+
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository;
34+
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository;
35+
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService;
36+
import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository;
37+
3338
/**
3439
* @author Joe Grandja
3540
* @since 5.3
3641
*/
3742
final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
3843
private static final String ELT_AUTHORIZATION_CODE_GRANT = "authorization-code-grant";
39-
private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref";
40-
private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref";
41-
private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref";
4244
private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref";
4345
private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref";
4446
private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref";
4547
private final BeanReference requestCache;
4648
private final BeanReference authenticationManager;
49+
private BeanDefinition defaultAuthorizedClientRepository;
4750
private BeanDefinition authorizationRequestRedirectFilter;
4851
private BeanDefinition authorizationCodeGrantFilter;
4952
private BeanDefinition authorizationCodeAuthenticationProvider;
@@ -58,8 +61,16 @@ public BeanDefinition parse(Element element, ParserContext parserContext) {
5861
Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT);
5962

6063
BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element);
61-
BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(
62-
element, clientRegistrationRepository);
64+
BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element);
65+
if (authorizedClientRepository == null) {
66+
BeanMetadataElement authorizedClientService = getAuthorizedClientService(element);
67+
if (authorizedClientService == null) {
68+
this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository(clientRegistrationRepository);
69+
authorizedClientRepository = this.defaultAuthorizedClientRepository;
70+
} else {
71+
authorizedClientRepository = createAuthorizedClientRepository(authorizedClientService);
72+
}
73+
}
6374
BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(
6475
authorizationCodeGrantElt);
6576

@@ -95,41 +106,6 @@ public BeanDefinition parse(Element element, ParserContext parserContext) {
95106
return null;
96107
}
97108

98-
private BeanMetadataElement getClientRegistrationRepository(Element element) {
99-
BeanMetadataElement clientRegistrationRepository;
100-
String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF);
101-
if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) {
102-
clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef);
103-
} else {
104-
clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class);
105-
}
106-
return clientRegistrationRepository;
107-
}
108-
109-
private BeanMetadataElement getAuthorizedClientRepository(Element element,
110-
BeanMetadataElement clientRegistrationRepository) {
111-
BeanMetadataElement authorizedClientRepository;
112-
String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF);
113-
if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) {
114-
authorizedClientRepository = new RuntimeBeanReference(authorizedClientRepositoryRef);
115-
} else {
116-
BeanMetadataElement authorizedClientService;
117-
String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF);
118-
if (!StringUtils.isEmpty(authorizedClientServiceRef)) {
119-
authorizedClientService = new RuntimeBeanReference(authorizedClientServiceRef);
120-
} else {
121-
authorizedClientService = BeanDefinitionBuilder
122-
.rootBeanDefinition(
123-
"org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService")
124-
.addConstructorArgValue(clientRegistrationRepository).getBeanDefinition();
125-
}
126-
authorizedClientRepository = BeanDefinitionBuilder.rootBeanDefinition(
127-
"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
128-
.addConstructorArgValue(authorizedClientService).getBeanDefinition();
129-
}
130-
return authorizedClientRepository;
131-
}
132-
133109
private BeanMetadataElement getAuthorizationRequestRepository(Element element) {
134110
BeanMetadataElement authorizationRequestRepository;
135111
String authorizationRequestRepositoryRef = element != null ?
@@ -158,6 +134,10 @@ private BeanMetadataElement getAccessTokenResponseClient(Element element) {
158134
return accessTokenResponseClient;
159135
}
160136

137+
BeanDefinition getDefaultAuthorizedClientRepository() {
138+
return this.defaultAuthorizedClientRepository;
139+
}
140+
161141
BeanDefinition getAuthorizationRequestRedirectFilter() {
162142
return this.authorizationRequestRedirectFilter;
163143
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2002-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.config.http;
17+
18+
import org.springframework.beans.BeanMetadataElement;
19+
import org.springframework.beans.factory.config.BeanDefinition;
20+
import org.springframework.beans.factory.config.RuntimeBeanReference;
21+
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
22+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
23+
import org.springframework.util.StringUtils;
24+
import org.w3c.dom.Element;
25+
26+
/**
27+
* @author Joe Grandja
28+
* @since 5.4
29+
*/
30+
final class OAuth2ClientBeanDefinitionParserUtils {
31+
private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref";
32+
private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref";
33+
private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref";
34+
35+
static BeanMetadataElement getClientRegistrationRepository(Element element) {
36+
BeanMetadataElement clientRegistrationRepository;
37+
String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF);
38+
if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) {
39+
clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef);
40+
} else {
41+
clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class);
42+
}
43+
return clientRegistrationRepository;
44+
}
45+
46+
static BeanMetadataElement getAuthorizedClientRepository(Element element) {
47+
String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF);
48+
if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) {
49+
return new RuntimeBeanReference(authorizedClientRepositoryRef);
50+
}
51+
return null;
52+
}
53+
54+
static BeanMetadataElement getAuthorizedClientService(Element element) {
55+
String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF);
56+
if (!StringUtils.isEmpty(authorizedClientServiceRef)) {
57+
return new RuntimeBeanReference(authorizedClientServiceRef);
58+
}
59+
return null;
60+
}
61+
62+
static BeanMetadataElement createAuthorizedClientRepository(BeanMetadataElement authorizedClientService) {
63+
return BeanDefinitionBuilder.rootBeanDefinition(
64+
"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
65+
.addConstructorArgValue(authorizedClientService)
66+
.getBeanDefinition();
67+
}
68+
69+
static BeanDefinition createDefaultAuthorizedClientRepository(BeanMetadataElement clientRegistrationRepository) {
70+
BeanDefinition authorizedClientService = BeanDefinitionBuilder.rootBeanDefinition(
71+
"org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService")
72+
.addConstructorArgValue(clientRegistrationRepository)
73+
.getBeanDefinition();
74+
return BeanDefinitionBuilder.rootBeanDefinition(
75+
"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
76+
.addConstructorArgValue(authorizedClientService)
77+
.getBeanDefinition();
78+
}
79+
}

0 commit comments

Comments
 (0)