Skip to content

Commit c5ce722

Browse files
committed
Specify clientRegistrationId in TokenRelayFilterFunctions
Closes gh-3541
1 parent 031e249 commit c5ce722

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/TokenRelayFilterFunctions.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.security.Principal;
2020

2121
import org.springframework.cloud.gateway.server.mvc.common.Shortcut;
22+
import org.springframework.security.core.Authentication;
2223
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
2324
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2425
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
@@ -37,13 +38,21 @@ private TokenRelayFilterFunctions() {
3738

3839
@Shortcut
3940
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay() {
41+
return tokenRelay(null);
42+
}
43+
44+
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay(String defaultClientRegistrationId) {
4045
return (request, next) -> {
41-
Principal principle = request.servletRequest().getUserPrincipal();
42-
if (principle instanceof OAuth2AuthenticationToken token) {
43-
String clientRegistrationId = token.getAuthorizedClientRegistrationId();
46+
Authentication principal = (Authentication) request.servletRequest().getUserPrincipal();
47+
48+
String clientRegistrationId = defaultClientRegistrationId;
49+
if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken token) {
50+
clientRegistrationId = token.getAuthorizedClientRegistrationId();
51+
}
52+
if (clientRegistrationId != null) {
4453
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
4554
.withClientRegistrationId(clientRegistrationId)
46-
.principal(token)
55+
.principal(principal)
4756
.build();
4857
OAuth2AuthorizedClientManager clientManager = getApplicationContext(request)
4958
.getBean(OAuth2AuthorizedClientManager.class);

spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/TokenRelayFilterFunctionsTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,29 @@ public void whenPrincipalExistsAuthorizationHeaderAdded() throws Exception {
104104
});
105105
}
106106

107+
@Test
108+
public void whenDefaultClientRegistrationIdProvidedAuthorizationHeaderAdded() throws Exception {
109+
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
110+
when(accessToken.getTokenValue()).thenReturn("mytoken");
111+
112+
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("myregistrationid")
113+
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
114+
.clientId("myclientid")
115+
.tokenUri("mytokenuri")
116+
.build();
117+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "joe", accessToken);
118+
119+
when(authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))).thenReturn(authorizedClient);
120+
121+
request.setUserPrincipal(new TestingAuthenticationToken("my", null));
122+
123+
filter = TokenRelayFilterFunctions.tokenRelay("myId");
124+
filter.filter(ServerRequest.create(request, converters), req -> {
125+
assertThat(req.headers().firstHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer mytoken");
126+
return null;
127+
});
128+
}
129+
107130
@Test
108131
public void principalIsNotOAuth2AuthenticationToken() throws Exception {
109132
request.setUserPrincipal(new TestingAuthenticationToken("my", null));

0 commit comments

Comments
 (0)