Skip to content

Commit f54ee52

Browse files
committed
Add WebSocketGraphQlClientInterceptor
See gh-322
1 parent 92547de commit f54ee52

File tree

5 files changed

+121
-28
lines changed

5 files changed

+121
-28
lines changed

spring-graphql/src/main/java/org/springframework/graphql/client/AbstractGraphQlClientBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ protected void setJsonCodecs(Encoder<?> encoder, Decoder<?> decoder) {
111111
this.jsonDecoder = decoder;
112112
}
113113

114+
/**
115+
* Return the configured interceptors. For subclasses that look for a
116+
* transport specific interceptor extensions.
117+
*/
118+
protected List<GraphQlClientInterceptor> getInterceptors() {
119+
return this.interceptors;
120+
}
121+
114122
/**
115123
* Build the default transport-agnostic client that subclasses can then wrap
116124
* with {@link AbstractDelegatingGraphQlClient}.

spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClient.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import java.net.URI;
2020
import java.util.Arrays;
21+
import java.util.List;
2122
import java.util.function.Consumer;
23+
import java.util.stream.Collectors;
2224

2325
import reactor.core.publisher.Mono;
2426

@@ -157,12 +159,25 @@ public WebSocketGraphQlClient build() {
157159
CodecDelegate.findJsonDecoder(this.codecConfigurer));
158160

159161
WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
160-
this.url, this.headers, this.webSocketClient, this.codecConfigurer, null, payload -> {});
162+
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor());
161163

162164
GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
163165
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
164166
}
165167

168+
private WebSocketGraphQlClientInterceptor getInterceptor() {
169+
170+
List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()
171+
.filter(interceptor -> interceptor instanceof WebSocketGraphQlClientInterceptor)
172+
.map(interceptor -> (WebSocketGraphQlClientInterceptor) interceptor)
173+
.collect(Collectors.toList());
174+
175+
Assert.state(interceptors.size() <= 1,
176+
"Only a single interceptor of type WebSocketGraphQlClientInterceptor may be configured");
177+
178+
return (!interceptors.isEmpty() ? interceptors.get(0) : new WebSocketGraphQlClientInterceptor() {});
179+
}
180+
166181
}
167182

168183
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright 2020-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.graphql.client;
18+
19+
20+
import java.util.Map;
21+
22+
import reactor.core.publisher.Mono;
23+
24+
25+
/**
26+
* An extension of {@link GraphQlClientInterceptor} with additional methods to
27+
* for WebSocket interception points. Only a single interceptor of type
28+
* {@link WebSocketGraphQlClientInterceptor} may be configured.
29+
*
30+
* @author Rossen Stoyanchev
31+
* @since 1.0.0
32+
*/
33+
public interface WebSocketGraphQlClientInterceptor extends GraphQlClientInterceptor {
34+
35+
/**
36+
* Provide a {@code Mono} that returns the payload for the
37+
* {@code "connection_init"} message. The {@code Mono} is subscribed to every
38+
* type a new WebSocket connection is established.
39+
*/
40+
default Mono<Object> connectionInitPayload() {
41+
return Mono.empty();
42+
}
43+
44+
/**
45+
* Handler the {@code "connection_ack"} message received from the server at
46+
* the start of the WebSocket connection.
47+
* @param ackPayload the payload of the {@code "connection_ack"} message
48+
*/
49+
default Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {
50+
return Mono.empty();
51+
}
52+
53+
}

spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.concurrent.ConcurrentHashMap;
2323
import java.util.concurrent.atomic.AtomicBoolean;
2424
import java.util.concurrent.atomic.AtomicLong;
25-
import java.util.function.Consumer;
2625

2726
import org.apache.commons.logging.Log;
2827
import org.apache.commons.logging.LogFactory;
@@ -69,17 +68,18 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
6968

7069
WebSocketGraphQlTransport(
7170
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
72-
@Nullable Object connectionInitPayload, Consumer<Map<String, Object>> connectionAckHandler) {
71+
WebSocketGraphQlClientInterceptor interceptor) {
7372

7473
Assert.notNull(url, "URI is required");
75-
Assert.notNull(url, "URI is required");
74+
Assert.notNull(client, "WebSocketClient is required");
75+
Assert.notNull(codecConfigurer, "CodecConfigurer is required");
76+
Assert.notNull(interceptor, "WebSocketGraphQlClientInterceptor is required");
7677

7778
this.url = url;
7879
this.headers.putAll(headers != null ? headers : HttpHeaders.EMPTY);
7980
this.webSocketClient = client;
8081

81-
this.graphQlSessionHandler = new GraphQlSessionHandler(
82-
codecConfigurer, connectionInitPayload, connectionAckHandler);
82+
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor);
8383

