diff --git a/lib/websocket.js b/lib/websocket.js index 280364a1b..50d85576a 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -660,22 +660,50 @@ function initAsClient(websocket, address, protocols, options) { if (serverProt) websocket._protocol = serverProt; - if (perMessageDeflate) { + const secWebSocketExtensions = res.headers['sec-websocket-extensions']; + + if (secWebSocketExtensions !== undefined) { + if (!perMessageDeflate) { + const message = + 'Server sent a Sec-WebSocket-Extensions header but no extension ' + + 'was requested'; + abortHandshake(websocket, socket, message); + return; + } + + let extensions; + try { - const extensions = parse(res.headers['sec-websocket-extensions']); + extensions = parse(secWebSocketExtensions); + } catch (err) { + const message = 'Invalid Sec-WebSocket-Extensions header'; + abortHandshake(websocket, socket, message); + return; + } + + const extensionNames = Object.keys(extensions); + + if (extensionNames.length) { + if ( + extensionNames.length !== 1 || + extensionNames[0] !== PerMessageDeflate.extensionName + ) { + const message = + 'Server indicated an extension that was not requested'; + abortHandshake(websocket, socket, message); + return; + } - if (extensions[PerMessageDeflate.extensionName]) { + try { perMessageDeflate.accept(extensions[PerMessageDeflate.extensionName]); - websocket._extensions[PerMessageDeflate.extensionName] = - perMessageDeflate; + } catch (err) { + const message = 'Invalid Sec-WebSocket-Extensions header'; + abortHandshake(websocket, socket, message); + return; } - } catch (err) { - abortHandshake( - websocket, - socket, - 'Invalid Sec-WebSocket-Extensions header' - ); - return; + + websocket._extensions[PerMessageDeflate.extensionName] = + perMessageDeflate; } } diff --git a/test/websocket.test.js b/test/websocket.test.js index 5642d968d..a1a1cda43 100644 --- a/test/websocket.test.js +++ b/test/websocket.test.js @@ -663,7 +663,40 @@ describe('WebSocket', () => { }); }); - it('fails if the Sec-WebSocket-Extensions response header is invalid', (done) => { + it('fails if an unexpected Sec-WebSocket-Extensions header is received', (done) => { + server.once('upgrade', (req, socket) => { + const key = crypto + .createHash('sha1') + .update(req.headers['sec-websocket-key'] + GUID) + .digest('base64'); + + socket.end( + 'HTTP/1.1 101 Switching Protocols\r\n' + + 'Upgrade: websocket\r\n' + + 'Connection: Upgrade\r\n' + + `Sec-WebSocket-Accept: ${key}\r\n` + + 'Sec-WebSocket-Extensions: foo\r\n' + + '\r\n' + ); + }); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`, { + perMessageDeflate: false + }); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual( + err.message, + 'Server sent a Sec-WebSocket-Extensions header but no extension ' + + 'was requested' + ); + ws.on('close', () => done()); + }); + }); + + it('fails if the Sec-WebSocket-Extensions header is invalid (1/2)', (done) => { server.once('upgrade', (req, socket) => { const key = crypto .createHash('sha1') @@ -693,6 +726,97 @@ describe('WebSocket', () => { }); }); + it('fails if the Sec-WebSocket-Extensions header is invalid (2/2)', (done) => { + server.once('upgrade', (req, socket) => { + const key = crypto + .createHash('sha1') + .update(req.headers['sec-websocket-key'] + GUID) + .digest('base64'); + + socket.end( + 'HTTP/1.1 101 Switching Protocols\r\n' + + 'Upgrade: websocket\r\n' + + 'Connection: Upgrade\r\n' + + `Sec-WebSocket-Accept: ${key}\r\n` + + 'Sec-WebSocket-Extensions: ' + + 'permessage-deflate; client_max_window_bits=7\r\n' + + '\r\n' + ); + }); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual( + err.message, + 'Invalid Sec-WebSocket-Extensions header' + ); + ws.on('close', () => done()); + }); + }); + + it('fails if an unexpected extension is received (1/2)', (done) => { + server.once('upgrade', (req, socket) => { + const key = crypto + .createHash('sha1') + .update(req.headers['sec-websocket-key'] + GUID) + .digest('base64'); + + socket.end( + 'HTTP/1.1 101 Switching Protocols\r\n' + + 'Upgrade: websocket\r\n' + + 'Connection: Upgrade\r\n' + + `Sec-WebSocket-Accept: ${key}\r\n` + + 'Sec-WebSocket-Extensions: foo\r\n' + + '\r\n' + ); + }); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual( + err.message, + 'Server indicated an extension that was not requested' + ); + ws.on('close', () => done()); + }); + }); + + it('fails if an unexpected extension is received (2/2)', (done) => { + server.once('upgrade', (req, socket) => { + const key = crypto + .createHash('sha1') + .update(req.headers['sec-websocket-key'] + GUID) + .digest('base64'); + + socket.end( + 'HTTP/1.1 101 Switching Protocols\r\n' + + 'Upgrade: websocket\r\n' + + 'Connection: Upgrade\r\n' + + `Sec-WebSocket-Accept: ${key}\r\n` + + 'Sec-WebSocket-Extensions: permessage-deflate,foo\r\n' + + '\r\n' + ); + }); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual( + err.message, + 'Server indicated an extension that was not requested' + ); + ws.on('close', () => done()); + }); + }); + it('fails if server sends a subprotocol when none was requested', (done) => { const wss = new WebSocket.Server({ server });