3434import  org .springframework .security .core .AuthenticationException ;
3535import  org .springframework .security .core .context .SecurityContext ;
3636import  org .springframework .security .core .context .SecurityContextHolder ;
37+ import  org .springframework .security .oauth2 .core .ClientAuthenticationMethod ;
3738import  org .springframework .security .oauth2 .core .OAuth2AuthenticationException ;
3839import  org .springframework .security .oauth2 .core .OAuth2Error ;
3940import  org .springframework .security .oauth2 .core .OAuth2ErrorCodes ;
5354import  org .springframework .security .web .authentication .AuthenticationSuccessHandler ;
5455import  org .springframework .security .web .authentication .DelegatingAuthenticationConverter ;
5556import  org .springframework .security .web .authentication .WebAuthenticationDetailsSource ;
57+ import  org .springframework .security .web .authentication .www .BasicAuthenticationEntryPoint ;
5658import  org .springframework .security .web .util .matcher .RequestMatcher ;
5759import  org .springframework .util .Assert ;
5860import  org .springframework .web .filter .OncePerRequestFilter ;
@@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
9092
9193	private  final  AuthenticationDetailsSource <HttpServletRequest , ?> authenticationDetailsSource  = new  WebAuthenticationDetailsSource ();
9294
95+ 	private  final  BasicAuthenticationEntryPoint  basicAuthenticationEntryPoint  = new  BasicAuthenticationEntryPoint ();
96+ 
9397	private  AuthenticationConverter  authenticationConverter ;
9498
9599	private  AuthenticationSuccessHandler  authenticationSuccessHandler  = this ::onAuthenticationSuccess ;
@@ -110,6 +114,7 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana
110114		Assert .notNull (requestMatcher , "requestMatcher cannot be null" );
111115		this .authenticationManager  = authenticationManager ;
112116		this .requestMatcher  = requestMatcher ;
117+ 		this .basicAuthenticationEntryPoint .setRealmName ("default" );
113118		// @formatter:off 
114119		this .authenticationConverter  = new  DelegatingAuthenticationConverter (
115120				Arrays .asList (
@@ -130,8 +135,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
130135			return ;
131136		}
132137
138+ 		Authentication  authenticationRequest  = null ;
133139		try  {
134- 			Authentication   authenticationRequest  = this .authenticationConverter .convert (request );
140+ 			authenticationRequest  = this .authenticationConverter .convert (request );
135141			if  (authenticationRequest  instanceof  AbstractAuthenticationToken  authenticationToken ) {
136142				authenticationToken .setDetails (this .authenticationDetailsSource .buildDetails (request ));
137143			}
@@ -147,7 +153,14 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
147153			if  (this .logger .isTraceEnabled ()) {
148154				this .logger .trace (LogMessage .format ("Client authentication failed: %s" , ex .getError ()), ex );
149155			}
150- 			this .authenticationFailureHandler .onAuthenticationFailure (request , response , ex );
156+ 			if  (authenticationRequest  instanceof  OAuth2ClientAuthenticationToken  clientAuthentication ) {
157+ 				this .authenticationFailureHandler .onAuthenticationFailure (request , response ,
158+ 						new  OAuth2ClientAuthenticationException (ex .getError (), ex , clientAuthentication ));
159+ 			}
160+ 			else  {
161+ 				this .authenticationFailureHandler .onAuthenticationFailure (request , response , ex );
162+ 			}
163+ 
151164		}
152165	}
153166
@@ -199,21 +212,21 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp
199212	}
200213
201214	private  void  onAuthenticationFailure (HttpServletRequest  request , HttpServletResponse  response ,
202- 			AuthenticationException  exception ) throws  IOException  {
215+ 			AuthenticationException  authenticationException ) throws  IOException  {
203216
204217		SecurityContextHolder .clearContext ();
205218
206- 		// TODO 
207- 		// The authorization server MAY return an HTTP 401 (Unauthorized) status code 
208- 		// to indicate which HTTP authentication schemes are supported. 
209- 		// If the client attempted to authenticate via the "Authorization" request header 
210- 		// field, 
211- 		// the authorization server MUST respond with an HTTP 401 (Unauthorized) status 
212- 		// code and 
213- 		// include the "WWW-Authenticate" response header field 
214- 		// matching the authentication scheme used by the client. 
215- 
216- 		OAuth2Error  error  = ((OAuth2AuthenticationException ) exception ).getError ();
219+ 		if  ( authenticationException   instanceof   OAuth2ClientAuthenticationException   clientAuthenticationException ) { 
220+ 			 OAuth2ClientAuthenticationToken   clientAuthentication  =  clientAuthenticationException 
221+ 				. getClientAuthentication (); 
222+ 			 if  ( ClientAuthenticationMethod . CLIENT_SECRET_BASIC 
223+ 				. equals ( clientAuthentication . getClientAuthenticationMethod ())) { 
224+ 				 this . basicAuthenticationEntryPoint . commence ( request ,  response ,  authenticationException ); 
225+ 				 return ; 
226+ 			} 
227+ 		} 
228+ 
229+ 		OAuth2Error  error  = ((OAuth2AuthenticationException ) authenticationException ).getError ();
217230		ServletServerHttpResponse  httpResponse  = new  ServletServerHttpResponse (response );
218231		if  (OAuth2ErrorCodes .INVALID_CLIENT .equals (error .getErrorCode ())) {
219232			httpResponse .setStatusCode (HttpStatus .UNAUTHORIZED );
@@ -248,4 +261,21 @@ private static void validateClientIdentifier(Authentication authentication) {
248261		}
249262	}
250263
264+ 	private  static  final  class  OAuth2ClientAuthenticationException  extends  OAuth2AuthenticationException  {
265+ 
266+ 		private  final  OAuth2ClientAuthenticationToken  clientAuthentication ;
267+ 
268+ 		private  OAuth2ClientAuthenticationException (OAuth2Error  error , Throwable  cause ,
269+ 				OAuth2ClientAuthenticationToken  clientAuthentication ) {
270+ 			super (error , cause );
271+ 			Assert .notNull (clientAuthentication , "clientAuthentication cannot be null" );
272+ 			this .clientAuthentication  = clientAuthentication ;
273+ 		}
274+ 
275+ 		private  OAuth2ClientAuthenticationToken  getClientAuthentication () {
276+ 			return  this .clientAuthentication ;
277+ 		}
278+ 
279+ 	}
280+ 
251281}
0 commit comments