Skip to content

Commit 85dc2e5

Browse files
committed
Propagate context in WebMvc GraphQlWebSocketHandler
See gh-342
1 parent e2948e9 commit 85dc2e5

File tree

9 files changed

+196
-34
lines changed

9 files changed

+196
-34
lines changed

spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,21 +16,24 @@
1616

1717
package org.springframework.graphql.execution;
1818

19+
import java.util.ArrayList;
1920
import java.util.List;
2021
import java.util.Map;
2122

23+
2224
/**
2325
* Default implementation of a composite accessor that is returned from
2426
* {@link ThreadLocalAccessor#composite(List)}.
2527
*
2628
* @author Rossen Stoyanchev
29+
* @since 1.0.0
2730
*/
2831
class CompositeThreadLocalAccessor implements ThreadLocalAccessor {
2932

3033
private final List<ThreadLocalAccessor> accessors;
3134

3235
CompositeThreadLocalAccessor(List<ThreadLocalAccessor> accessors) {
33-
this.accessors = accessors;
36+
this.accessors = new ArrayList<>(accessors);
3437
}
3538

3639
@Override

spring-graphql/src/main/java/org/springframework/graphql/server/DefaultWebGraphQlHandlerBuilder.java

+16-8
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,34 @@ public WebGraphQlHandler build() {
9696
.map(interceptor -> interceptor.apply(endOfChain))
9797
.orElse(endOfChain);
9898

99+
ThreadLocalAccessor accessor = (CollectionUtils.isEmpty(this.accessors) ? null :
100+
ThreadLocalAccessor.composite(this.accessors));
101+
99102
return new WebGraphQlHandler() {
100103

104+
@Override
105+
public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
106+
return (webSocketInterceptor != null ?
107+
webSocketInterceptor : new WebSocketGraphQlInterceptor() {});
108+
}
109+
110+
@Nullable
111+
@Override
112+
public ThreadLocalAccessor getThreadLocalAccessor() {
113+
return accessor;
114+
}
115+
101116
@Override
102117
public Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request) {
103118
return executionChain.next(request)
104119
.contextWrite(context -> {
105-
if (!CollectionUtils.isEmpty(accessors)) {
106-
ThreadLocalAccessor accessor = ThreadLocalAccessor.composite(accessors);
120+
if (accessor != null) {
107121
return ReactorContextManager.extractThreadLocalValues(accessor, context);
108122
}
109123
return context;
110124
});
111125
}
112126

113-
@Override
114-
public WebSocketGraphQlInterceptor webSocketInterceptor() {
115-
return (webSocketInterceptor != null ?
116-
webSocketInterceptor : new WebSocketGraphQlInterceptor() {});
117-
}
118-
119127
};
120128
}
121129

spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlHandler.java

+14-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import org.springframework.graphql.ExecutionGraphQlService;
2424
import org.springframework.graphql.execution.ThreadLocalAccessor;
25+
import org.springframework.lang.Nullable;
2526

2627

2728
/**
@@ -33,6 +34,19 @@
3334
*/
3435
public interface WebGraphQlHandler {
3536

37+
/**
38+
* Return the single interceptor of type
39+
* {@link WebSocketGraphQlInterceptor} among all the configured
40+
* interceptors.
41+
*/
42+
WebSocketGraphQlInterceptor getWebSocketInterceptor();
43+
44+
/**
45+
* Return the composite {@link ThreadLocalAccessor} that the handler is
46+
* configured with.
47+
*/
48+
@Nullable
49+
ThreadLocalAccessor getThreadLocalAccessor();
3650

3751
/**
3852
* Execute the given request and return the response.
@@ -41,13 +55,6 @@ public interface WebGraphQlHandler {
4155
*/
4256
Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request);
4357

44-
/**
45-
* Return the single interceptor of type
46-
* {@link WebSocketGraphQlInterceptor} among all the configured
47-
* interceptors.
48-
*/
49-
WebSocketGraphQlInterceptor webSocketInterceptor();
50-
5158

