|
1 | 1 | /*
|
2 |
| - * Copyright 2002-2020 the original author or authors. |
| 2 | + * Copyright 2002-2021 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
16 | 16 |
|
17 | 17 | package org.springframework.security.oauth2.client.web;
|
18 | 18 |
|
| 19 | +import java.util.Collection; |
19 | 20 | import java.util.HashMap;
|
20 | 21 | import java.util.Map;
|
21 | 22 |
|
|
33 | 34 | import org.springframework.security.authentication.AuthenticationManager;
|
34 | 35 | import org.springframework.security.core.Authentication;
|
35 | 36 | import org.springframework.security.core.AuthenticationException;
|
| 37 | +import org.springframework.security.core.GrantedAuthority; |
36 | 38 | import org.springframework.security.core.authority.AuthorityUtils;
|
37 | 39 | import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
38 | 40 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
39 | 41 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
| 42 | +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; |
40 | 43 | import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
|
41 | 44 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
42 | 45 | import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
@@ -152,6 +155,12 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI
|
152 | 155 | assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
|
153 | 156 | }
|
154 | 157 |
|
| 158 | + // gh-10033 |
| 159 | + @Test |
| 160 | + public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() { |
| 161 | + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null)); |
| 162 | + } |
| 163 | + |
155 | 164 | @Test
|
156 | 165 | public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
|
157 | 166 | String requestUri = "/path";
|
@@ -416,6 +425,41 @@ public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationR
|
416 | 425 | assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
|
417 | 426 | }
|
418 | 427 |
|
| 428 | + // gh-10033 |
| 429 | + @Test |
| 430 | + public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception { |
| 431 | + this.filter.setAuthenticationResultConverter((authentication) -> null); |
| 432 | + String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); |
| 433 | + String state = "state"; |
| 434 | + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
| 435 | + request.setServletPath(requestUri); |
| 436 | + request.addParameter(OAuth2ParameterNames.CODE, "code"); |
| 437 | + request.addParameter(OAuth2ParameterNames.STATE, state); |
| 438 | + MockHttpServletResponse response = new MockHttpServletResponse(); |
| 439 | + this.setUpAuthorizationRequest(request, response, this.registration1, state); |
| 440 | + this.setUpAuthenticationResult(this.registration1); |
| 441 | + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response)); |
| 442 | + } |
| 443 | + |
| 444 | + // gh-10033 |
| 445 | + @Test |
| 446 | + public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() { |
| 447 | + this.filter.setAuthenticationResultConverter( |
| 448 | + (authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(), |
| 449 | + authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId())); |
| 450 | + String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId(); |
| 451 | + String state = "state"; |
| 452 | + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
| 453 | + request.setServletPath(requestUri); |
| 454 | + request.addParameter(OAuth2ParameterNames.CODE, "code"); |
| 455 | + request.addParameter(OAuth2ParameterNames.STATE, state); |
| 456 | + MockHttpServletResponse response = new MockHttpServletResponse(); |
| 457 | + this.setUpAuthorizationRequest(request, response, this.registration1, state); |
| 458 | + this.setUpAuthenticationResult(this.registration1); |
| 459 | + Authentication authenticationResult = this.filter.attemptAuthentication(request, response); |
| 460 | + assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class); |
| 461 | + } |
| 462 | + |
419 | 463 | private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
420 | 464 | ClientRegistration registration, String state) {
|
421 | 465 | Map<String, Object> attributes = new HashMap<>();
|
@@ -454,4 +498,13 @@ private void setUpAuthenticationResult(ClientRegistration registration) {
|
454 | 498 | given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
|
455 | 499 | }
|
456 | 500 |
|
| 501 | + private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken { |
| 502 | + |
| 503 | + CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities, |
| 504 | + String authorizedClientRegistrationId) { |
| 505 | + super(principal, authorities, authorizedClientRegistrationId); |
| 506 | + } |
| 507 | + |
| 508 | + } |
| 509 | + |
457 | 510 | }
|
0 commit comments