Skip to content

Commit d47db53

Browse files
committed
Add ClientRegistrationIdResolver
1 parent 32aca6b commit d47db53

File tree

3 files changed

+136
-195
lines changed

3 files changed

+136
-195
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/function/client/OAuth2ClientHttpRequestInterceptor.java

Lines changed: 45 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.io.IOException;
2020
import java.util.HashMap;
2121
import java.util.Map;
22-
import java.util.function.Consumer;
2322

2423
import jakarta.servlet.http.HttpServletRequest;
2524
import jakarta.servlet.http.HttpServletResponse;
@@ -31,7 +30,7 @@
3130
import org.springframework.http.client.ClientHttpRequestExecution;
3231
import org.springframework.http.client.ClientHttpRequestInterceptor;
3332
import org.springframework.http.client.ClientHttpResponse;
34-
import org.springframework.security.access.AccessDeniedException;
33+
import org.springframework.lang.Nullable;
3534
import org.springframework.security.authentication.AnonymousAuthenticationToken;
3635
import org.springframework.security.core.Authentication;
3736
import org.springframework.security.core.authority.AuthorityUtils;
@@ -45,16 +44,13 @@
4544
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
4645
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
4746
import org.springframework.security.oauth2.client.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
48-
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
49-
import org.springframework.security.oauth2.client.registration.ClientRegistration;
5047
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
5148
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
5249
import org.springframework.security.oauth2.core.OAuth2Error;
5350
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
5451
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
5552
import org.springframework.util.Assert;
5653
import org.springframework.util.StringUtils;
57-
import org.springframework.web.client.RestClient;
5854
import org.springframework.web.client.RestClientResponseException;
5955
import org.springframework.web.context.request.RequestContextHolder;
6056
import org.springframework.web.context.request.ServletRequestAttributes;
@@ -114,14 +110,9 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
114110
private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
115111
"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
116112

117-
private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2ClientHttpRequestInterceptor.class.getName()
118-
.concat(".clientRegistrationId");
119-
120113
private final OAuth2AuthorizedClientManager authorizedClientManager;
121114

122-
private String defaultClientRegistrationId;
123-
124-
private boolean useAuthenticatedClientRegistrationId;
115+
private final ClientRegistrationIdResolver clientRegistrationIdResolver;
125116

126117
// @formatter:off
127118
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
@@ -138,41 +129,23 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
138129
* manages the authorized client(s)
139130
*/
140131
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) {
141-
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
142-
this.authorizedClientManager = authorizedClientManager;
143-
}
144-
145-
/**
146-
* Sets the default {@code clientRegistrationId} to be used for resolving an
147-
* {@link OAuth2AuthorizedClient}.
148-
*
149-
* <p>
150-
* By default, the {@code clientRegistrationId} is obtained from the current
151-
* {@link Authentication principal}. Using this setter overrides the default, but can
152-
* be overridden by providing an
153-
* {@link RestClient.RequestHeadersSpec#attributes(Consumer) attribute} via
154-
* {@link #clientRegistrationId(String)}.
155-
* @param clientRegistrationId the default {@code clientRegistrationId}
156-
*/
157-
public void setDefaultClientRegistrationId(String clientRegistrationId) {
158-
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
159-
this.defaultClientRegistrationId = clientRegistrationId;
132+
this(authorizedClientManager, new RequestAttributeClientRegistrationIdResolver());
160133
}
161134

162135
/**
163-
* Enables or disables discovering the {@code clientRegistrationId} from the current
164-
* {@link Authentication principal}. It is recommended to be cautious with this
165-
* feature since all HTTP requests will receive the access token if it can be resolved
166-
* from the current Authentication.
167-
*
168-
* <p>
169-
* This feature requires the user to be logged in via OAuth2 or OpenID Connect Login.
170-
* @param useAuthenticatedClientRegistrationId true if the
171-
* {@code clientRegistrationId} should be discovered from the current
172-
* {@link Authentication principal}. The default is false.
136+
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
137+
* parameters.
138+
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
139+
* manages the authorized client(s)
140+
* @param clientRegistrationIdResolver the strategy for resolving a
141+
* {@code clientRegistrationId} from the intercepted request
173142
*/
174-
public void setUseAuthenticatedClientRegistrationId(boolean useAuthenticatedClientRegistrationId) {
175-
this.useAuthenticatedClientRegistrationId = useAuthenticatedClientRegistrationId;
143+
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager,
144+
ClientRegistrationIdResolver clientRegistrationIdResolver) {
145+
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
146+
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
147+
this.authorizedClientManager = authorizedClientManager;
148+
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
176149
}
177150

178151
/**
@@ -268,19 +241,6 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
268241
this.securityContextHolderStrategy = securityContextHolderStrategy;
269242
}
270243

271-
/**
272-
* Modifies the {@link RestClient.RequestHeadersSpec#attributes(Consumer) attributes}
273-
* to include the {@link ClientRegistration#getRegistrationId() clientRegistrationId}
274-
* to be used to look up the {@link OAuth2AuthorizedClient}.
275-
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
276-
* clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
277-
* @return the {@link Consumer} to populate the attributes
278-
*/
279-
public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
280-
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
281-
return (attributes) -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
282-
}
283-
284244
@Override
285245
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
286246
throws IOException {
@@ -306,7 +266,11 @@ public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttp
306266
}
307267

