Skip to content

Commit a60fd43

Browse files
wangzwrwinch
authored andcommitted
Fix OAuth2 Client with Ditributed Session
Fixes: gh-6215
1 parent 0c27f64 commit a60fd43

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
5353
if (state == null) {
5454
return Mono.empty();
5555
}
56-
return getStateToAuthorizationRequest(exchange, false)
56+
return getStateToAuthorizationRequest(exchange)
5757
.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
5858
.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
5959
}
@@ -62,9 +62,8 @@ public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
6262
public Mono<Void> saveAuthorizationRequest(
6363
OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
6464
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
65-
return getStateToAuthorizationRequest(exchange, true)
66-
.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
67-
.then();
65+
return saveStateToAuthorizationRequest(exchange).doOnNext(stateToAuthorizationRequest ->
66+
stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)).then();
6867
}
6968

7069
@Override
@@ -108,16 +107,28 @@ private Mono<Map<String, Object>> getSessionAttributes(ServerWebExchange exchang
108107
return exchange.getSession().map(WebSession::getAttributes);
109108
}
110109

111-
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) {
110+
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
111+
Assert.notNull(exchange, "exchange cannot be null");
112+
113+
return getSessionAttributes(exchange)
114+
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
115+
}
116+
117+
private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
112118
Assert.notNull(exchange, "exchange cannot be null");
113119

114120
return getSessionAttributes(exchange)
115121
.doOnNext(sessionAttrs -> {
116-
if (create) {
117-
sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap<String, OAuth2AuthorizationRequest>());
122+
Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
123+
124+
if (stateToAuthzRequest == null) {
125+
stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
118126
}
119-
})
120-
.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
127+
128+
// No matter stateToAuthzRequest was in session or not, we should always put it into session again
129+
// in case of redis or hazelcast session. #6215
130+
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
131+
}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
121132
}
122133

123134
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818

1919
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2020

21+
import static org.mockito.ArgumentMatchers.any;
22+
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.spy;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.when;
27+
import java.util.HashMap;
2128
import java.util.Map;
2229

2330
import org.junit.Test;
@@ -99,6 +106,36 @@ public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() {
99106
.verifyComplete();
100107
}
101108

109+
@Test
110+
public void multipleSavedAuthorizationRequestAndRedisCookie() {
111+
String oldState = "state0";
112+
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
113+
.queryParam(OAuth2ParameterNames.STATE, oldState).build();
114+
115+
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
116+
.authorizationUri("https://example.com/oauth2/authorize")
117+
.clientId("client-id")
118+
.redirectUri("http://localhost/client-1")
119+
.state(oldState)
120+
.build();
121+
122+
Map<String, Object> sessionAttrs = spy(new HashMap<>());
123+
WebSession session = mock(WebSession.class);
124+
when(session.getAttributes()).thenReturn(sessionAttrs);
125+
WebSessionManager sessionManager = e -> Mono.just(session);
126+
127+
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
128+
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
129+
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
130+
ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
131+
132+
Mono<Void> saveAndSave = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
133+
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange));
134+
135+
StepVerifier.create(saveAndSave).verifyComplete();
136+
verify(sessionAttrs, times(2)).put(any(), any());
137+
}
138+
102139
@Test
103140
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
104141
String oldState = "state0";

0 commit comments

Comments
 (0)