5259
/**
5360
* Provides access to a builder to create a {@link WebGraphQlHandler} instance.

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public GraphQlWebSocketHandler(
8383
Assert.notNull(graphQlHandler, "WebGraphQlHandler is required");
8484

8585
this.graphQlHandler = graphQlHandler;
86-
this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor();
86+
this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor();
8787
this.codecDelegate = new CodecDelegate(codecConfigurer);
8888
this.initTimeoutDuration = connectionInitTimeout;
8989
}

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java

+87-5
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
import java.io.ByteArrayInputStream;
2020
import java.io.ByteArrayOutputStream;
21+
import java.io.Closeable;
2122
import java.io.IOException;
2223
import java.io.InputStream;
2324
import java.io.OutputStream;
2425
import java.net.URI;
2526
import java.time.Duration;
2627
import java.util.Arrays;
2728
import java.util.Collections;
29+
import java.util.LinkedHashMap;
2830
import java.util.List;
2931
import java.util.Map;
3032
import java.util.concurrent.ConcurrentHashMap;
@@ -43,6 +45,7 @@
4345
import reactor.core.scheduler.Scheduler;
4446
import reactor.core.scheduler.Schedulers;
4547

48+
import org.springframework.graphql.execution.ThreadLocalAccessor;
4649
import org.springframework.graphql.server.WebGraphQlHandler;
4750
import org.springframework.graphql.server.WebGraphQlRequest;
4851
import org.springframework.graphql.server.WebGraphQlResponse;
@@ -53,15 +56,21 @@
5356
import org.springframework.http.HttpOutputMessage;
5457
import org.springframework.http.converter.GenericHttpMessageConverter;
5558
import org.springframework.http.converter.HttpMessageConverter;
59+
import org.springframework.http.server.ServerHttpRequest;
60+
import org.springframework.http.server.ServerHttpResponse;
5661
import org.springframework.lang.Nullable;
5762
import org.springframework.util.Assert;
5863
import org.springframework.util.CollectionUtils;
5964
import org.springframework.web.socket.CloseStatus;
6065
import org.springframework.web.socket.SubProtocolCapable;
6166
import org.springframework.web.socket.TextMessage;
67+
import org.springframework.web.socket.WebSocketHandler;
6268
import org.springframework.web.socket.WebSocketSession;
6369
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
6470
import org.springframework.web.socket.handler.TextWebSocketHandler;
71+
import org.springframework.web.socket.server.HandshakeHandler;
72+
import org.springframework.web.socket.server.HandshakeInterceptor;
73+
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
6574

6675
/**
6776
* WebSocketHandler for GraphQL based on
@@ -81,7 +90,9 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
8190

8291
private final WebGraphQlHandler graphQlHandler;
8392

84-
private final WebSocketGraphQlInterceptor webSocketInterceptor;
93+
private final ContextHandshakeInterceptor contextHandshakeInterceptor;
94+
95+
private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor;
8596

8697
private final Duration initTimeoutDuration;
8798

@@ -103,7 +114,8 @@ public GraphQlWebSocketHandler(
103114
Assert.notNull(converter, "HttpMessageConverter for JSON is required");
104115

105116
this.graphQlHandler = graphQlHandler;
106-
this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor();
117+
this.contextHandshakeInterceptor = new ContextHandshakeInterceptor(graphQlHandler.getThreadLocalAccessor());
118+
this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor();
107119
this.initTimeoutDuration = connectionInitTimeout;
108120
this.converter = converter;
109121
}
@@ -113,6 +125,18 @@ public List<String> getSubProtocols() {
113125
return SUB_PROTOCOL_LIST;
114126
}
115127

128+
/**
129+
* Return a {@link WebSocketHttpRequestHandler} that uses this instance as
130+
* its {@link WebGraphQlHandler} and adds a {@link HandshakeInterceptor} to
131+
* propagate context.
132+
*/
133+
public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) {
134+
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this, handshakeHandler);
135+
handler.setHandshakeInterceptors(Collections.singletonList(this.contextHandshakeInterceptor));
136+
return handler;
137+
}
138+
139+
116140
@Override
117141
public void afterConnectionEstablished(WebSocketSession session) {
118142
if ("graphql-ws".equalsIgnoreCase(session.getAcceptedProtocol())) {
@@ -137,8 +161,15 @@ public void afterConnectionEstablished(WebSocketSession session) {
137161

138162
}
139163

164+
@SuppressWarnings({"unused", "try"})
140165
@Override
141166
protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
167+
try (Closeable closeable = this.contextHandshakeInterceptor.restoreThreadLocalValue(session)) {
168+
handleInternal(session, webSocketMessage);
169+
}
170+
}
171+
172+
private void handleInternal(WebSocketSession session, TextMessage webSocketMessage) throws IOException {
142173
GraphQlWebSocketMessage message = decode(webSocketMessage);
143174
String id = message.getId();
144175
Map<String, Object> payload = message.getPayload();
@@ -174,7 +205,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
174205
if (subscription != null) {
175206
subscription.cancel();
176207
}
177-
this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id)
208+
this.webSocketGraphQlInterceptor.handleCancelledSubscription(session.getId(), id)
178209
.block(Duration.ofSeconds(10));
179210
}
180211
return;
@@ -183,7 +214,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
183214
GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
184215
return;
185216
}
186-
this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload)
217+
this.webSocketGraphQlInterceptor.handleConnectionInitialization(session.getId(), payload)
187218
.defaultIfEmpty(Collections.emptyMap())
188219
.publishOn(sessionState.getScheduler()) // Serial blocking send via single thread
189220
.doOnNext(ackPayload -> {
@@ -285,7 +316,7 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeSta
285316
info.dispose();
286317
Map<String, Object> connectionInitPayload = info.getConnectionInitPayload();
287318
if (connectionInitPayload != null) {
288-
this.webSocketInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload);
319+
this.webSocketGraphQlInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload);
289320
}
290321
}
291322
}
@@ -296,6 +327,57 @@ public boolean supportsPartialMessages() {
296327
}
297328

