Skip to content

Commit e3683eb

Browse files
committed
Always return current ClientRegistration in loadAuthorizedClient
This changes `InMemoryOAuth2AuthorizedClientService.loadAuthorizedClient` (and its reactive counterpart) to always return `OAuth2AuthorizedClient` instances containing the current `ClientRegistration` as obtained from the `ClientRegistrationRepository`. Before this change, the first `ClientRegistration` instance was cached, with the effect that any changes made in the `ClientRegistrationRepository` (such as a new client secret) would not have taken effect. Closes gh-15511
1 parent 30c9860 commit e3683eb

File tree

4 files changed

+173
-44
lines changed

4 files changed

+173
-44
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -80,7 +80,13 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
8080
if (registration == null) {
8181
return null;
8282
}
83-
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
83+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
84+
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
85+
if (cachedAuthorizedClient == null) {
86+
return null;
87+
}
88+
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
89+
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
8490
}
8591

8692
@Override

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,11 +32,11 @@
3232
*
3333
* @author Rob Winch
3434
* @author Vedran Pavic
35-
* @since 5.1
3635
* @see OAuth2AuthorizedClientService
3736
* @see OAuth2AuthorizedClient
3837
* @see ClientRegistration
3938
* @see Authentication
39+
* @since 5.1
4040
*/
4141
public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
4242

@@ -47,6 +47,7 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
4747
/**
4848
* Constructs an {@code InMemoryReactiveOAuth2AuthorizedClientService} using the
4949
* provided parameters.
50+
*
5051
* @param clientRegistrationRepository the repository of client registrations
5152
*/
5253
public InMemoryReactiveOAuth2AuthorizedClientService(
@@ -62,8 +63,19 @@ public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String cl
6263
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
6364
Assert.hasText(principalName, "principalName cannot be empty");
6465
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
65-
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
66-
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
66+
.mapNotNull((clientRegistration) -> {
67+
OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
68+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
69+
if (cachedAuthorizedClient == null) {
70+
return null;
71+
}
72+
// @formatter:off
73+
return new OAuth2AuthorizedClient(clientRegistration,
74+
cachedAuthorizedClient.getPrincipalName(),
75+
cachedAuthorizedClient.getAccessToken(),
76+
cachedAuthorizedClient.getRefreshToken());
77+
// @formatter:on
78+
});
6779
}
6880

6981
@Override

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@
3333
import static org.assertj.core.api.Assertions.assertThatObject;
3434
import static org.mockito.ArgumentMatchers.eq;
3535
import static org.mockito.BDDMockito.given;
36-
import static org.mockito.Mockito.mock;
36+
import static org.mockito.BDDMockito.mock;
3737

3838
/**
3939
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
@@ -43,23 +43,23 @@
4343
*/
4444
public class InMemoryOAuth2AuthorizedClientServiceTests {
4545

46-
private String principalName1 = "principal-1";
46+
private final String principalName1 = "principal-1";
4747

48-
private String principalName2 = "principal-2";
48+
private final String principalName2 = "principal-2";
4949

50-
private ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build();
50+
private final ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build();
5151

52-
private ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build();
52+
private final ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build();
5353

54-
private ClientRegistration registration3 = TestClientRegistrations.clientRegistration()
55-
.clientId("client-3")
56-
.registrationId("registration-3")
57-
.build();
54+
private final ClientRegistration registration3 = TestClientRegistrations.clientRegistration()
55+
.clientId("client-3")
56+
.registrationId("registration-3")
57+
.build();
5858

59-
private ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(
59+
private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(
6060
this.registration1, this.registration2, this.registration3);
6161

62-
private InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
62+
private final InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
6363
this.clientRegistrationRepository);
6464

6565
@Test
@@ -79,9 +79,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept
7979
@Test
8080
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
8181
String registrationId = this.registration3.getRegistrationId();
82+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
83+
mock(OAuth2AccessToken.class));
8284
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
8385
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
84-
mock(OAuth2AuthorizedClient.class));
86+
authorizedClient);
8587
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
8688
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
8789
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
@@ -92,7 +94,7 @@ public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedCli
9294
@Test
9395
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
9496
assertThatIllegalArgumentException()
95-
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, this.principalName1));
97+
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, this.principalName1));
9698
}
9799

