Skip to content

Commit baa3b28

Browse files
committed
Add Predicate for authorizationConsentRequired for device code grant
Introduces customizable Predicate to determine if user consent is required in device authorization flows. Previously, device consent handling used fixed logic. Now applications can define custom logic for skipping or displaying consent pages. Adds OAuth2DeviceVerificationAuthenticationContext and updates OAuth2DeviceVerificationAuthenticationProvider with setAuthorizationConsentRequired method. Fixes gh-18016 Signed-off-by: Dinesh Gupta <[email protected]>
1 parent d5c5bb2 commit baa3b28

File tree

3 files changed

+231
-4
lines changed

3 files changed

+231
-4
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2004-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.server.authorization.authentication;
18+
19+
import java.util.Collections;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
import java.util.Set;
23+
24+
import org.springframework.lang.Nullable;
25+
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
26+
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
27+
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
28+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
29+
import org.springframework.util.Assert;
30+
31+
/**
32+
* An {@link OAuth2AuthenticationContext} that holds an
33+
* {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is
34+
* used when determining if authorization consent is required.
35+
*
36+
* @author Dinesh Gupta
37+
* @since 7.0
38+
* @see OAuth2AuthenticationContext
39+
* @see OAuth2DeviceVerificationAuthenticationToken
40+
* @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate)
41+
*/
42+
public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext {
43+
44+
private final Map<Object, Object> context;
45+
46+
private OAuth2DeviceVerificationAuthenticationContext(Map<Object, Object> context) {
47+
this.context = Collections.unmodifiableMap(new HashMap<>(context));
48+
}
49+
50+
@SuppressWarnings("unchecked")
51+
@Nullable
52+
@Override
53+
public <V> V get(Object key) {
54+
return hasKey(key) ? (V) this.context.get(key) : null;
55+
}
56+
57+
@Override
58+
public boolean hasKey(Object key) {
59+
Assert.notNull(key, "key cannot be null");
60+
return this.context.containsKey(key);
61+
}
62+
63+
/**
64+
* Returns the {@link RegisteredClient registered client}.
65+
* @return the {@link RegisteredClient}
66+
*/
67+
public RegisteredClient getRegisteredClient() {
68+
return get(RegisteredClient.class);
69+
}
70+
71+
/**
72+
* Returns the {@link OAuth2Authorization authorization}.
73+
* @return the {@link OAuth2Authorization}
74+
*/
75+
public OAuth2Authorization getAuthorization() {
76+
return get(OAuth2Authorization.class);
77+
}
78+
79+
/**
80+
* Returns the {@link OAuth2AuthorizationConsent authorization consent}.
81+
* @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
82+
*/
83+
@Nullable
84+
public OAuth2AuthorizationConsent getAuthorizationConsent() {
85+
return get(OAuth2AuthorizationConsent.class);
86+
}
87+
88+
/**
89+
* Returns the requested scopes.
90+
* @return the requested scopes
91+
*/
92+
public Set<String> getRequestedScopes() {
93+
Set<String> requestedScopes = getAuthorization().getAttribute(OAuth2ParameterNames.SCOPE);
94+
return (requestedScopes != null) ? requestedScopes : Collections.emptySet();
95+
}
96+
97+
/**
98+
* Constructs a new {@link Builder} with the provided
99+
* {@link OAuth2DeviceVerificationAuthenticationToken}.
100+
* @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken}
101+
* @return the {@link Builder}
102+
*/
103+
public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) {
104+
return new Builder(authentication);
105+
}
106+
107+
/**
108+
* A builder for {@link OAuth2DeviceVerificationAuthenticationContext}.
109+
*/
110+
public static final class Builder extends AbstractBuilder<OAuth2DeviceVerificationAuthenticationContext, Builder> {
111+
112+
private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) {
113+
super(authentication);
114+
}
115+
116+
/**
117+
* Sets the {@link RegisteredClient registered client}.
118+
* @param registeredClient the {@link RegisteredClient}
119+
* @return the {@link Builder} for further configuration
120+
*/
121+
public Builder registeredClient(RegisteredClient registeredClient) {
122+
return put(RegisteredClient.class, registeredClient);
123+
}
124+
125+
/**
126+
* Sets the {@link OAuth2Authorization authorization}.
127+
* @param authorization the {@link OAuth2Authorization}
128+
* @return the {@link Builder} for further configuration
129+
*/
130+
public Builder authorization(OAuth2Authorization authorization) {
131+
return put(OAuth2Authorization.class, authorization);
132+
}
133+
134+
/**
135+
* Sets the {@link OAuth2AuthorizationConsent authorization consent}.
136+
* @param authorizationConsent the {@link OAuth2AuthorizationConsent}
137+
* @return the {@link Builder} for further configuration
138+
*/
139+
public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
140+
return put(OAuth2AuthorizationConsent.class, authorizationConsent);
141+
}
142+
143+
/**
144+
* Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}.
145+
* @return the {@link OAuth2DeviceVerificationAuthenticationContext}
146+
*/
147+
@Override
148+
public OAuth2DeviceVerificationAuthenticationContext build() {
149+
Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null");
150+
Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null");
151+
return new OAuth2DeviceVerificationAuthenticationContext(getContext());
152+
}
153+
154+
}
155+
156+
}

oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.security.Principal;
2020
import java.util.Base64;
2121
import java.util.Set;
22+
import java.util.function.Predicate;
2223

2324
import org.apache.commons.logging.Log;
2425
import org.apache.commons.logging.LogFactory;
@@ -79,6 +80,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
7980

8081
private final OAuth2AuthorizationConsentService authorizationConsentService;
8182

83+
private Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired;
84+
8285
/**
8386
* Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the
8487
* provided parameters.
@@ -143,10 +146,18 @@ public Authentication authenticate(Authentication authentication) throws Authent
143146

144147
Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
145148

149+
OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
150+
.with(deviceVerificationAuthentication)
151+
.registeredClient(registeredClient)
152+
.authorization(authorization);
153+
146154
OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService
147155
.findById(registeredClient.getId(), principal.getName());
156+
if (currentAuthorizationConsent != null) {
157+
authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);
158+
}
148159

149-
if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) {
160+
if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
150161
String state = DEFAULT_STATE_GENERATOR.generateKey();
151162
authorization = OAuth2Authorization.from(authorization)
152163
.principalName(principal.getName())
@@ -204,10 +215,37 @@ public boolean supports(Class<?> authentication) {
204215
return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication);
205216
}
206217

207-
private static boolean requiresAuthorizationConsent(Set<String> requestedScopes,
208-
OAuth2AuthorizationConsent authorizationConsent) {
218+
/**
219+
* Sets the {@code Predicate} used to determine if authorization consent is required.
220+
*
221+
* <p>
222+
* The {@link OAuth2DeviceVerificationAuthenticationContext} gives the predicate
223+
* access to the {@link OAuth2DeviceVerificationAuthenticationToken}, as well as, the
224+
* following context attributes:
225+
* <ul>
226+
* <li>The {@link RegisteredClient} associated with the device authorization
227+
* request.</li>
228+
* <li>The {@link OAuth2Authorization} containing the device authorization request
229+
* parameters.</li>
230+
* <li>The {@link OAuth2AuthorizationConsent} previously granted to the
231+
* {@link RegisteredClient}, or {@code null} if not available.</li>
232+
* </ul>
233+
* </p>
234+
* @param authorizationConsentRequired the {@code Predicate} used to determine if
235+
* authorization consent is required
236+
*/
237+
public void setAuthorizationConsentRequired(
238+
Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired) {
239+
Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null");
240+
this.authorizationConsentRequired = authorizationConsentRequired;
241+
}
242+
243+
private static boolean isAuthorizationConsentRequired(
244+
OAuth2DeviceVerificationAuthenticationContext authenticationContext) {
209245

210-
if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) {
246+
if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
247+
.getScopes()
248+
.containsAll(authenticationContext.getRequestedScopes())) {
211249
return false;
212250
}
213251

oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.function.Consumer;
2525
import java.util.function.Function;
26+
import java.util.function.Predicate;
2627

2728
import org.junit.jupiter.api.BeforeEach;
2829
import org.junit.jupiter.api.Test;
@@ -125,6 +126,13 @@ public void constructorWhenAuthorizationConsentServiceIsNullThenThrowIllegalArgu
125126
// @formatter:on
126127
}
127128

129+
@Test
130+
public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() {
131+
assertThatIllegalArgumentException()
132+
.isThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null))
133+
.withMessage("authorizationConsentRequired cannot be null");
134+
}
135+
128136
@Test
129137
public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() {
130138
assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue();
@@ -382,6 +390,31 @@ public void authenticateWhenAuthorizationConsentExistsAndRequestedScopesDoNotMat
382390
.isEqualTo(authenticationResult.getState());
383391
}
384392

393+
@Test
394+
public void authenticateWhenCustomAuthorizationConsentRequiredThenUsed() {
395+
@SuppressWarnings("unchecked")
396+
Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired = mock(Predicate.class);
397+
this.authenticationProvider.setAuthorizationConsentRequired(authorizationConsentRequired);
398+
399+
// @formatter:off
400+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
401+
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
402+
.authorizationGrantType(AuthorizationGrantType.DEVICE_CODE)
403+
.token(createDeviceCode())
404+
.token(createUserCode())
405+
.attributes(Map::clear)
406+
.attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
407+
.build();
408+
// @formatter:on
409+
Authentication authentication = createAuthentication();
410+
given(this.registeredClientRepository.findById(anyString())).willReturn(registeredClient);
411+
given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).willReturn(authorization);
412+
413+
this.authenticationProvider.authenticate(authentication);
414+
415+
verify(authorizationConsentRequired).test(any());
416+
}
417+
385418
private static void mockAuthorizationServerContext() {
386419
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build();
387420
TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(

0 commit comments

Comments
 (0)