Skip to content

Commit 00c114c

Browse files
committed
ID Token contains sid claim after refresh_token grant
Closes gh-1224
1 parent ece9f10 commit 00c114c

File tree

4 files changed

+139
-16
lines changed

4 files changed

+139
-16
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.time.Instant;
1919
import java.time.temporal.ChronoUnit;
2020
import java.util.Collections;
21+
import java.util.Date;
2122

2223
import org.springframework.lang.Nullable;
2324
import org.springframework.security.core.session.SessionInformation;
@@ -126,11 +127,15 @@ public Jwt generate(OAuth2TokenContext context) {
126127
if (StringUtils.hasText(nonce)) {
127128
claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
128129
}
129-
}
130-
SessionInformation sessionInformation = context.get(SessionInformation.class);
131-
if (sessionInformation != null) {
132-
claimsBuilder.claim("sid", sessionInformation.getSessionId());
133-
claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest());
130+
SessionInformation sessionInformation = context.get(SessionInformation.class);
131+
if (sessionInformation != null) {
132+
claimsBuilder.claim("sid", sessionInformation.getSessionId());
133+
claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest());
134+
}
135+
} else if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
136+
OidcIdToken currentIdToken = context.getAuthorization().getToken(OidcIdToken.class).getToken();
137+
claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
138+
claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
134139
}
135140
}
136141
// @formatter:on

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.time.Instant;
2020
import java.time.temporal.ChronoUnit;
2121
import java.util.Collections;
22+
import java.util.Date;
2223
import java.util.HashMap;
2324
import java.util.HashSet;
2425
import java.util.Map;
@@ -38,6 +39,7 @@
3839
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3940
import org.springframework.security.oauth2.core.OAuth2Token;
4041
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
42+
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
4143
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
4244
import org.springframework.security.oauth2.core.oidc.OidcScopes;
4345
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
@@ -196,7 +198,15 @@ public void authenticateWhenValidRefreshTokenThenReturnAccessToken() {
196198
@Test
197199
public void authenticateWhenValidRefreshTokenThenReturnIdToken() {
198200
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
199-
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
201+
OidcIdToken authorizedIdToken = OidcIdToken.withTokenValue("id-token")
202+
.issuer("https://provider.com")
203+
.subject("subject")
204+
.issuedAt(Instant.now())
205+
.expiresAt(Instant.now().plusSeconds(60))
206+
.claim("sid", "sessionId-1234")
207+
.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
208+
.build();
209+
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).token(authorizedIdToken).build();
200210
when(this.authorizationService.findByToken(
201211
eq(authorization.getRefreshToken().getToken().getTokenValue()),
202212
eq(OAuth2TokenType.REFRESH_TOKEN)))

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,63 @@ public void requestWhenAuthenticationRequestThenTokenResponseIncludesIdToken() t
237237
assertThat(idToken.<String>getClaim("sid")).isNotNull();
238238
}
239239

