Skip to content

Commit 311d9a1

Browse files
authored
Specify clientRegistrationId in TokenRelayFilterFunctions (#3591)
Closes gh-3541
1 parent 9050809 commit 311d9a1

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/TokenRelayFilterFunctions.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616

1717
package org.springframework.cloud.gateway.server.mvc.filter;
1818

19-
import java.security.Principal;
20-
2119
import org.springframework.cloud.gateway.server.mvc.common.Shortcut;
20+
import org.springframework.security.core.Authentication;
2221
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
2322
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2423
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
@@ -37,13 +36,21 @@ private TokenRelayFilterFunctions() {
3736

3837
@Shortcut
3938
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay() {
39+
return tokenRelay(null);
40+
}
41+
42+
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay(String defaultClientRegistrationId) {
4043
return (request, next) -> {
41-
Principal principle = request.servletRequest().getUserPrincipal();
42-
if (principle instanceof OAuth2AuthenticationToken token) {
43-
String clientRegistrationId = token.getAuthorizedClientRegistrationId();
44+
Authentication principal = (Authentication) request.servletRequest().getUserPrincipal();
45+
46+
String clientRegistrationId = defaultClientRegistrationId;
47+
if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken token) {
48+
clientRegistrationId = token.getAuthorizedClientRegistrationId();
49+
}
50+
if (clientRegistrationId != null) {
4451
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
4552
.withClientRegistrationId(clientRegistrationId)
46-
.principal(token)
53+
.principal(principal)
4754
.build();
4855
OAuth2AuthorizedClientManager clientManager = getApplicationContext(request)
4956
.getBean(OAuth2AuthorizedClientManager.class);

spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/TokenRelayFilterFunctionsTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,29 @@ public void whenPrincipalExistsAuthorizationHeaderAdded() throws Exception {
104104
});
105105
}
106106

107+
@Test
108+
public void whenDefaultClientRegistrationIdProvidedAuthorizationHeaderAdded() throws Exception {
109+
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
110+
when(accessToken.getTokenValue()).thenReturn("mytoken");
111+
112+
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("myregistrationid")
113+
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
114+
.clientId("myclientid")
115+
.tokenUri("mytokenuri")
116+
.build();
117+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "joe", accessToken);
118+
119+
when(authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))).thenReturn(authorizedClient);
120+
121+
request.setUserPrincipal(new TestingAuthenticationToken("my", null));
122+
123+
filter = TokenRelayFilterFunctions.tokenRelay("myId");
124+
filter.filter(ServerRequest.create(request, converters), req -> {
125+
assertThat(req.headers().firstHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer mytoken");
126+
return null;
127+
});
128+
}
129+
107130
@Test
108131
public void principalIsNotOAuth2AuthenticationToken() throws Exception {
109132
request.setUserPrincipal(new TestingAuthenticationToken("my", null));

0 commit comments

Comments
 (0)