Skip to content

Commit 8c2b095

Browse files
committed
Extract JwtDecoderFactory from JwtClientAssertionAuthenticationProvider
Closes gh-944
1 parent b1b2bc4 commit 8c2b095

File tree

4 files changed

+335
-217
lines changed

4 files changed

+335
-217
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProvider.java

Lines changed: 18 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -15,53 +15,27 @@
1515
*/
1616
package org.springframework.security.oauth2.server.authorization.authentication;
1717

18-
import java.nio.charset.StandardCharsets;
19-
import java.util.ArrayList;
20-
import java.util.Collections;
21-
import java.util.HashMap;
22-
import java.util.List;
23-
import java.util.Map;
24-
import java.util.Objects;
25-
import java.util.concurrent.ConcurrentHashMap;
26-
import java.util.function.Predicate;
27-
28-
import javax.crypto.spec.SecretKeySpec;
29-
3018
import org.springframework.security.authentication.AuthenticationProvider;
3119
import org.springframework.security.core.Authentication;
3220
import org.springframework.security.core.AuthenticationException;
3321
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
34-
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
3522
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3623
import org.springframework.security.oauth2.core.OAuth2Error;
3724
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
38-
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
3925
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
40-
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
41-
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
4226
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
4327
import org.springframework.security.oauth2.jwt.Jwt;
44-
import org.springframework.security.oauth2.jwt.JwtClaimNames;
45-
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
4628
import org.springframework.security.oauth2.jwt.JwtDecoder;
4729
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
4830
import org.springframework.security.oauth2.jwt.JwtException;
49-
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
50-
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
5131
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
5232
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
5333
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
54-
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext;
55-
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
56-
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
5734
import org.springframework.util.Assert;
58-
import org.springframework.util.CollectionUtils;
59-
import org.springframework.util.StringUtils;
60-
import org.springframework.web.util.UriComponentsBuilder;
6135

