Skip to content

Commit ddb7408

Browse files
committed
Expose WebSocket session info in interceptor
Expose information about the WebSocketSession consistently in all methods of WebSocketGraphQlInterceptor. See gh-268
1 parent 8b5e2a3 commit ddb7408

File tree

9 files changed

+288
-42
lines changed

9 files changed

+288
-42
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlInterceptor.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import org.springframework.beans.factory.ObjectProvider;
2424
import org.springframework.graphql.ExecutionGraphQlService;
25+
import org.springframework.util.Assert;
2526

2627

2728
/**
@@ -44,7 +45,8 @@ public interface WebGraphQlInterceptor {
4445
/**
4546
* Intercept a request and delegate to the rest of the chain including other
4647
* interceptors and a {@link ExecutionGraphQlService}.
47-
* @param request the request to execute
48+
* @param request the request which may be a {@link WebSocketGraphQlRequest}
49+
* when intercepting a GraphQL request over WebSocket
4850
* @param chain the rest of the chain to execute the request
4951
* @return a {@link Mono} with the response
5052
*/
@@ -57,7 +59,13 @@ public interface WebGraphQlInterceptor {
5759
* @return a new interceptor that chains the two
5860
*/
5961
default WebGraphQlInterceptor andThen(WebGraphQlInterceptor nextInterceptor) {
60-
return (request, chain) -> intercept(request, nextRequest -> nextInterceptor.intercept(nextRequest, chain));
62+
return (request, chain) -> intercept(request, nextRequest -> {
63+
if (request instanceof WebSocketGraphQlRequest) {
64+
Assert.isTrue(nextRequest instanceof WebSocketGraphQlRequest,
65+
"Expected WebSocketGraphQlRequest but was: " + nextRequest.getClass().getName());
66+
}
67+
return nextInterceptor.intercept(nextRequest, chain);
68+
});
6169
}
6270

6371
/**

spring-graphql/src/main/java/org/springframework/graphql/server/WebSocketGraphQlInterceptor.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ default Mono<WebGraphQlResponse> intercept(WebGraphQlRequest request, Chain chai
4040
* Handle the {@code "connection_init"} message at the start of a GraphQL over
4141
* WebSocket session and return an optional payload for the
4242
* {@code "connection_ack"} message to send back.
43-
* @param sessionId the id of the WebSocket session
43+
* @param sessionInfo information about the underlying WebSocket session
4444
* @param connectionInitPayload the payload from the {@code "connection_init"} message
4545
* @return the payload for the {@code "connection_ack"}, or empty
4646
*/
4747
default Mono<Object> handleConnectionInitialization(
48-
String sessionId, Map<String, Object> connectionInitPayload) {
48+
WebSocketSessionInfo sessionInfo, Map<String, Object> connectionInitPayload) {
4949

5050
return Mono.empty();
5151
}
@@ -55,25 +55,25 @@ default Mono<Object> handleConnectionInitialization(
5555
* subscription stream. The underlying {@link org.reactivestreams.Publisher}
5656
* for the subscription is automatically cancelled. This callback is for any
5757
* additional, or more centralized handling across subscriptions.
58-
* @param sessionId the id of the WebSocket session
58+
* @param sessionInfo information about the underlying WebSocket session
5959
* @param subscriptionId the unique id for the subscription; correlates to the
6060
* {@link WebGraphQlRequest#getId() requestId} from the original {@code "subscribe"}
6161
* message that started the subscription
6262
* @return {@code Mono} for the completion of handling
6363
*/
64-
default Mono<Void> handleCancelledSubscription(String sessionId, String subscriptionId) {
64+
default Mono<Void> handleCancelledSubscription(WebSocketSessionInfo sessionInfo, String subscriptionId) {
6565
return Mono.empty();
6666
}
6767

6868
/**
6969
* Invoked when the WebSocket session is closed, from either side.
70-
* @param sessionId the id of the WebSocket session
70+
* @param sessionInfo information about the underlying WebSocket session
7171
* @param statusCode the WebSocket "close" status code
7272
* @param connectionInitPayload the payload from the {@code "connect_init"}
7373
* message received at the start of the connection
7474
*/
7575
default void handleConnectionClosed(
76-
String sessionId, int statusCode, Map<String, Object> connectionInitPayload) {
76+
WebSocketSessionInfo sessionInfo, int statusCode, Map<String, Object> connectionInitPayload) {
7777
}
7878

7979
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright 2020-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.graphql.server;
18+
19+
20+
import java.net.URI;
21+
import java.util.Locale;
22+
import java.util.Map;
23+
24+
import org.springframework.http.HttpHeaders;
25+
import org.springframework.lang.Nullable;
26+
import org.springframework.util.Assert;
27+
28+
29+
/**
30+
* {@link org.springframework.graphql.server.WebGraphQlRequest} extension for
31+
* server handling of GraphQL over WebSocket requests.
32+
*
33+
* @author Rossen Stoyanchev
34+
* @since 1.0.0
35+
*/
36+
public class WebSocketGraphQlRequest extends WebGraphQlRequest {
37+
38+
private final WebSocketSessionInfo sessionInfo;
39+
40+
41+
/**
42+
* Create an instance.
43+
* @param uri the URL for the HTTP request or WebSocket handshake
44+
* @param headers the HTTP request headers
45+
* @param body the deserialized content of the GraphQL request
46+
* @param id the id from the GraphQL over WebSocket {@code "subscribe"} message
47+
* @param locale the locale from the HTTP request, if any
48+
* @param sessionInfo the WebSocket session id
49+
*/
50+
public WebSocketGraphQlRequest(
51+
URI uri, HttpHeaders headers, Map<String, Object> body, String id, @Nullable Locale locale,
52+
WebSocketSessionInfo sessionInfo) {
53+
54+
super(uri, headers, body, id, locale);
55+
Assert.notNull(sessionInfo, "WebSocketSessionInfo is required");
56+
this.sessionInfo = sessionInfo;
57+
}
58+
59+
60+
/**
61+
* Return information about the underlying WebSocket session.
62+
*/
63+
public WebSocketSessionInfo getSessionInfo() {
64+
return this.sessionInfo;
65+
}
66+
67+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.graphql.server;
17+
18+
import java.net.InetSocketAddress;
19+
import java.net.URI;
20+
import java.security.Principal;
21+
import java.util.Map;
22+
23+
import reactor.core.publisher.Mono;
24+
25+
import org.springframework.http.HttpHeaders;
26+
import org.springframework.lang.Nullable;
27+
28+
/**
29+
* Expose information about the underlying WebSocketSession including the
30+
* session id, the attributes, and HTTP handshake request.
31+
*
32+
* @author Rossen Stoyanchev
33+
* @since 1.0.0
34+
*/
35+
public interface WebSocketSessionInfo {
36+
37+
/**
38+
* Return the id for the WebSocketSession.
39+
*/
40+
String getId();
41+
42+
/**
43+
* Return the map with attributes associated with the WebSocket session.
44+
*/
45+
Map<String, Object> getAttributes();
46+
47+
/**
48+
* Return the URL for the WebSocket endpoint.
49+
*/
50+
URI getUri();
51+
52+
/**
53+
* Return the HTTP headers from the handshake request.
54+
*/
55+
HttpHeaders getHeaders();
56+
57+
/**
58+
* Return the principal associated with the handshake request, if any.
59+
*/
60+
Mono<Principal> getPrincipal();
61+
62+
/**
63+
* For a server session this is the remote address where the handshake
64+
* request came from. For a client session, it is {@code null}.
65+
*/
66+
@Nullable
67+
InetSocketAddress getRemoteAddress();
68+
69+
}

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.graphql.server.webflux;
1818

19+
import java.net.InetSocketAddress;
20+
import java.net.URI;
21+
import java.security.Principal;
1922
import java.time.Duration;
2023
import java.util.Arrays;
2124
import java.util.Collections;
@@ -33,10 +36,12 @@
3336
import reactor.core.publisher.Mono;
3437

3538
import org.springframework.graphql.server.WebGraphQlHandler;
36-
import org.springframework.graphql.server.WebGraphQlRequest;
3739
import org.springframework.graphql.server.WebGraphQlResponse;
3840
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
41+
import org.springframework.graphql.server.WebSocketGraphQlRequest;
42+
import org.springframework.graphql.server.WebSocketSessionInfo;
3943
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
44+
import org.springframework.http.HttpHeaders;
4045
import org.springframework.http.codec.CodecConfigurer;
4146
import org.springframework.util.Assert;
4247
import org.springframework.util.CollectionUtils;
@@ -106,6 +111,7 @@ public Mono<Void> handle(WebSocketSession session) {
106111
}
107112

108113
// Session state
114+
WebSocketSessionInfo sessionInfo = new WebFluxSessionInfo(session);
109115
AtomicReference<Map<String, Object>> connectionInitPayloadRef = new AtomicReference<>();
110116
Map<String, Subscription> subscriptions = new ConcurrentHashMap<>();
111117

@@ -123,7 +129,7 @@ public Mono<Void> handle(WebSocketSession session) {
123129
return;
124130
}
125131
int statusCode = (closeStatus != null ? closeStatus.getCode() : 1005);
126-
this.webSocketInterceptor.handleConnectionClosed(session.getId(), statusCode, connectionInitPayload);
132+
this.webSocketInterceptor.handleConnectionClosed(sessionInfo, statusCode, connectionInitPayload);
127133
})
128134
.subscribe();
129135

@@ -139,8 +145,8 @@ public Mono<Void> handle(WebSocketSession session) {
139145
if (id == null) {
140146
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
141147
}
142-
WebGraphQlRequest request = new WebGraphQlRequest(
143-
handshakeInfo.getUri(), handshakeInfo.getHeaders(), payload, id, null);
148+
WebSocketGraphQlRequest request = new WebSocketGraphQlRequest(
149+
handshakeInfo.getUri(), handshakeInfo.getHeaders(), payload, id, null, sessionInfo);
144150
if (logger.isDebugEnabled()) {
145151
logger.debug("Executing: " + request);
146152
}
@@ -155,15 +161,15 @@ public Mono<Void> handle(WebSocketSession session) {
155161
if (subscription != null) {
156162
subscription.cancel();
157163
}
158-
return this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id)
164+
return this.webSocketInterceptor.handleCancelledSubscription(sessionInfo, id)
159165
.thenMany(Flux.empty());
160166
}
161167
return Flux.empty();
162168
case CONNECTION_INIT:
163169
if (!connectionInitPayloadRef.compareAndSet(null, payload)) {
164170
return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
165171
}
166-
return this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload)
172+
return this.webSocketInterceptor.handleConnectionInitialization(sessionInfo, payload)
167173
.defaultIfEmpty(Collections.emptyMap())
168174
.map(ackPayload -> this.codecDelegate.encodeConnectionAck(session, ackPayload))
169175
.flux()
@@ -232,6 +238,46 @@ static <V> Flux<V> close(WebSocketSession session, CloseStatus status) {
232238
}
233239

234240

241+
private static class WebFluxSessionInfo implements WebSocketSessionInfo {
242+
243+
private final WebSocketSession session;
244+
245+
private WebFluxSessionInfo(WebSocketSession session) {
246+
this.session = session;
247+
}
248+
249+
@Override
250+
public String getId() {
251+
return this.session.getId();
252+
}
253+
254+
@Override
255+
public Map<String, Object> getAttributes() {
256+
return this.session.getAttributes();
257+
}
258+
259+
@Override
260+
public URI getUri() {
261+
return this.session.getHandshakeInfo().getUri();
262+
}
263+
264+
@Override
265+
public HttpHeaders getHeaders() {
266+
return this.session.getHandshakeInfo().getHeaders();
267+
}
268+
269+
@Override
270+
public Mono<Principal> getPrincipal() {
271+
return this.session.getHandshakeInfo().getPrincipal();
272+
}
273+
274+
@Override
275+
public InetSocketAddress getRemoteAddress() {
276+
return this.session.getHandshakeInfo().getRemoteAddress();
277+
}
278+
}
279+
280+
235281
@SuppressWarnings("serial")
236282
private static class SubscriptionExistsException extends RuntimeException {
237283
}

0 commit comments

Comments
 (0)