| 
16 | 16 | 
 
  | 
17 | 17 | package org.springframework.security.oauth2.client.oidc.authentication;  | 
18 | 18 | 
 
  | 
 | 19 | +import java.util.Map;  | 
 | 20 | + | 
19 | 21 | import org.springframework.context.ApplicationListener;  | 
20 | 22 | import org.springframework.security.core.Authentication;  | 
 | 23 | +import org.springframework.security.core.context.SecurityContext;  | 
21 | 24 | import org.springframework.security.core.context.SecurityContextHolder;  | 
 | 25 | +import org.springframework.security.core.context.SecurityContextHolderStrategy;  | 
22 | 26 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;  | 
23 | 27 | import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;  | 
24 | 28 | import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;  | 
 | 29 | +import org.springframework.security.oauth2.client.registration.ClientRegistration;  | 
 | 30 | +import org.springframework.security.oauth2.core.OAuth2AuthenticationException;  | 
 | 31 | +import org.springframework.security.oauth2.core.OAuth2Error;  | 
25 | 32 | import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;  | 
26 | 33 | import org.springframework.security.oauth2.core.oidc.OidcIdToken;  | 
27 | 34 | import org.springframework.security.oauth2.core.oidc.StandardClaimNames;  | 
 | 35 | +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;  | 
28 | 36 | import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;  | 
29 | 37 | import org.springframework.security.oauth2.core.oidc.user.OidcUser;  | 
 | 38 | +import org.springframework.security.oauth2.jwt.Jwt;  | 
 | 39 | +import org.springframework.security.oauth2.jwt.JwtDecoder;  | 
 | 40 | +import org.springframework.security.oauth2.jwt.JwtDecoderFactory;  | 
 | 41 | +import org.springframework.security.oauth2.jwt.JwtException;  | 
 | 42 | +import org.springframework.util.Assert;  | 
30 | 43 | 
 
  | 
31 | 44 | /**  | 
32 | 45 |  * An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s  | 
33 | 46 |  */  | 
34 | 47 | public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {  | 
35 | 48 | 
 
  | 
36 |  | -	private final OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider;  | 
 | 49 | +	private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";  | 
37 | 50 | 
 
  | 
38 |  | -	public RefreshOidcIdTokenHandler(  | 
39 |  | -			OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider) {  | 
40 |  | -		this.oidcAuthorizationCodeAuthenticationProvider = oidcAuthorizationCodeAuthenticationProvider;  | 
41 |  | -	}  | 
 | 51 | +	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder  | 
 | 52 | +		.getContextHolderStrategy();  | 
 | 53 | + | 
 | 54 | +	private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new OidcIdTokenDecoderFactory();  | 
42 | 55 | 
 
  | 
43 | 56 | 	@Override  | 
44 | 57 | 	public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {  | 
45 | 58 | 		OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();  | 
46 | 59 | 		OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();  | 
47 |  | -		OidcIdToken refreshedOidcToken = this.oidcAuthorizationCodeAuthenticationProvider  | 
48 |  | -			.createOidcToken(authorizedClient.getClientRegistration(), accessTokenResponse);  | 
49 |  | -		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();  | 
 | 60 | +		ClientRegistration clientRegistration = authorizedClient.getClientRegistration();  | 
 | 61 | +		OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);  | 
 | 62 | +		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();  | 
50 | 63 | 		if (authentication instanceof OAuth2AuthenticationToken oauth2AuthenticationToken) {  | 
51 | 64 | 			if (authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser) {  | 
52 | 65 | 				OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,  | 
53 | 66 | 						defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);  | 
54 |  | -				SecurityContextHolder.getContext()  | 
55 |  | -					.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),  | 
56 |  | -							oauth2AuthenticationToken.getAuthorizedClientRegistrationId()));  | 
 | 67 | +				SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();  | 
 | 68 | +				context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),  | 
 | 69 | +						oauth2AuthenticationToken.getAuthorizedClientRegistrationId()));  | 
 | 70 | +				this.securityContextHolderStrategy.setContext(context);  | 
57 | 71 | 			}  | 
58 | 72 | 		}  | 
59 | 73 | 	}  | 
60 | 74 | 
 
  | 
 | 75 | +	/**  | 
 | 76 | +	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use  | 
 | 77 | +	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.  | 
 | 78 | +	 */  | 
 | 79 | +	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {  | 
 | 80 | +		this.securityContextHolderStrategy = securityContextHolderStrategy;  | 
 | 81 | +	}  | 
 | 82 | + | 
 | 83 | +	/**  | 
 | 84 | +	 * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature  | 
 | 85 | +	 * verification. The factory returns a {@link JwtDecoder} associated to the provided  | 
 | 86 | +	 * {@link ClientRegistration}.  | 
 | 87 | +	 * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken}  | 
 | 88 | +	 * signature verification  | 
 | 89 | +	 */  | 
 | 90 | +	public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {  | 
 | 91 | +		Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");  | 
 | 92 | +		this.jwtDecoderFactory = jwtDecoderFactory;  | 
 | 93 | +	}  | 
 | 94 | + | 
 | 95 | +	private OidcIdToken createOidcToken(ClientRegistration clientRegistration,  | 
 | 96 | +			OAuth2AccessTokenResponse accessTokenResponse) {  | 
 | 97 | +		JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);  | 
 | 98 | +		Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);  | 
 | 99 | +		return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());  | 
 | 100 | +	}  | 
 | 101 | + | 
 | 102 | +	private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) {  | 
 | 103 | +		try {  | 
 | 104 | +			Map<String, Object> parameters = accessTokenResponse.getAdditionalParameters();  | 
 | 105 | +			return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN));  | 
 | 106 | +		}  | 
 | 107 | +		catch (JwtException ex) {  | 
 | 108 | +			OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);  | 
 | 109 | +			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);  | 
 | 110 | +		}  | 
 | 111 | +	}  | 
 | 112 | + | 
61 | 113 | }  | 
0 commit comments