308268
private void authorizeClient(HttpRequest request, Authentication principal) {
309-
String clientRegistrationId = clientRegistrationId(request, principal);
269+
String clientRegistrationId = this.clientRegistrationIdResolver.resolve(request);
270+
if (clientRegistrationId == null) {
271+
return;
272+
}
273+
310274
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId)
311275
.principal(principal)
312276
.build();
@@ -323,7 +287,11 @@ private void handleAuthorizationFailure(HttpRequest request, Authentication prin
323287
return;
324288
}
325289

326-
String clientRegistrationId = clientRegistrationId(request, principal);
290+
String clientRegistrationId = this.clientRegistrationIdResolver.resolve(request);
291+
if (clientRegistrationId == null) {
292+
return;
293+
}
294+
327295
ClientAuthorizationException authorizationException = new ClientAuthorizationException(error,
328296
clientRegistrationId);
329297
handleAuthorizationFailure(authorizationException, principal);
@@ -368,35 +336,6 @@ private static Map<String, String> parseWwwAuthenticateHeader(String wwwAuthenti
368336
return parameters;
369337
}
370338

371-
private String clientRegistrationId(HttpRequest request, Authentication principal) {
372-
String clientRegistrationId = (String) request.getAttributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
373-
if (clientRegistrationId == null) {
374-
clientRegistrationId = this.defaultClientRegistrationId;
375-
}
376-
if (clientRegistrationId == null && this.useAuthenticatedClientRegistrationId) {
377-
if (principal instanceof OAuth2AuthenticationToken) {
378-
clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
379-
}
380-
else if (principal instanceof AnonymousAuthenticationToken) {
381-
throw new AccessDeniedException("Authentication is required");
382-
}
383-
else {
384-
throw new IllegalStateException("Unable to discover clientRegistrationId."
385-
+ " When useAuthenticatedClientRegistrationId=true, the current principal must be of type OAuth2AuthenticationToken"
386-
+ " (OAuth2 or OpenID Connect Login is required in order to use this feature).");
387-
}
388-
}
389-
if (clientRegistrationId == null) {
390-
throw new IllegalStateException("No clientRegistrationId was provided."
391-
+ " Please consider using OAuth2ClientHttpRequestInterceptor.clientRegistrationId(String) to provide one per request via RestClient.RequestHeadersSpec#attributes(Consumer),"
392-
+ " OAuth2ClientHttpRequestInterceptor#setDefaultClientRegistrationId(String) to provide a default for all requests,"
393-
+ " or OAuth2ClientHttpRequestInterceptor#setUseAuthenticatedClientRegistrationId(true) to configure resolving one from the current principal"
394-
+ " (OAuth2 or OpenID Connect Login is required in order to use this feature).");
395-
}
396-
397-
return clientRegistrationId;
398-
}
399-
400339
private void handleAuthorizationFailure(OAuth2AuthorizationException authorizationException,
401340
Authentication principal) {
402341
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder
@@ -412,4 +351,24 @@ private void handleAuthorizationFailure(OAuth2AuthorizationException authorizati
412351
this.authorizationFailureHandler.onAuthorizationFailure(authorizationException, principal, attributes);
413352
}
414353

354+
/**
355+
* A strategy for resolving a {@code clientRegistrationId} from an intercepted
356+
* request.
357+
*/
358+
@FunctionalInterface
359+
public interface ClientRegistrationIdResolver {
360+
361+
/**
362+
* Resolve the {@code clientRegistrationId} from the current request, which is
363+
* used to obtain an {@link OAuth2AuthorizedClient}.
364+
* @param request the intercepted request, containing HTTP method, URI, headers,
365+
* and request attributes
366+
* @return the {@code clientRegistrationId} to be used for resolving an
367+
* {@link OAuth2AuthorizedClient}.
368+
*/
369+
@Nullable
370+
String resolve(HttpRequest request);
371+
372+
}
373+
415374
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client.web.function.client;
18+
19+
import java.util.Map;
20+
import java.util.function.Consumer;
21+
22+
import org.springframework.http.HttpRequest;
23+
import org.springframework.http.client.ClientHttpRequest;
24+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
25+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
26+
import org.springframework.util.Assert;
27+
28+
/**
29+
* @author Steve Riesenberg
30+
*/
31+
public final class RequestAttributeClientRegistrationIdResolver
32+
implements OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver {
33+
34+
private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = RequestAttributeClientRegistrationIdResolver.class
35+
.getName()
36+
.concat(".clientRegistrationId");
37+
38+
@Override
39+
public String resolve(HttpRequest request) {
40+
return (String) request.getAttributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
41+
}
42+
43+
/**
44+
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
45+
* {@link ClientRegistration#getRegistrationId() clientRegistrationId} to be used to
46+
* look up the {@link OAuth2AuthorizedClient}.
47+
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
48+
* clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
49+
* @return the {@link Consumer} to populate the attributes
50+
*/
51+
public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
52+
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
53+
return (attributes) -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
54+
}
55+
56+
}

0 commit comments

Comments
 (0)