Skip to content

Commit 9a74971

Browse files
committed
Fix a bug where websocket subprotocols were not forwarded
1 parent 2f41556 commit 9a74971

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

src/rules/websockets/websocket-handlers.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
} from '../../util/request-utils';
2626
import {
2727
findRawHeader,
28+
findRawHeaders,
2829
objectHeadersToRaw,
2930
pairFlatRawHeaders,
3031
rawHeadersToObjectPreservingCase
@@ -329,7 +330,12 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
329330
// header object internally.
330331
const headers = rawHeadersToObjectPreservingCase(rawHeaders);
331332

332-
const upstreamWebSocket = new WebSocket(wsUrl, {
333+
// Subprotocols have to be handled explicitly. WS takes control of the headers itself,
334+
// and checks the response, so we need to parse the client headers and use them manually:
335+
const subprotocols = findRawHeaders(rawHeaders, 'sec-websocket-protocol')
336+
.flatMap(([_k, value]) => value.split(',').map(p => p.trim()));
337+
338+
const upstreamWebSocket = new WebSocket(wsUrl, subprotocols, {
333339
maxPayload: 0,
334340
agent,
335341
lookup: getDnsLookupFunction(this.lookupOptions),

src/util/header-utils.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export const findRawHeader = (rawHeaders: RawHeaders, targetKey: string) =>
3030
export const findRawHeaderIndex = (rawHeaders: RawHeaders, targetKey: string) =>
3131
rawHeaders.findIndex(([key]) => key.toLowerCase() === targetKey);
3232

33-
const findRawHeaders = (rawHeaders: RawHeaders, targetKey: string) =>
33+
export const findRawHeaders = (rawHeaders: RawHeaders, targetKey: string) =>
3434
rawHeaders.filter(([key]) => key.toLowerCase() === targetKey);
3535

3636
/**

test/integration/websockets.spec.ts

+12-7
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,17 @@ nodeOnly(() => {
146146
it("forwards the incoming requests's headers", async () => {
147147
mockServer.forAnyWebSocket().thenPassThrough();
148148

149-
const ws = new WebSocket(`ws://localhost:${wsPort}`, {
150-
agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`),
151-
headers: {
152-
'echo-headers': 'true',
153-
'Funky-HEADER-casing': 'Header-Value'
149+
const ws = new WebSocket(
150+
`ws://localhost:${wsPort}`,
151+
['subprotocol-a', 'subprotocol-b'],
152+
{
153+
agent: new HttpProxyAgent(`http://localhost:${mockServer.port}`),
154+
headers: {
155+
'echo-headers': 'true',
156+
'Funky-HEADER-casing': 'Header-Value'
157+
}
154158
}
155-
});
159+
);
156160

157161
const response = await new Promise<Buffer>((resolve, reject) => {
158162
ws.on('message', resolve);
@@ -172,7 +176,8 @@ nodeOnly(() => {
172176
[ 'Sec-WebSocket-Version', '13' ],
173177
[ 'Connection', 'Upgrade' ],
174178
[ 'Upgrade', 'websocket' ],
175-
[ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ]
179+
[ 'Sec-WebSocket-Extensions', 'permessage-deflate; client_max_window_bits' ],
180+
[ 'Sec-WebSocket-Protocol', 'subprotocol-a,subprotocol-b' ]
176181
]);
177182
});
178183

0 commit comments

Comments
 (0)