8484
this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
8585
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
@@ -167,21 +167,16 @@ private static class GraphQlSessionHandler implements WebSocketHandler {
167167

168168
private final CodecDelegate codecDelegate;
169169

170-
private final GraphQlMessage connectionInitMessage;
171-
172-
private final Consumer<Map<String, Object>> connectionAckHandler;
170+
private final WebSocketGraphQlClientInterceptor interceptor;
173171

174172
private Sinks.One<GraphQlSession> graphQlSessionSink;
175173

176174
private final AtomicBoolean stopped = new AtomicBoolean();
177175

178176

179-
GraphQlSessionHandler(CodecConfigurer codecConfigurer,
180-
@Nullable Object connectionInitPayload, Consumer<Map<String, Object>> connectionAckHandler) {
181-
177+
GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) {
182178
this.codecDelegate = new CodecDelegate(codecConfigurer);
183-
this.connectionInitMessage = GraphQlMessage.connectionInit(connectionInitPayload);
184-
this.connectionAckHandler = connectionAckHandler;
179+
this.interceptor = interceptor;
185180
this.graphQlSessionSink = Sinks.unsafe().one();
186181
}
187182

@@ -231,8 +226,12 @@ public Mono<Void> handle(WebSocketSession session) {
231226
GraphQlSession graphQlSession = new GraphQlSession(session);
232227
registerCloseStatusHandling(graphQlSession, session);
233228

229+
Mono<GraphQlMessage> connectionInitMono = this.interceptor.connectionInitPayload()
230+
.defaultIfEmpty(Collections.emptyMap())
231+
.map(GraphQlMessage::connectionInit);
232+
234233
Mono<Void> sendCompletion =
235-
session.send(Flux.just(this.connectionInitMessage).concatWith(graphQlSession.getRequestFlux())
234+
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
236235
.map(message -> this.codecDelegate.encode(session, message)));
237236

238237
Mono<Void> receiveCompletion = session.receive()
@@ -242,20 +241,23 @@ public Mono<Void> handle(WebSocketSession session) {
242241
GraphQlMessage message = this.codecDelegate.decode(webSocketMessage);
243242
Assert.state(message.resolvedType() == GraphQlMessageType.CONNECTION_ACK,
244243
() -> "Unexpected message before connection_ack: " + message);
245-
this.connectionAckHandler.accept(message.getPayload());
246-
if (logger.isDebugEnabled()) {
247-
logger.debug(graphQlSession + " initialized");
248-
}
244+
return this.interceptor.handleConnectionAck(message.getPayload())
245+
.then(Mono.defer(() -> {
246+
if (logger.isDebugEnabled()) {
247+
logger.debug(graphQlSession + " initialized");
248+
}
249+
Sinks.EmitResult result = this.graphQlSessionSink.tryEmitValue(graphQlSession);
250+
if (result.isFailure()) {
251+
return Mono.error(new IllegalStateException(
252+
"GraphQlSession initialized but could not be emitted: " + result));
253+
}
254+
return Mono.empty();
255+
}));
249256
}
250257
catch (Throwable ex) {
251258
this.graphQlSessionSink.tryEmitError(ex);
252259
return Mono.error(ex);
253260
}
254-
Sinks.EmitResult emitResult = this.graphQlSessionSink.tryEmitValue(graphQlSession);
255-
if (emitResult.isFailure()) {
256-
return Mono.error(new IllegalStateException(
257-
"GraphQlSession initialized but could not be emitted: " + emitResult));
258-
}
259261
}
260262
else {
261263
try {

spring-graphql/src/test/java/org/springframework/graphql/client/MockWebSocketGraphQlTransportTests.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
import reactor.core.publisher.Mono;
3232
import reactor.test.StepVerifier;
3333

34-
import org.springframework.graphql.support.DefaultGraphQlRequest;
3534
import org.springframework.graphql.GraphQlRequest;
3635
import org.springframework.graphql.GraphQlResponse;
3736
import org.springframework.graphql.ResponseError;
37+
import org.springframework.graphql.support.DefaultGraphQlRequest;
3838
import org.springframework.graphql.web.TestWebSocketClient;
3939
import org.springframework.graphql.web.TestWebSocketConnection;
4040
import org.springframework.graphql.web.support.GraphQlMessage;
@@ -196,9 +196,23 @@ void start() {
196196
Map<String, String> initPayload = Collections.singletonMap("key", "valueInit");
197197
AtomicReference<Map<String, Object>> connectionAckRef = new AtomicReference<>();
198198

199+
WebSocketGraphQlClientInterceptor interceptor = new WebSocketGraphQlClientInterceptor() {
200+
201+
@Override
202+
public Mono<Object> connectionInitPayload() {
203+
return Mono.just(initPayload);
204+
}
205+
206+
@Override
207+
public Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {
208+
connectionAckRef.set(ackPayload);
209+
return Mono.empty();
210+
}
211+
};
212+
213+
199214
WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
200-
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
201-
initPayload, connectionAckRef::set);
215+
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor);
202216

203217
transport.start().block(TIMEOUT);
204218

@@ -311,7 +325,8 @@ void errorDuringResponseHandling() {
311325

312326
private static WebSocketGraphQlTransport createTransport(WebSocketClient client) {
313327
return new WebSocketGraphQlTransport(
314-
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), null, p -> {});
328+
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
329+
new WebSocketGraphQlClientInterceptor() {});
315330
}
316331

317332
private void assertActualClientMessages(GraphQlMessage... expectedMessages) {

0 commit comments

Comments
 (0)