Skip to content

Commit 17e774d

Browse files
farraultjgrandja
authored andcommitted
Preserve existing refresh token if new refresh token not returned
During an oauth2 refresh if the authorization server doesn't return a new refresh token, preserve the existing one. Fixes: gh-6503
1 parent 0428906 commit 17e774d

File tree

4 files changed

+133
-8
lines changed

4 files changed

+133
-8
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,6 +45,7 @@
4545
import java.time.Duration;
4646
import java.time.Instant;
4747
import java.util.Map;
48+
import java.util.Optional;
4849
import java.util.function.Consumer;
4950

5051
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
@@ -289,7 +290,11 @@ private Mono<OAuth2AuthorizedClient> authorizeWithRefreshToken(ExchangeFunction
289290
.build();
290291
return next.exchange(refreshRequest)
291292
.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
292-
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
293+
.map(accessTokenResponse -> {
294+
OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken())
295+
.orElse(authorizedClient.getRefreshToken());
296+
return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
297+
})
293298
.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
294299
.thenReturn(result));
295300
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -391,7 +391,11 @@ private Mono<OAuth2AuthorizedClient> authorizeWithRefreshToken(ClientRequest req
391391
.build();
392392
return next.exchange(refreshRequest)
393393
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
394-
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
394+
.map(accessTokenResponse -> {
395+
OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken())
396+
.orElse(authorizedClient.getRefreshToken());
397+
return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken);
398+
})
395399
.map(result -> {
396400
Authentication principal = (Authentication) request.attribute(
397401
AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()));

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@
1919
import org.junit.Before;
2020
import org.junit.Test;
2121
import org.junit.runner.RunWith;
22+
import org.mockito.ArgumentCaptor;
23+
import org.mockito.Captor;
2224
import org.mockito.Mock;
2325
import org.mockito.junit.MockitoJUnitRunner;
2426
import org.springframework.core.codec.ByteBufferEncoder;
@@ -94,6 +96,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
9496
@Mock
9597
private ServerWebExchange serverWebExchange;
9698

99+
@Captor
100+
private ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor;
101+
97102
private ServerOAuth2AuthorizedClientExchangeFilterFunction function;
98103

99104
private MockExchangeFunction exchange = new MockExchangeFunction();
@@ -260,7 +265,62 @@ public void filterWhenRefreshRequiredThenRefresh() {
260265
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
261266
.block();
262267

263-
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
268+
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any());
269+
270+
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
271+
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
272+
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
273+
274+
List<ClientRequest> requests = this.exchange.getRequests();
275+
assertThat(requests).hasSize(2);
276+
277+
ClientRequest request0 = requests.get(0);
278+
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
279+
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token");
280+
assertThat(request0.method()).isEqualTo(HttpMethod.POST);
281+
assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token");
282+
283+
ClientRequest request1 = requests.get(1);
284+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
285+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
286+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
287+
assertThat(getBody(request1)).isEmpty();
288+
}
289+
290+
@Test
291+
public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken() {
292+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
293+
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
294+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
295+
.expiresIn(3600)
296+
// .refreshToken(xxx) // No refreshToken in response
297+
.build();
298+
when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response));
299+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
300+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
301+
302+
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
303+
this.accessToken.getTokenValue(),
304+
issuedAt,
305+
accessTokenExpiresAt);
306+
307+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
308+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
309+
"principalName", this.accessToken, refreshToken);
310+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
311+
.attributes(oauth2AuthorizedClient(authorizedClient))
312+
.build();
313+
314+
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
315+
this.function.filter(request, this.exchange)
316+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
317+
.block();
318+
319+
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(authentication), any());
320+
321+
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
322+
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
323+
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(authorizedClient.getRefreshToken());
264324

265325
List<ClientRequest> requests = this.exchange.getRequests();
266326
assertThat(requests).hasSize(2);

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -101,6 +101,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
101101
private WebClient.RequestHeadersSpec<?> spec;
102102
@Captor
103103
private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
104+
@Captor
105+
private ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor;
104106

105107
/**
106108
* Used for get the attributes from defaultRequest.
@@ -406,7 +408,61 @@ public void filterWhenRefreshRequiredThenRefresh() {
406408

407409
this.function.filter(request, this.exchange).block();
408410

409-
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
411+
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any());
412+
413+
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
414+
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
415+
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
416+
417+
List<ClientRequest> requests = this.exchange.getRequests();
418+
assertThat(requests).hasSize(2);
419+
420+
ClientRequest request0 = requests.get(0);
421+
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
422+
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token");
423+
assertThat(request0.method()).isEqualTo(HttpMethod.POST);
424+
assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token");
425+
426+
ClientRequest request1 = requests.get(1);
427+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
428+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
429+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
430+
assertThat(getBody(request1)).isEmpty();
431+
}
432+
433+
@Test
434+
public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefreshToken() {
435+
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
436+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
437+
.expiresIn(3600)
438+
// .refreshToken(xxx) // No refreshToken in response
439+
.build();
440+
when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response));
441+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
442+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
443+
444+
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
445+
this.accessToken.getTokenValue(),
446+
issuedAt,
447+
accessTokenExpiresAt);
448+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
449+
this.authorizedClientRepository);
450+
451+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
452+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
453+
"principalName", this.accessToken, refreshToken);
454+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
455+
.attributes(oauth2AuthorizedClient(authorizedClient))
456+
.attributes(authentication(this.authentication))
457+
.build();
458+
459+
this.function.filter(request, this.exchange).block();
460+
461+
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any());
462+
463+
OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue();
464+
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
465+
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(refreshToken);
410466

411467
List<ClientRequest> requests = this.exchange.getRequests();
412468
assertThat(requests).hasSize(2);

0 commit comments

Comments
 (0)