Skip to content

Commit

Permalink
Merge pull request #10 from channelcat/master
Browse files Browse the repository at this point in the history
merge upstream master branch
  • Loading branch information
yunstanford authored Aug 10, 2017
2 parents 80f27b1 + df4a149 commit bda6c85
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sanic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sanic.app import Sanic
from sanic.blueprints import Blueprint

__version__ = '0.5.4'
__version__ = '0.6.0'

__all__ = ['Sanic', 'Blueprint']
7 changes: 5 additions & 2 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,12 @@ def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
return handler

# Decorator
def websocket(self, uri, host=None, strict_slashes=False):
def websocket(self, uri, host=None, strict_slashes=False,
subprotocols=None):
"""Decorate a function to be registered as a websocket route
:param uri: path of the URL
:param subprotocols: optional list of strings with the supported
subprotocols
:param host:
:return: decorated function
"""
Expand All @@ -236,7 +239,7 @@ async def websocket_handler(request, *args, **kwargs):
# On Python3.5 the Transport classes in asyncio do not
# have a get_protocol() method as in uvloop
protocol = request.transport._protocol
ws = await protocol.websocket_handshake(request)
ws = await protocol.websocket_handshake(request, subprotocols)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
Expand Down
5 changes: 3 additions & 2 deletions sanic/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def cookies(self):

class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'transport', 'streaming_fn',
'status', 'content_type', 'headers', '_cookies')
'transport', 'streaming_fn', 'status',
'content_type', 'headers', '_cookies'
)

def __init__(self, streaming_fn, status=200, headers=None,
content_type='text/plain'):
Expand Down
14 changes: 13 additions & 1 deletion sanic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def write_response(self, response):
else:
super().write_response(response)

async def websocket_handshake(self, request):
async def websocket_handshake(self, request, subprotocols=None):
# let the websockets package do the handshake with the client
headers = []

Expand All @@ -57,6 +57,17 @@ def set_header(k, v):
except InvalidHandshake:
raise InvalidUsage('Invalid websocket request')

subprotocol = None
if subprotocols and 'Sec-Websocket-Protocol' in request.headers:
# select a subprotocol
client_subprotocols = [p.strip() for p in request.headers[
'Sec-Websocket-Protocol'].split(',')]
for p in client_subprotocols:
if p in subprotocols:
subprotocol = p
set_header('Sec-Websocket-Protocol', subprotocol)
break

# write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
for k, v in headers:
Expand All @@ -69,5 +80,6 @@ def set_header(k, v):
max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue
)
self.websocket.subprotocol = subprotocol
self.websocket.connection_made(request.transport)
return self.websocket
43 changes: 43 additions & 0 deletions tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def test_websocket_route():

@app.websocket('/ws')
async def handler(request, ws):
assert ws.subprotocol is None
ev.set()

request, response = app.test_client.get('/ws', headers={
Expand All @@ -352,6 +353,48 @@ async def handler(request, ws):
assert ev.is_set()


def test_websocket_route_with_subprotocols():
app = Sanic('test_websocket_route')
results = []

@app.websocket('/ws', subprotocols=['foo', 'bar'])
async def handler(request, ws):
results.append(ws.subprotocol)

request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'bar'})
assert response.status == 101

request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'bar, foo'})
assert response.status == 101

request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Protocol': 'baz'})
assert response.status == 101

request, response = app.test_client.get('/ws', headers={
'Upgrade': 'websocket',
'Connection': 'upgrade',
'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
'Sec-WebSocket-Version': '13'})
assert response.status == 101

assert results == ['bar', 'bar', None, None]


def test_route_duplicate():
app = Sanic('test_route_duplicate')

Expand Down

0 comments on commit bda6c85

Please # to comment.