Skip to content

Added the ability to pass in a parameter when using JwtIssuerAuthenticationManagerResolver #17748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
import org.springframework.core.convert.converter.Converter;
import org.springframework.core.log.LogMessage;
import org.springframework.lang.NonNull;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationManagerResolver;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoders;
import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException;
Expand Down Expand Up @@ -93,7 +95,40 @@ public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Collecti
public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Predicate<String> trustedIssuers) {
Assert.notNull(trustedIssuers, "trustedIssuers cannot be null");
return new JwtIssuerAuthenticationManagerResolver(
new TrustedIssuerJwtAuthenticationManagerResolver(trustedIssuers));
new TrustedIssuerJwtAuthenticationManagerResolver(null, trustedIssuers));
}

/**
* Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided
* parameters
* @param trustedIssuers an array of trusted issuers
* @since 6.2
*/
public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter, String... trustedIssuers) {
return fromTrustedIssuers(jwtAuthenticationConverter, Set.of(trustedIssuers));
}

/**
* Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided
* parameters
* @param trustedIssuers a collection of trusted issuers
* @since 6.2
*/
public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter, Collection<String> trustedIssuers) {
Assert.notEmpty(trustedIssuers, "trustedIssuers cannot be empty");
return fromTrustedIssuers(jwtAuthenticationConverter, Set.copyOf(trustedIssuers)::contains);
}

/**
* Construct a {@link JwtIssuerAuthenticationManagerResolver} using the provided
* parameters
* @param trustedIssuers a predicate to validate issuers
* @since 6.2
*/
public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter, Predicate<String> trustedIssuers) {
Assert.notNull(trustedIssuers, "trustedIssuers cannot be null");
return new JwtIssuerAuthenticationManagerResolver(
new TrustedIssuerJwtAuthenticationManagerResolver(jwtAuthenticationConverter, trustedIssuers));
}

/**
Expand All @@ -117,6 +152,7 @@ public static JwtIssuerAuthenticationManagerResolver fromTrustedIssuers(Predicat
* {@link AuthenticationManager} by the issuer
*/
public JwtIssuerAuthenticationManagerResolver(

AuthenticationManagerResolver<String> issuerAuthenticationManagerResolver) {
Assert.notNull(issuerAuthenticationManagerResolver, "issuerAuthenticationManagerResolver cannot be null");
this.authenticationManager = new ResolvingAuthenticationManager(issuerAuthenticationManagerResolver);
Expand Down Expand Up @@ -197,7 +233,14 @@ static class TrustedIssuerJwtAuthenticationManagerResolver implements Authentica

private final Predicate<String> trustedIssuer;

private final Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter;

TrustedIssuerJwtAuthenticationManagerResolver(Predicate<String> trustedIssuer) {
this(null, trustedIssuer);
}

TrustedIssuerJwtAuthenticationManagerResolver(Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter, Predicate<String> trustedIssuer) {
this.jwtAuthenticationConverter = jwtAuthenticationConverter;
this.trustedIssuer = trustedIssuer;
}

Expand All @@ -208,7 +251,11 @@ public AuthenticationManager resolve(String issuer) {
(k) -> {
this.logger.debug("Constructing AuthenticationManager");
JwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuer);
return new JwtAuthenticationProvider(jwtDecoder)::authenticate;
JwtAuthenticationProvider provider = new JwtAuthenticationProvider(jwtDecoder);
if (jwtAuthenticationConverter != null) {
provider.setJwtAuthenticationConverter(jwtAuthenticationConverter);
}
return provider::authenticate;
});
this.logger.debug(LogMessage.format("Resolved AuthenticationManager for issuer '%s'", issuer));
return authenticationManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.springframework.security.oauth2.server.resource.authentication;

import java.util.Collection;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -29,17 +29,22 @@
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import jakarta.servlet.http.HttpServletRequest;
import net.minidev.json.JSONObject;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.Test;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationManagerResolver;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException;
import org.springframework.security.oauth2.server.resource.authentication.JwtIssuerAuthenticationManagerResolver.TrustedIssuerJwtAuthenticationManagerResolver;
Expand All @@ -51,6 +56,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.verify;
import static org.mockito.Mockito.when;

/**
* Tests for {@link JwtIssuerAuthenticationManagerResolver}
Expand Down Expand Up @@ -267,6 +273,47 @@ public void resolveWhenBearerTokenEvilThenGenericException() {
// @formatter:on
}

@Test
public void testFactoryMethodWithConverter() throws Exception {
Converter<Jwt, ? extends AbstractAuthenticationToken> converter = new Converter<Jwt, AbstractAuthenticationToken>() {
@Override
public @Nullable AbstractAuthenticationToken convert(Jwt source) {
JwtAuthenticationToken authenticationToken = new JwtAuthenticationToken(source, Collections.emptyList(), "test_translated_name");
authenticationToken.setDetails("test_translated_details");
return authenticationToken;
}
};
try (MockWebServer server = new MockWebServer()) {
server.start();
String issuer = server.url("").toString();
// @formatter:off
server.enqueue(new MockResponse().setResponseCode(200)
.setHeader("Content-Type", "application/json")
.setBody(String.format(DEFAULT_RESPONSE_TEMPLATE, issuer, issuer)
));
server.enqueue(new MockResponse().setResponseCode(200)
.setHeader("Content-Type", "application/json")
.setBody(JWK_SET)
);
server.enqueue(new MockResponse().setResponseCode(200)
.setHeader("Content-Type", "application/json")
.setBody(JWK_SET)
);
// @formatter:on
JWSObject jws = new JWSObject(new JWSHeader(JWSAlgorithm.RS256),
new Payload(new JSONObject(Collections.singletonMap(JwtClaimNames.ISS, issuer))));
jws.sign(new RSASSASigner(TestKeys.DEFAULT_PRIVATE_KEY));
JwtIssuerAuthenticationManagerResolver authenticationManagerResolver = JwtIssuerAuthenticationManagerResolver
.fromTrustedIssuers(converter, issuer);
Authentication token = withBearerToken(jws.serialize());
AuthenticationManager authenticationManager = authenticationManagerResolver.resolve(null);
assertThat(authenticationManager).isNotNull();
Authentication authentication = authenticationManager.authenticate(token);
assertThat(authentication.isAuthenticated()).isTrue();
assertThat(authentication.getDetails()).isEqualTo("test_translated_details");
}
}

@Test
public void resolveWhenAuthenticationExceptionThenAuthenticationRequestIsIncluded() {
Authentication authentication = new BearerTokenAuthenticationToken(this.jwt);
Expand Down