18
18
19
19
import java .io .ByteArrayInputStream ;
20
20
import java .io .ByteArrayOutputStream ;
21
+ import java .io .Closeable ;
21
22
import java .io .IOException ;
22
23
import java .io .InputStream ;
23
24
import java .io .OutputStream ;
24
25
import java .net .URI ;
25
26
import java .time .Duration ;
26
27
import java .util .Arrays ;
27
28
import java .util .Collections ;
29
+ import java .util .LinkedHashMap ;
28
30
import java .util .List ;
29
31
import java .util .Map ;
30
32
import java .util .concurrent .ConcurrentHashMap ;
43
45
import reactor .core .scheduler .Scheduler ;
44
46
import reactor .core .scheduler .Schedulers ;
45
47
48
+ import org .springframework .graphql .execution .ThreadLocalAccessor ;
46
49
import org .springframework .graphql .server .WebGraphQlHandler ;
47
50
import org .springframework .graphql .server .WebGraphQlRequest ;
48
51
import org .springframework .graphql .server .WebGraphQlResponse ;
53
56
import org .springframework .http .HttpOutputMessage ;
54
57
import org .springframework .http .converter .GenericHttpMessageConverter ;
55
58
import org .springframework .http .converter .HttpMessageConverter ;
59
+ import org .springframework .http .server .ServerHttpRequest ;
60
+ import org .springframework .http .server .ServerHttpResponse ;
56
61
import org .springframework .lang .Nullable ;
57
62
import org .springframework .util .Assert ;
58
63
import org .springframework .util .CollectionUtils ;
59
64
import org .springframework .web .socket .CloseStatus ;
60
65
import org .springframework .web .socket .SubProtocolCapable ;
61
66
import org .springframework .web .socket .TextMessage ;
67
+ import org .springframework .web .socket .WebSocketHandler ;
62
68
import org .springframework .web .socket .WebSocketSession ;
63
69
import org .springframework .web .socket .handler .ExceptionWebSocketHandlerDecorator ;
64
70
import org .springframework .web .socket .handler .TextWebSocketHandler ;
71
+ import org .springframework .web .socket .server .HandshakeHandler ;
72
+ import org .springframework .web .socket .server .HandshakeInterceptor ;
73
+ import org .springframework .web .socket .server .support .WebSocketHttpRequestHandler ;
65
74
66
75
/**
67
76
* WebSocketHandler for GraphQL based on
@@ -81,7 +90,9 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
81
90
82
91
private final WebGraphQlHandler graphQlHandler ;
83
92
84
- private final WebSocketGraphQlInterceptor webSocketInterceptor ;
93
+ private final ContextHandshakeInterceptor contextHandshakeInterceptor ;
94
+
95
+ private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor ;
85
96
86
97
private final Duration initTimeoutDuration ;
87
98
@@ -103,7 +114,8 @@ public GraphQlWebSocketHandler(
103
114
Assert .notNull (converter , "HttpMessageConverter for JSON is required" );
104
115
105
116
this .graphQlHandler = graphQlHandler ;
106
- this .webSocketInterceptor = this .graphQlHandler .webSocketInterceptor ();
117
+ this .contextHandshakeInterceptor = new ContextHandshakeInterceptor (graphQlHandler .getThreadLocalAccessor ());
118
+ this .webSocketGraphQlInterceptor = this .graphQlHandler .getWebSocketInterceptor ();
107
119
this .initTimeoutDuration = connectionInitTimeout ;
108
120
this .converter = converter ;
109
121
}
@@ -113,6 +125,18 @@ public List<String> getSubProtocols() {
113
125
return SUB_PROTOCOL_LIST ;
114
126
}
115
127
128
+ /**
129
+ * Return a {@link WebSocketHttpRequestHandler} that uses this instance as
130
+ * its {@link WebGraphQlHandler} and adds a {@link HandshakeInterceptor} to
131
+ * propagate context.
132
+ */
133
+ public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler (HandshakeHandler handshakeHandler ) {
134
+ WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler (this , handshakeHandler );
135
+ handler .setHandshakeInterceptors (Collections .singletonList (this .contextHandshakeInterceptor ));
136
+ return handler ;
137
+ }
138
+
139
+
116
140
@ Override
117
141
public void afterConnectionEstablished (WebSocketSession session ) {
118
142
if ("graphql-ws" .equalsIgnoreCase (session .getAcceptedProtocol ())) {
@@ -137,8 +161,15 @@ public void afterConnectionEstablished(WebSocketSession session) {
137
161
138
162
}
139
163
164
+ @ SuppressWarnings ({"unused" , "try" })
140
165
@ Override
141
166
protected void handleTextMessage (WebSocketSession session , TextMessage webSocketMessage ) throws Exception {
167
+ try (Closeable closeable = this .contextHandshakeInterceptor .restoreThreadLocalValue (session )) {
168
+ handleInternal (session , webSocketMessage );
169
+ }
170
+ }
171
+
172
+ private void handleInternal (WebSocketSession session , TextMessage webSocketMessage ) throws IOException {
142
173
GraphQlWebSocketMessage message = decode (webSocketMessage );
143
174
String id = message .getId ();
144
175
Map <String , Object > payload = message .getPayload ();
@@ -174,7 +205,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
174
205
if (subscription != null ) {
175
206
subscription .cancel ();
176
207
}
177
- this .webSocketInterceptor .handleCancelledSubscription (session .getId (), id )
208
+ this .webSocketGraphQlInterceptor .handleCancelledSubscription (session .getId (), id )
178
209
.block (Duration .ofSeconds (10 ));
179
210
}
180
211
return ;
@@ -183,7 +214,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
183
214
GraphQlStatus .closeSession (session , GraphQlStatus .TOO_MANY_INIT_REQUESTS_STATUS );
184
215
return ;
185
216
}
186
- this .webSocketInterceptor .handleConnectionInitialization (session .getId (), payload )
217
+ this .webSocketGraphQlInterceptor .handleConnectionInitialization (session .getId (), payload )
187
218
.defaultIfEmpty (Collections .emptyMap ())
188
219
.publishOn (sessionState .getScheduler ()) // Serial blocking send via single thread
189
220
.doOnNext (ackPayload -> {
@@ -285,7 +316,7 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeSta
285
316
info .dispose ();
286
317
Map <String , Object > connectionInitPayload = info .getConnectionInitPayload ();
287
318
if (connectionInitPayload != null ) {
288
- this .webSocketInterceptor .handleConnectionClosed (id , closeStatus .getCode (), connectionInitPayload );
319
+ this .webSocketGraphQlInterceptor .handleConnectionClosed (id , closeStatus .getCode (), connectionInitPayload );
289
320
}
290
321
}
291
322
}
@@ -296,6 +327,57 @@ public boolean supportsPartialMessages() {
296
327
}
297
328
298
329
330
+ /**
331
+ * {@code HandshakeInterceptor} that propagates ThreadLocal context through
332
+ * the attributes map in {@code WebSocketSession}.
333
+ */
334
+ private static class ContextHandshakeInterceptor implements HandshakeInterceptor {
335
+
336
+ private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor .class .getName ();
337
+
338
+ @ Nullable
339
+ private final ThreadLocalAccessor accessor ;
340
+
341
+ ContextHandshakeInterceptor (@ Nullable ThreadLocalAccessor accessor ) {
342
+ this .accessor = accessor ;
343
+ }
344
+
345
+ @ Override
346
+ public boolean beforeHandshake (
347
+ ServerHttpRequest request , ServerHttpResponse response , WebSocketHandler wsHandler ,
348
+ Map <String , Object > attributes ) {
349
+
350
+ if (this .accessor != null ) {
351
+ Map <String , Object > valuesMap = new LinkedHashMap <>();
352
+ this .accessor .extractValues (valuesMap );
353
+ attributes .put (SAVED_CONTEXT_KEY , valuesMap );
354
+ }
355
+ return true ;
356
+ }
357
+
358
+ @ Override
359
+ public void afterHandshake (
360
+ ServerHttpRequest request , ServerHttpResponse response , WebSocketHandler wsHandler ,
361
+ @ Nullable Exception exception ) {
362
+ }
363
+
364
+ @ SuppressWarnings ("unchecked" )
365
+ public Closeable restoreThreadLocalValue (WebSocketSession session ) {
366
+ if (this .accessor != null ) {
367
+ Map <String , Object > valuesMap = (Map <String , Object >) session .getAttributes ().get (SAVED_CONTEXT_KEY );
368
+ // Uncomment when Boot is updated to use HandshakeInterceptor
369
+ // Assert.state(valuesMap != null, "No context");
370
+ if (valuesMap != null ) {
371
+ this .accessor .restoreValues (valuesMap );
372
+ return () -> this .accessor .resetValues (valuesMap );
373
+ }
374
+ }
375
+ return () -> {};
376
+ }
377
+
378
+ }
379
+
380
+
299
381
private static class GraphQlStatus {
300
382
301
383
private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus (4400 , "Invalid message" );
0 commit comments