28
28
import org .springframework .http .HttpRequest ;
29
29
import org .springframework .http .HttpStatus ;
30
30
import org .springframework .http .HttpStatusCode ;
31
- import org .springframework .http .client .ClientHttpRequest ;
32
31
import org .springframework .http .client .ClientHttpRequestExecution ;
33
32
import org .springframework .http .client .ClientHttpRequestInterceptor ;
34
33
import org .springframework .http .client .ClientHttpResponse ;
34
+ import org .springframework .security .access .AccessDeniedException ;
35
35
import org .springframework .security .authentication .AnonymousAuthenticationToken ;
36
36
import org .springframework .security .core .Authentication ;
37
37
import org .springframework .security .core .authority .AuthorityUtils ;
@@ -121,6 +121,8 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
121
121
122
122
private String defaultClientRegistrationId ;
123
123
124
+ private boolean useAuthenticatedClientRegistrationId ;
125
+
124
126
// @formatter:off
125
127
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
126
128
(clientRegistrationId , principal , attributes ) -> { };
@@ -157,6 +159,22 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
157
159
this .defaultClientRegistrationId = clientRegistrationId ;
158
160
}
159
161
162
+ /**
163
+ * Enables or disables discovering the {@code clientRegistrationId} from the current
164
+ * {@link Authentication principal}. It is recommended to be cautious with this
165
+ * feature since all HTTP requests will receive the access token if it can be resolved
166
+ * from the current Authentication.
167
+ *
168
+ * <p>
169
+ * This feature requires the user to be logged in via OAuth2 or OpenID Connect Login.
170
+ * @param useAuthenticatedClientRegistrationId true if the
171
+ * {@code clientRegistrationId} should be discovered from the current
172
+ * {@link Authentication principal}. The default is false.
173
+ */
174
+ public void setUseAuthenticatedClientRegistrationId (boolean useAuthenticatedClientRegistrationId ) {
175
+ this .useAuthenticatedClientRegistrationId = useAuthenticatedClientRegistrationId ;
176
+ }
177
+
160
178
/**
161
179
* Sets the {@link OAuth2AuthorizationFailureHandler} that handles authentication and
162
180
* authorization failures when communicating to the OAuth 2.0 Resource Server.
@@ -251,9 +269,9 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
251
269
}
252
270
253
271
/**
254
- * Modifies the {@link ClientHttpRequest#getAttributes( ) attributes} to include the
255
- * {@link ClientRegistration#getRegistrationId() clientRegistrationId} to be used to
256
- * look up the {@link OAuth2AuthorizedClient}.
272
+ * Modifies the {@link RestClient.RequestHeadersSpec#attributes(Consumer ) attributes}
273
+ * to include the {@link ClientRegistration#getRegistrationId() clientRegistrationId}
274
+ * to be used to look up the {@link OAuth2AuthorizedClient}.
257
275
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
258
276
* clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
259
277
* @return the {@link Consumer} to populate the attributes
@@ -289,16 +307,9 @@ public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttp
289
307
290
308
private void authorizeClient (HttpRequest request , Authentication principal ) {
291
309
String clientRegistrationId = clientRegistrationId (request , principal );
292
- if (clientRegistrationId == null ) {
293
- return ;
294
- }
295
-
296
- // @formatter:off
297
- OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
298
- .withClientRegistrationId (clientRegistrationId )
299
- .principal (principal )
300
- .build ();
301
- // @formatter:on
310
+ OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest .withClientRegistrationId (clientRegistrationId )
311
+ .principal (principal )
312
+ .build ();
302
313
OAuth2AuthorizedClient authorizedClient = this .authorizedClientManager .authorize (authorizeRequest );
303
314
if (authorizedClient != null ) {
304
315
request .getHeaders ().setBearerAuth (authorizedClient .getAccessToken ().getTokenValue ());
@@ -313,10 +324,6 @@ private void handleAuthorizationFailure(HttpRequest request, Authentication prin
313
324
}
314
325
315
326
String clientRegistrationId = clientRegistrationId (request , principal );
316
- if (clientRegistrationId == null ) {
317
- return ;
318
- }
319
-
320
327
ClientAuthorizationException authorizationException = new ClientAuthorizationException (error ,
321
328
clientRegistrationId );
322
329
handleAuthorizationFailure (authorizationException , principal );
@@ -366,8 +373,25 @@ private String clientRegistrationId(HttpRequest request, Authentication principa
366
373
if (clientRegistrationId == null ) {
367
374
clientRegistrationId = this .defaultClientRegistrationId ;
368
375
}
369
- if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken authentication ) {
370
- clientRegistrationId = authentication .getAuthorizedClientRegistrationId ();
376
+ if (clientRegistrationId == null && this .useAuthenticatedClientRegistrationId ) {
377
+ if (principal instanceof OAuth2AuthenticationToken ) {
378
+ clientRegistrationId = ((OAuth2AuthenticationToken ) principal ).getAuthorizedClientRegistrationId ();
379
+ }
380
+ else if (principal instanceof AnonymousAuthenticationToken ) {
381
+ throw new AccessDeniedException ("Authentication is required" );
382
+ }
383
+ else {
384
+ throw new IllegalStateException ("Unable to discover clientRegistrationId."
385
+ + " When useAuthenticatedClientRegistrationId=true, the current principal must be of type OAuth2AuthenticationToken"
386
+ + " (OAuth2 or OpenID Connect Login is required in order to use this feature)." );
387
+ }
388
+ }
389
+ if (clientRegistrationId == null ) {
390
+ throw new IllegalStateException ("No clientRegistrationId was provided."
391
+ + " Please consider using OAuth2ClientHttpRequestInterceptor.clientRegistrationId(String) to provide one per request via RestClient.RequestHeadersSpec#attributes(Consumer),"
392
+ + " OAuth2ClientHttpRequestInterceptor#setDefaultClientRegistrationId(String) to provide a default for all requests,"
393
+ + " or OAuth2ClientHttpRequestInterceptor#setUseAuthenticatedClientRegistrationId(true) to configure resolving one from the current principal"
394
+ + " (OAuth2 or OpenID Connect Login is required in order to use this feature)." );
371
395
}
372
396
373
397
return clientRegistrationId ;
0 commit comments