298329

330+
/**
331+
* {@code HandshakeInterceptor} that propagates ThreadLocal context through
332+
* the attributes map in {@code WebSocketSession}.
333+
*/
334+
private static class ContextHandshakeInterceptor implements HandshakeInterceptor {
335+
336+
private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor.class.getName();
337+
338+
@Nullable
339+
private final ThreadLocalAccessor accessor;
340+
341+
ContextHandshakeInterceptor(@Nullable ThreadLocalAccessor accessor) {
342+
this.accessor = accessor;
343+
}
344+
345+
@Override
346+
public boolean beforeHandshake(
347+
ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
348+
Map<String, Object> attributes) {
349+
350+
if (this.accessor != null) {
351+
Map<String, Object> valuesMap = new LinkedHashMap<>();
352+
this.accessor.extractValues(valuesMap);
353+
attributes.put(SAVED_CONTEXT_KEY, valuesMap);
354+
}
355+
return true;
356+
}
357+
358+
@Override
359+
public void afterHandshake(
360+
ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
361+
@Nullable Exception exception) {
362+
}
363+
364+
@SuppressWarnings("unchecked")
365+
public Closeable restoreThreadLocalValue(WebSocketSession session) {
366+
if (this.accessor != null) {
367+
Map<String, Object> valuesMap = (Map<String, Object>) session.getAttributes().get(SAVED_CONTEXT_KEY);
368+
// Uncomment when Boot is updated to use HandshakeInterceptor
369+
// Assert.state(valuesMap != null, "No context");
370+
if (valuesMap != null) {
371+
this.accessor.restoreValues(valuesMap);
372+
return () -> this.accessor.resetValues(valuesMap);
373+
}
374+
}
375+
return () -> {};
376+
}
377+
378+
}
379+
380+
299381
private static class GraphQlStatus {
300382

301383
private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");

spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,6 +20,8 @@
2020

2121
import org.springframework.graphql.BookSource;
2222
import org.springframework.graphql.GraphQlSetup;
23+
import org.springframework.graphql.execution.ThreadLocalAccessor;
24+
import org.springframework.lang.Nullable;
2325

2426
public abstract class WebSocketHandlerTestSupport {
2527

@@ -65,6 +67,12 @@ public abstract class WebSocketHandlerTestSupport {
6567

6668

6769
protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) {
70+
return initHandler(null, interceptors);
71+
}
72+
73+
protected WebGraphQlHandler initHandler(
74+
@Nullable ThreadLocalAccessor accessor, WebGraphQlInterceptor... interceptors) {
75+
6876
return GraphQlSetup.schemaResource(BookSource.schema)
6977
.queryFetcher("bookById", environment -> {
7078
Long id = Long.parseLong(environment.getArgument("id"));
@@ -75,6 +83,7 @@ protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) {
7583
return Flux.fromIterable(BookSource.books())
7684
.filter((book) -> book.getAuthor().getFullName().contains(author));
7785
})
86+
.threadLocalAccessor(accessor)
7887
.interceptor(interceptors)
7988
.toWebGraphQlHandler();
8089
}

0 commit comments

Comments
 (0)