6236
/**
6337
* An {@link AuthenticationProvider} implementation used for OAuth 2.0 Client Authentication,
64-
* which authenticates the (JWT) {@link OAuth2ParameterNames#CLIENT_ASSERTION client_assertion} parameter.
38+
* which authenticates the {@link Jwt} {@link OAuth2ParameterNames#CLIENT_ASSERTION client_assertion} parameter.
6539
*
6640
* @author Rafal Lewczuk
6741
* @author Joe Grandja
@@ -70,14 +44,15 @@
7044
* @see OAuth2ClientAuthenticationToken
7145
* @see RegisteredClientRepository
7246
* @see OAuth2AuthorizationService
47+
* @see JwtClientAssertionDecoderFactory
7348
*/
7449
public final class JwtClientAssertionAuthenticationProvider implements AuthenticationProvider {
7550
private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
7651
private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD =
7752
new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
7853
private final RegisteredClientRepository registeredClientRepository;
7954
private final CodeVerifierAuthenticator codeVerifierAuthenticator;
80-
private final JwtClientAssertionDecoderFactory jwtClientAssertionDecoderFactory;
55+
private JwtDecoderFactory<RegisteredClient> jwtDecoderFactory;
8156

8257
/**
8358
* Constructs a {@code JwtClientAssertionAuthenticationProvider} using the provided parameters.
@@ -91,7 +66,7 @@ public JwtClientAssertionAuthenticationProvider(RegisteredClientRepository regis
9166
Assert.notNull(authorizationService, "authorizationService cannot be null");
9267
this.registeredClientRepository = registeredClientRepository;
9368
this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService);
94-
this.jwtClientAssertionDecoderFactory = new JwtClientAssertionDecoderFactory();
69+
this.jwtDecoderFactory = new JwtClientAssertionDecoderFactory();
9570
}
9671

9772
@Override
@@ -119,7 +94,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
11994
}
12095

12196
Jwt jwtAssertion = null;
122-
JwtDecoder jwtDecoder = this.jwtClientAssertionDecoderFactory.createDecoder(registeredClient);
97+
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient);
12398
try {
12499
jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString());
125100
} catch (JwtException ex) {
@@ -142,6 +117,19 @@ public boolean supports(Class<?> authentication) {
142117
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
143118
}
144119

120+
/**
121+
* Sets the {@link JwtDecoderFactory} that provides a {@link JwtDecoder} for the specified {@link RegisteredClient}
122+
* and is used for authenticating a {@link Jwt} Bearer Token during OAuth 2.0 Client Authentication.
123+
* The default factory is {@link JwtClientAssertionDecoderFactory}.
124+
*
125+
* @param jwtDecoderFactory the {@link JwtDecoderFactory} that provides a {@link JwtDecoder} for the specified {@link RegisteredClient}
126+
* @since 0.4.0
127+
*/
128+
public void setJwtDecoderFactory(JwtDecoderFactory<RegisteredClient> jwtDecoderFactory) {
129+
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
130+
this.jwtDecoderFactory = jwtDecoderFactory;
131+
}
132+
145133
private static void throwInvalidClient(String parameterName) {
146134
throwInvalidClient(parameterName, null);
147135
}
@@ -155,112 +143,4 @@ private static void throwInvalidClient(String parameterName, Throwable cause) {
155143
throw new OAuth2AuthenticationException(error, error.toString(), cause);
156144
}
157145

158-
private static class JwtClientAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
159-
private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
160-
161-
private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
162-
163-
static {
164-
Map<JwsAlgorithm, String> mappings = new HashMap<>();
165-
mappings.put(MacAlgorithm.HS256, "HmacSHA256");
166-
mappings.put(MacAlgorithm.HS384, "HmacSHA384");
167-
mappings.put(MacAlgorithm.HS512, "HmacSHA512");
168-
JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
169-
}
170-
171-
private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
172-
173-
@Override
174-
public JwtDecoder createDecoder(RegisteredClient registeredClient) {
175-
Assert.notNull(registeredClient, "registeredClient cannot be null");
176-
return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), (key) -> {
177-
NimbusJwtDecoder jwtDecoder = buildDecoder(registeredClient);
178-
jwtDecoder.setJwtValidator(createJwtValidator(registeredClient));
179-
return jwtDecoder;
180-
});
181-
}
182-
183-
private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
184-
JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm();
185-
if (jwsAlgorithm instanceof SignatureAlgorithm) {
186-
String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
187-
if (!StringUtils.hasText(jwkSetUrl)) {
188-
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
189-
"Failed to find a Signature Verifier for Client: '"
190-
+ registeredClient.getId()
191-
+ "'. Check to ensure you have configured the JWK Set URL.",
192-
JWT_CLIENT_AUTHENTICATION_ERROR_URI);
193-
throw new OAuth2AuthenticationException(oauth2Error);
194-
}
195-
return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
196-
}
197-
if (jwsAlgorithm instanceof MacAlgorithm) {
198-
String clientSecret = registeredClient.getClientSecret();
199-
if (!StringUtils.hasText(clientSecret)) {
200-
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
201-
"Failed to find a Signature Verifier for Client: '"
202-
+ registeredClient.getId()
203-
+ "'. Check to ensure you have configured the client secret.",
204-
JWT_CLIENT_AUTHENTICATION_ERROR_URI);
205-
throw new OAuth2AuthenticationException(oauth2Error);
206-
}
207-
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
208-
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
209-
return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
210-
}
211-
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
212-
"Failed to find a Signature Verifier for Client: '"
213-
+ registeredClient.getId()
214-
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.",
215-
JWT_CLIENT_AUTHENTICATION_ERROR_URI);
216-
throw new OAuth2AuthenticationException(oauth2Error);
217-
}
218-
219-
private static OAuth2TokenValidator<Jwt> createJwtValidator(RegisteredClient registeredClient) {
220-
String clientId = registeredClient.getClientId();
221-
return new DelegatingOAuth2TokenValidator<>(
222-
new JwtClaimValidator<>(JwtClaimNames.ISS, clientId::equals),
223-
new JwtClaimValidator<>(JwtClaimNames.SUB, clientId::equals),
224-
new JwtClaimValidator<>(JwtClaimNames.AUD, containsAudience()),
225-
new JwtClaimValidator<>(JwtClaimNames.EXP, Objects::nonNull),
226-
new JwtTimestampValidator()
227-
);
228-
}
229-
230-
private static Predicate<List<String>> containsAudience() {
231-
return (audienceClaim) -> {
232-
if (CollectionUtils.isEmpty(audienceClaim)) {
233-
return false;
234-
}
235-
List<String> audienceList = getAudience();
236-
for (String audience : audienceClaim) {
237-
if (audienceList.contains(audience)) {
238-
return true;
239-
}
240-
}
241-
return false;
242-
};
243-
}
244-
245-
private static List<String> getAudience() {
246-
AuthorizationServerContext authorizationServerContext = AuthorizationServerContextHolder.getContext();
247-
if (!StringUtils.hasText(authorizationServerContext.getIssuer())) {
248-
return Collections.emptyList();
249-
}
250-
251-
AuthorizationServerSettings authorizationServerSettings = authorizationServerContext.getAuthorizationServerSettings();
252-
List<String> audience = new ArrayList<>();
253-
audience.add(authorizationServerContext.getIssuer());
254-
audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenEndpoint()));
255-
audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenIntrospectionEndpoint()));
256-
audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenRevocationEndpoint()));
257-
return audience;
258-
}
259-
260-
private static String asUrl(String issuer, String endpoint) {
261-
return UriComponentsBuilder.fromUriString(issuer).path(endpoint).build().toUriString();
262-
}
263-
264-
}
265-
266146
}

0 commit comments

Comments
 (0)