98100
@Test
@@ -104,14 +106,14 @@ public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentE
104106
@Test
105107
public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
106108
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
107-
.loadAuthorizedClient("registration-not-found", this.principalName1);
109+
.loadAuthorizedClient("registration-not-found", this.principalName1);
108110
assertThat(authorizedClient).isNull();
109111
}
110112

111113
@Test
112114
public void loadAuthorizedClientWhenClientRegistrationFoundButNotAssociatedToPrincipalThenReturnNull() {
113115
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService
114-
.loadAuthorizedClient(this.registration1.getRegistrationId(), "principal-not-found");
116+
.loadAuthorizedClient(this.registration1.getRegistrationId(), "principal-not-found");
115117
assertThat(authorizedClient).isNull();
116118
}
117119

@@ -123,14 +125,42 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci
123125
mock(OAuth2AccessToken.class));
124126
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
125127
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
126-
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
127-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
128+
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
129+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
130+
}
131+
132+
@Test
133+
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
134+
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
135+
.clientSecret("updated secret")
136+
.build();
137+
ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
138+
given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
139+
updatedRegistration);
140+
141+
Authentication authentication = mock(Authentication.class);
142+
given(authentication.getName()).willReturn(this.principalName1);
143+
144+
InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);
145+
146+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
147+
mock(OAuth2AccessToken.class));
148+
service.saveAuthorizedClient(authorizedClient, authentication);
149+
150+
OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
151+
this.principalName1, mock(OAuth2AccessToken.class));
152+
OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
153+
this.principalName1);
154+
OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
155+
this.principalName1);
156+
assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
157+
assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
128158
}
129159

130160
@Test
131161
public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
132162
assertThatIllegalArgumentException()
133-
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class)));
163+
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class)));
134164
}
135165

136166
@Test
@@ -147,20 +177,20 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() {
147177
mock(OAuth2AccessToken.class));
148178
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
149179
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
150-
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
151-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
180+
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
181+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
152182
}
153183

154184
@Test
155185
public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
156186
assertThatIllegalArgumentException()
157-
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, this.principalName2));
187+
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, this.principalName2));
158188
}
159189

160190
@Test
161191
public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
162192
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientService
163-
.removeAuthorizedClient(this.registration3.getRegistrationId(), null));
193+
.removeAuthorizedClient(this.registration3.getRegistrationId(), null));
164194
}
165195

166196
@Test
@@ -171,13 +201,38 @@ public void removeAuthorizedClientWhenSavedThenRemoved() {
171201
mock(OAuth2AccessToken.class));
172202
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
173203
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
174-
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
204+
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
175205
assertThat(loadedAuthorizedClient).isNotNull();
176206
this.authorizedClientService.removeAuthorizedClient(this.registration2.getRegistrationId(),
177207
this.principalName2);
178208
loadedAuthorizedClient = this.authorizedClientService
179-
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
209+
.loadAuthorizedClient(this.registration2.getRegistrationId(), this.principalName2);
180210
assertThat(loadedAuthorizedClient).isNull();
181211
}
182212

213+
private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
214+
assertThat(actual).isNotNull();
215+
assertThat(actual.getClientRegistration().getRegistrationId())
216+
.isEqualTo(expected.getClientRegistration().getRegistrationId());
217+
assertThat(actual.getClientRegistration().getClientName())
218+
.isEqualTo(expected.getClientRegistration().getClientName());
219+
assertThat(actual.getClientRegistration().getRedirectUri())
220+
.isEqualTo(expected.getClientRegistration().getRedirectUri());
221+
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
222+
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
223+
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
224+
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
225+
assertThat(actual.getClientRegistration().getClientId())
226+
.isEqualTo(expected.getClientRegistration().getClientId());
227+
assertThat(actual.getClientRegistration().getClientSecret())
228+
.isEqualTo(expected.getClientRegistration().getClientSecret());
229+
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
230+
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
231+
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
232+
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
233+
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
234+
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
235+
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
236+
}
237+
183238
}

0 commit comments

Comments
 (0)