Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;

import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.util.Assert;

/**
* An {@link OAuth2AuthenticationException} that holds an
* {@link OAuth2ClientAuthenticationToken} and is used by an
* {@code AuthenticationFailureHandler} when handling a failed authentication attempt by
* an OAuth 2.0 Client.
*
* @author Joe Grandja
* @since 1.5
* @see OAuth2ClientAuthenticationToken
*/
public class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {

private final OAuth2ClientAuthenticationToken clientAuthentication;

/**
* Constructs an {@code OAuth2ClientAuthenticationException} using the provided
* parameters.
* @param error the {@link OAuth2Error OAuth 2.0 Error}
* @param clientAuthentication the {@link OAuth2ClientAuthenticationToken OAuth 2.0
* Client Authentication} request
*/
public OAuth2ClientAuthenticationException(OAuth2Error error,
OAuth2ClientAuthenticationToken clientAuthentication) {
super(error);
Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
this.clientAuthentication = clientAuthentication;
}

/**
* Constructs an {@code OAuth2ClientAuthenticationException} using the provided
* parameters.
* @param error the {@link OAuth2Error OAuth 2.0 Error}
* @param cause the root cause
* @param clientAuthentication the {@link OAuth2ClientAuthenticationToken OAuth 2.0
* Client Authentication} request
*/
public OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
OAuth2ClientAuthenticationToken clientAuthentication) {
super(error, cause);
Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
this.clientAuthentication = clientAuthentication;
}

/**
* Returns the {@link OAuth2ClientAuthenticationToken OAuth 2.0 Client Authentication}
* request.
* @return the {@link OAuth2ClientAuthenticationToken}
*/
public OAuth2ClientAuthenticationToken getClientAuthentication() {
return this.clientAuthentication;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,32 @@
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationException;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.PublicClientAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.X509ClientCertificateAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretBasicAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientAuthenticationFailureHandler;
import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.X509ClientCertificateAuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.DelegatingAuthenticationConverter;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
Expand All @@ -75,6 +70,7 @@
* @see ClientSecretAuthenticationProvider
* @see PublicClientAuthenticationConverter
* @see PublicClientAuthenticationProvider
* @see OAuth2ClientAuthenticationFailureHandler
* @see <a target="_blank" href=
* "https://datatracker.ietf.org/doc/html/rfc6749#section-2.3">Section 2.3 Client
* Authentication</a>
Expand All @@ -88,17 +84,13 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter

private final RequestMatcher requestMatcher;

private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();

private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();

private AuthenticationConverter authenticationConverter;

private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;

private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;
private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ClientAuthenticationFailureHandler();

/**
* Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided
Expand All @@ -114,7 +106,6 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
this.authenticationManager = authenticationManager;
this.requestMatcher = requestMatcher;
this.basicAuthenticationEntryPoint.setRealmName("default");
// @formatter:off
this.authenticationConverter = new DelegatingAuthenticationConverter(
Arrays.asList(
Expand All @@ -138,16 +129,16 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
Authentication authenticationRequest = null;
try {
authenticationRequest = this.authenticationConverter.convert(request);
if (authenticationRequest == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
}
if (authenticationRequest instanceof AbstractAuthenticationToken authenticationToken) {
authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
}
if (authenticationRequest != null) {
validateClientIdentifier(authenticationRequest);
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
}
validateClientIdentifier(authenticationRequest);
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
filterChain.doFilter(request, response);

}
catch (OAuth2AuthenticationException ex) {
if (this.logger.isTraceEnabled()) {
Expand All @@ -160,8 +151,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
else {
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
}

}

}

/**
Expand Down Expand Up @@ -211,35 +202,6 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp
}
}

private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException {

SecurityContextHolder.clearContext();

if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
.getClientAuthentication();
if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientAuthentication.getClientAuthenticationMethod())) {
this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
return;
}
}

OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
}
else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
// We don't want to reveal too much information to the caller so just return the
// error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
}

private static void validateClientIdentifier(Authentication authentication) {
if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
return;
Expand All @@ -261,21 +223,4 @@ private static void validateClientIdentifier(Authentication authentication) {
}
}

private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {

private final OAuth2ClientAuthenticationToken clientAuthentication;

private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
OAuth2ClientAuthenticationToken clientAuthentication) {
super(error, cause);
Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
this.clientAuthentication = clientAuthentication;
}

private OAuth2ClientAuthenticationToken getClientAuthentication() {
return this.clientAuthentication;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web.authentication;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationException;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
import org.springframework.util.Assert;

/**
* An implementation of an {@link AuthenticationFailureHandler} used for handling a failed
* authentication attempt by an OAuth 2.0 Client and delegating to an
* {@link AuthenticationEntryPoint} based on the {@link ClientAuthenticationMethod} used
* by the client.
*
* @author Joe Grandja
* @since 1.5
* @see AuthenticationFailureHandler
* @see AuthenticationEntryPoint
* @see OAuth2ClientAuthenticationFilter
* @see OAuth2ClientAuthenticationException
*/
public final class OAuth2ClientAuthenticationFailureHandler implements AuthenticationFailureHandler {

private final Map<ClientAuthenticationMethod, AuthenticationEntryPoint> authenticationEntryPoints;

private AuthenticationEntryPoint defaultAuthenticationEntryPoint = new DefaultAuthenticationEntryPoint();

public OAuth2ClientAuthenticationFailureHandler() {
this.authenticationEntryPoints = new HashMap<>();
BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
basicAuthenticationEntryPoint.setRealmName("default");
this.authenticationEntryPoints.put(ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
basicAuthenticationEntryPoint);
}

@Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException, ServletException {
SecurityContextHolder.clearContext();
AuthenticationEntryPoint authenticationEntryPoint = this.defaultAuthenticationEntryPoint;
if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
.getClientAuthentication();
AuthenticationEntryPoint clientAuthenticationMethodEntryPoint = this.authenticationEntryPoints
.get(clientAuthentication.getClientAuthenticationMethod());
if (clientAuthenticationMethodEntryPoint != null) {
// Override the default
authenticationEntryPoint = clientAuthenticationMethodEntryPoint;
}
}
authenticationEntryPoint.commence(request, response, authenticationException);
}

/**
* Sets the {@link AuthenticationEntryPoint} used for the specified
* {@link ClientAuthenticationMethod}.
* @param authenticationEntryPoint the {@link AuthenticationEntryPoint}
* @param clientAuthenticationMethod the {@link ClientAuthenticationMethod}
*/
public void setAuthenticationEntryPointFor(AuthenticationEntryPoint authenticationEntryPoint,
ClientAuthenticationMethod clientAuthenticationMethod) {
Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null");
Assert.notNull(clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
this.authenticationEntryPoints.put(clientAuthenticationMethod, authenticationEntryPoint);
}

/**
* Sets the default {@link AuthenticationEntryPoint} used when unable to determine the
* {@link ClientAuthenticationMethod} used by the client.
* @param defaultAuthenticationEntryPoint the default {@link AuthenticationEntryPoint}
*/
public void setDefaultAuthenticationEntryPoint(AuthenticationEntryPoint defaultAuthenticationEntryPoint) {
Assert.notNull(defaultAuthenticationEntryPoint, "defaultAuthenticationEntryPoint cannot be null");
this.defaultAuthenticationEntryPoint = defaultAuthenticationEntryPoint;
}

private static final class DefaultAuthenticationEntryPoint implements AuthenticationEntryPoint {

private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

@Override
public void commence(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
}
else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
// We don't want to reveal too much information to the caller
// so just return the error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
}

}

}
Loading