240+
// gh-1224
241+
@Test
242+
public void requestWhenRefreshTokenRequestThenIdTokenContainsSidClaim() throws Exception {
243+
this.spring.register(AuthorizationServerConfiguration.class).autowire();
244+
245+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
246+
this.registeredClientRepository.save(registeredClient);
247+
248+
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
249+
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
250+
.params(authorizationRequestParameters)
251+
.with(user("user").roles("A", "B")))
252+
.andExpect(status().is3xxRedirection())
253+
.andReturn();
254+
String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
255+
String expectedRedirectUri = authorizationRequestParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
256+
assertThat(redirectedUrl).matches(expectedRedirectUri + "\\?code=.{15,}&state=state");
257+
258+
String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
259+
OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
260+
261+
mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
262+
.params(getTokenRequestParameters(registeredClient, authorization))
263+
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
264+
registeredClient.getClientId(), registeredClient.getClientSecret())))
265+
.andExpect(status().isOk())
266+
.andReturn();
267+
268+
MockHttpServletResponse servletResponse = mvcResult.getResponse();
269+
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
270+
servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
271+
OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
272+
273+
Jwt idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
274+
275+
String sidClaim = idToken.getClaim("sid");
276+
assertThat(sidClaim).isNotNull();
277+
278+
// Refresh access token
279+
mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
280+
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue())
281+
.param(OAuth2ParameterNames.REFRESH_TOKEN, accessTokenResponse.getRefreshToken().getTokenValue())
282+
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
283+
registeredClient.getClientId(), registeredClient.getClientSecret())))
284+
.andExpect(status().isOk())
285+
.andReturn();
286+
287+
servletResponse = mvcResult.getResponse();
288+
httpResponse = new MockClientHttpResponse(
289+
servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
290+
accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
291+
292+
idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
293+
294+
assertThat(idToken.<String>getClaim("sid")).isEqualTo(sidClaim);
295+
}
296+
240297
@Test
241298
public void requestWhenLogoutRequestThenLogout() throws Exception {
242299
this.spring.register(AuthorizationServerConfiguration.class).autowire();

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
import org.springframework.security.core.session.SessionInformation;
3232
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3333
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
34+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3435
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3536
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3637
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
38+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
3739
import org.springframework.security.oauth2.core.oidc.OidcScopes;
3840
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
3941
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -46,6 +48,7 @@
4648
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
4749
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
4850
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
51+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken;
4952
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
5053
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
5154
import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext;
@@ -152,7 +155,7 @@ public void generateWhenAccessTokenTypeThenReturnJwt() {
152155
}
153156

154157
@Test
155-
public void generateWhenIdTokenTypeThenReturnJwt() {
158+
public void generateWhenIdTokenTypeAndAuthorizationCodeGrantThenReturnJwt() {
156159
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
157160
.scope(OidcScopes.OPENID)
158161
.tokenSettings(TokenSettings.builder().idTokenSignatureAlgorithm(SignatureAlgorithm.ES256).build())
@@ -190,6 +193,49 @@ public void generateWhenIdTokenTypeThenReturnJwt() {
190193
assertGeneratedTokenType(tokenContext);
191194
}
192195

196+
// gh-1224
197+
@Test
198+
public void generateWhenIdTokenTypeAndRefreshTokenGrantThenReturnJwt() {
199+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
200+
.scope(OidcScopes.OPENID)
201+
.build();
202+
OidcIdToken idToken = OidcIdToken.withTokenValue("id-token")
203+
.issuer("https://provider.com")
204+
.subject("subject")
205+
.issuedAt(Instant.now())
206+
.expiresAt(Instant.now().plusSeconds(60))
207+
.claim("sid", "sessionId-1234")
208+
.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
209+
.build();
210+
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
211+
.token(idToken)
212+
.build();
213+
214+
OAuth2RefreshToken refreshToken = authorization.getRefreshToken().getToken();
215+
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
216+
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
217+
218+
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
219+
refreshToken.getTokenValue(), clientPrincipal, null, null);
220+
221+
Authentication principal = authorization.getAttribute(Principal.class.getName());
222+
223+
// @formatter:off
224+
OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
225+
.registeredClient(registeredClient)
226+
.principal(principal)
227+
.authorizationServerContext(this.authorizationServerContext)
228+
.authorization(authorization)
229+
.authorizedScopes(authorization.getAuthorizedScopes())
230+
.tokenType(ID_TOKEN_TOKEN_TYPE)
231+
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
232+
.authorizationGrant(authentication)
233+
.build();
234+
// @formatter:on
235+
236+
assertGeneratedTokenType(tokenContext);
237+
}
238+
193239
private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) {
194240
this.jwtGenerator.generate(tokenContext);
195241

@@ -239,15 +285,20 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) {
239285
assertThat(scopes).isEqualTo(tokenContext.getAuthorizedScopes());
240286
} else {
241287
assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.AZP)).isEqualTo(tokenContext.getRegisteredClient().getClientId());
242-
243-
OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute(
244-
OAuth2AuthorizationRequest.class.getName());
245-
String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
246-
assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce);
247-
248-
SessionInformation sessionInformation = tokenContext.get(SessionInformation.class);
249-
assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(sessionInformation.getSessionId());
250-
assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(sessionInformation.getLastRequest());
288+
if (tokenContext.getAuthorizationGrantType().equals(AuthorizationGrantType.AUTHORIZATION_CODE)) {
289+
OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute(
290+
OAuth2AuthorizationRequest.class.getName());
291+
String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
292+
assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce);
293+
294+
SessionInformation sessionInformation = tokenContext.get(SessionInformation.class);
295+
assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(sessionInformation.getSessionId());
296+
assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(sessionInformation.getLastRequest());
297+
} else if (tokenContext.getAuthorizationGrantType().equals(AuthorizationGrantType.REFRESH_TOKEN)) {
298+
OidcIdToken currentIdToken = tokenContext.getAuthorization().getToken(OidcIdToken.class).getToken();
299+
assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(currentIdToken.getClaim("sid"));
300+
assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
301+
}
251302
}
252303
}
253304

0 commit comments

Comments
 (0)