Skip to content

Commit 87de6ce

Browse files
committed
Use Reactive JSON Encoder
Closes gh-16177
1 parent 3d1e4b5 commit 87de6ce

File tree

4 files changed

+190
-34
lines changed

4 files changed

+190
-34
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.web.server;
18+
19+
import org.springframework.http.converter.GenericHttpMessageConverter;
20+
import org.springframework.http.converter.HttpMessageConverter;
21+
import org.springframework.http.converter.json.GsonHttpMessageConverter;
22+
import org.springframework.http.converter.json.JsonbHttpMessageConverter;
23+
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
24+
import org.springframework.util.ClassUtils;
25+
26+
/**
27+
* Utility methods for {@link HttpMessageConverter}'s.
28+
*
29+
* @author Joe Grandja
30+
* @author luamas
31+
* @since 5.1
32+
*/
33+
final class HttpMessageConverters {
34+
35+
private static final boolean jackson2Present;
36+
37+
private static final boolean gsonPresent;
38+
39+
private static final boolean jsonbPresent;
40+
41+
static {
42+
ClassLoader classLoader = HttpMessageConverters.class.getClassLoader();
43+
jackson2Present = ClassUtils.isPresent("com.fasterxml.jackson.databind.ObjectMapper", classLoader)
44+
&& ClassUtils.isPresent("com.fasterxml.jackson.core.JsonGenerator", classLoader);
45+
gsonPresent = ClassUtils.isPresent("com.google.gson.Gson", classLoader);
46+
jsonbPresent = ClassUtils.isPresent("jakarta.json.bind.Jsonb", classLoader);
47+
}
48+
49+
private HttpMessageConverters() {
50+
}
51+
52+
static GenericHttpMessageConverter<Object> getJsonMessageConverter() {
53+
if (jackson2Present) {
54+
return new MappingJackson2HttpMessageConverter();
55+
}
56+
if (gsonPresent) {
57+
return new GsonHttpMessageConverter();
58+
}
59+
if (jsonbPresent) {
60+
return new JsonbHttpMessageConverter();
61+
}
62+
return null;
63+
}
64+
65+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.web.server;
18+
19+
import java.io.ByteArrayOutputStream;
20+
import java.io.IOException;
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
import org.jetbrains.annotations.NotNull;
25+
import org.reactivestreams.Publisher;
26+
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
28+
29+
import org.springframework.core.ResolvableType;
30+
import org.springframework.core.io.buffer.DataBuffer;
31+
import org.springframework.core.io.buffer.DataBufferFactory;
32+
import org.springframework.http.HttpHeaders;
33+
import org.springframework.http.HttpOutputMessage;
34+
import org.springframework.http.MediaType;
35+
import org.springframework.http.codec.HttpMessageEncoder;
36+
import org.springframework.http.converter.HttpMessageConverter;
37+
import org.springframework.security.oauth2.core.OAuth2Error;
38+
import org.springframework.util.MimeType;
39+
40+
class OAuth2ErrorEncoder implements HttpMessageEncoder<OAuth2Error> {
41+
42+
private final HttpMessageConverter<Object> messageConverter = HttpMessageConverters.getJsonMessageConverter();
43+
44+
@NotNull
45+
@Override
46+
public List<MediaType> getStreamingMediaTypes() {
47+
return List.of();
48+
}
49+
50+
@Override
51+
public boolean canEncode(ResolvableType elementType, MimeType mimeType) {
52+
return getEncodableMimeTypes().contains(mimeType);
53+
}
54+
55+
@NotNull
56+
@Override
57+
public Flux<DataBuffer> encode(Publisher<? extends OAuth2Error> error, DataBufferFactory bufferFactory,
58+
ResolvableType elementType, MimeType mimeType, Map<String, Object> hints) {
59+
return Mono.from(error).flatMap((data) -> {
60+
ByteArrayHttpOutputMessage bytes = new ByteArrayHttpOutputMessage();
61+
try {
62+
this.messageConverter.write(data, MediaType.APPLICATION_JSON, bytes);
63+
return Mono.just(bytes.getBody().toByteArray());
64+
}
65+
catch (IOException ex) {
66+
return Mono.error(ex);
67+
}
68+
}).map(bufferFactory::wrap).flux();
69+
}
70+
71+
@NotNull
72+
@Override
73+
public List<MimeType> getEncodableMimeTypes() {
74+
return List.of(MediaType.APPLICATION_JSON);
75+
}
76+
77+
private static class ByteArrayHttpOutputMessage implements HttpOutputMessage {
78+
79+
private final ByteArrayOutputStream body = new ByteArrayOutputStream();
80+
81+
@NotNull
82+
@Override
83+
public ByteArrayOutputStream getBody() {
84+
return this.body;
85+
}
86+
87+
@NotNull
88+
@Override
89+
public HttpHeaders getHeaders() {
90+
return new HttpHeaders();
91+
}
92+
93+
}
94+
95+
}

config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutWebFilter.java

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616

1717
package org.springframework.security.config.web.server;
1818

19-
import java.nio.charset.StandardCharsets;
19+
import java.util.Collections;
2020

2121
import jakarta.servlet.http.HttpServletResponse;
2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
24-
import reactor.core.publisher.Flux;
2524
import reactor.core.publisher.Mono;
2625

27-
import org.springframework.core.io.buffer.DataBuffer;
28-
import org.springframework.http.server.reactive.ServerHttpResponse;
26+
import org.springframework.core.ResolvableType;
27+
import org.springframework.http.MediaType;
28+
import org.springframework.http.codec.EncoderHttpMessageWriter;
29+
import org.springframework.http.codec.HttpMessageWriter;
2930
import org.springframework.security.authentication.AuthenticationManager;
3031
import org.springframework.security.authentication.AuthenticationServiceException;
3132
import org.springframework.security.authentication.ReactiveAuthenticationManager;
@@ -62,6 +63,9 @@ class OidcBackChannelLogoutWebFilter implements WebFilter {
6263

6364
private ServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
6465

66+
private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
67+
new OAuth2ErrorEncoder());
68+
6569
/**
6670
* Construct an {@link OidcBackChannelLogoutWebFilter}
6771
* @param authenticationConverter the {@link AuthenticationConverter} for deriving
@@ -84,7 +88,7 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
8488
if (ex instanceof AuthenticationServiceException) {
8589
return Mono.error(ex);
8690
}
87-
return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
91+
return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
8892
})
8993
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
9094
.flatMap(this.authenticationManager::authenticate)
@@ -93,27 +97,20 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
9397
if (ex instanceof AuthenticationServiceException) {
9498
return Mono.error(ex);
9599
}
96-
return handleAuthenticationFailure(exchange.getResponse(), ex).then(Mono.empty());
100+
return handleAuthenticationFailure(exchange, ex).then(Mono.empty());
97101
})
98102
.flatMap((authentication) -> {
99103
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
100104
return this.logoutHandler.logout(webFilterExchange, authentication);
101105
});
102106
}
103107

104-
private Mono<Void> handleAuthenticationFailure(ServerHttpResponse response, Exception ex) {
108+
private Mono<Void> handleAuthenticationFailure(ServerWebExchange exchange, Exception ex) {
105109
this.logger.debug("Failed to process OIDC Back-Channel Logout", ex);
106-
response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
107-
OAuth2Error error = oauth2Error(ex);
108-
byte[] bytes = String.format("""
109-
{
110-
"error_code": "%s",
111-
"error_description": "%s",
112-
"error_uri: "%s"
113-
}
114-
""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
115-
DataBuffer buffer = response.bufferFactory().wrap(bytes);
116-
return response.writeWith(Flux.just(buffer));
110+
exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
111+
return this.errorHttpMessageConverter.write(Mono.just(oauth2Error(ex)), ResolvableType.forClass(Object.class),
112+
ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
113+
exchange.getResponse(), Collections.emptyMap());
117114
}
118115

119116
private OAuth2Error oauth2Error(Exception ex) {

config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,24 @@
1616

1717
package org.springframework.security.config.web.server;
1818

19-
import java.nio.charset.StandardCharsets;
2019
import java.util.Collection;
20+
import java.util.Collections;
2121
import java.util.HashMap;
2222
import java.util.Map;
2323
import java.util.concurrent.atomic.AtomicInteger;
2424

2525
import jakarta.servlet.http.HttpServletResponse;
2626
import org.apache.commons.logging.Log;
2727
import org.apache.commons.logging.LogFactory;
28-
import reactor.core.publisher.Flux;
2928
import reactor.core.publisher.Mono;
3029

31-
import org.springframework.core.io.buffer.DataBuffer;
30+
import org.springframework.core.ResolvableType;
3231
import org.springframework.http.HttpHeaders;
32+
import org.springframework.http.MediaType;
3333
import org.springframework.http.ResponseEntity;
34+
import org.springframework.http.codec.EncoderHttpMessageWriter;
35+
import org.springframework.http.codec.HttpMessageWriter;
3436
import org.springframework.http.server.reactive.ServerHttpRequest;
35-
import org.springframework.http.server.reactive.ServerHttpResponse;
3637
import org.springframework.security.core.Authentication;
3738
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken;
3839
import org.springframework.security.oauth2.client.oidc.server.session.InMemoryReactiveOidcSessionRegistry;
@@ -44,6 +45,7 @@
4445
import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
4546
import org.springframework.util.Assert;
4647
import org.springframework.web.reactive.function.client.WebClient;
48+
import org.springframework.web.server.ServerWebExchange;
4749
import org.springframework.web.util.UriComponents;
4850
import org.springframework.web.util.UriComponentsBuilder;
4951

@@ -63,6 +65,9 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler {
6365

6466
private ReactiveOidcSessionRegistry sessionRegistry = new InMemoryReactiveOidcSessionRegistry();
6567

68+
private final HttpMessageWriter<OAuth2Error> errorHttpMessageConverter = new EncoderHttpMessageWriter<>(
69+
new OAuth2ErrorEncoder());
70+
6671
private WebClient web = WebClient.create();
6772

6873
private String logoutUri = "{baseScheme}://localhost{basePort}/logout";
@@ -97,7 +102,7 @@ public Mono<Void> logout(WebFilterExchange exchange, Authentication authenticati
97102
totalCount.intValue()));
98103
}
99104
if (!list.isEmpty()) {
100-
return handleLogoutFailure(exchange.getExchange().getResponse(), oauth2Error(list));
105+
return handleLogoutFailure(exchange.getExchange(), oauth2Error(list));
101106
}
102107
else {
103108
return Mono.empty();
@@ -148,17 +153,11 @@ private OAuth2Error oauth2Error(Collection<?> errors) {
148153
"https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation");
149154
}
150155

151-
private Mono<Void> handleLogoutFailure(ServerHttpResponse response, OAuth2Error error) {
152-
response.setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
153-
byte[] bytes = String.format("""
154-
{
155-
"error_code": "%s",
156-
"error_description": "%s",
157-
"error_uri: "%s"
158-
}
159-
""", error.getErrorCode(), error.getDescription(), error.getUri()).getBytes(StandardCharsets.UTF_8);
160-
DataBuffer buffer = response.bufferFactory().wrap(bytes);
161-
return response.writeWith(Flux.just(buffer));
156+
private Mono<Void> handleLogoutFailure(ServerWebExchange exchange, OAuth2Error error) {
157+
exchange.getResponse().setRawStatusCode(HttpServletResponse.SC_BAD_REQUEST);
158+
return this.errorHttpMessageConverter.write(Mono.just(error), ResolvableType.forClass(Object.class),
159+
ResolvableType.forClass(Object.class), MediaType.APPLICATION_JSON, exchange.getRequest(),
160+
exchange.getResponse(), Collections.emptyMap());
162161
}
163162

164163
/**

0 commit comments

Comments
 (0)