Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Check X-Error when WS connection fails #2568

Merged
merged 3 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.D/2568.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added handling of WebSocket connection error reason "X-Error" header.
71 changes: 43 additions & 28 deletions neuro-sdk/src/neuro_sdk/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)

import aiohttp
from aiohttp import WSMessage
from aiohttp import ClientWebSocketResponse, WSServerHandshakeError
from multidict import CIMultiDict
from yarl import URL

Expand All @@ -33,7 +33,6 @@
ServerNotAvailable,
)
from ._tracing import gen_trace_id
from ._utils import asyncgeneratorcontextmanager

log = logging.getLogger(__package__)

Expand Down Expand Up @@ -109,6 +108,23 @@ def session(self) -> aiohttp.ClientSession:
async def close(self) -> None:
pass

def _raise_error(self, status_code: int, err_text: str) -> None:
try:
payload = jsonmodule.loads(err_text)
except ValueError:
# One example would be a HEAD request for application/json
payload = {}
if "error" in payload:
err_text = payload["error"]
else:
payload = {}
if status_code == 400 and "errno" in payload:
os_errno: Any = payload["errno"]
os_errno = errno.__dict__.get(os_errno, os_errno)
raise OSError(os_errno, err_text)
err_cls = self._exception_map.get(status_code, IllegalArgumentError)
raise err_cls(err_text)

@asynccontextmanager
async def request(
self,
Expand Down Expand Up @@ -154,31 +170,21 @@ async def request(
read_bufsize=2 ** 22,
) as resp:
if 400 <= resp.status:
err_text = await resp.text()
if resp.content_type.lower() == "application/json":
try:
payload = jsonmodule.loads(err_text)
except ValueError:
# One example would be a HEAD request for application/json
payload = {}
if "error" in payload:
err_text = payload["error"]
else:
payload = {}
if resp.status == 400 and "errno" in payload:
os_errno: Any = payload["errno"]
os_errno = errno.__dict__.get(os_errno, os_errno)
raise OSError(os_errno, err_text)
err_cls = self._exception_map.get(resp.status, IllegalArgumentError)
raise err_cls(err_text)
self._raise_error(resp.status, await resp.text())
else:
yield resp

@asyncgeneratorcontextmanager
@asynccontextmanager
async def ws_connect(
self, abs_url: URL, auth: str, *, headers: Optional[Dict[str, str]] = None
) -> AsyncIterator[WSMessage]:
# TODO: timeout
self,
abs_url: URL,
*,
auth: str,
headers: Optional[Dict[str, str]] = None,
heartbeat: Optional[float] = None,
timeout: Optional[float] = 10.0,
receive_timeout: Optional[float] = None,
) -> AsyncIterator[ClientWebSocketResponse]:
assert abs_url.is_absolute(), abs_url
log.debug("Fetch web socket: %s", abs_url)

Expand All @@ -187,11 +193,20 @@ async def ws_connect(
else:
real_headers = CIMultiDict()
real_headers["Authorization"] = auth

async with self._session.ws_connect(abs_url, headers=real_headers) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
yield msg
try:
async with self._session.ws_connect(
abs_url,
headers=real_headers,
heartbeat=heartbeat,
timeout=timeout, # type: ignore
receive_timeout=receive_timeout,
) as ws:
yield ws
except WSServerHandshakeError as e:
err_text = str(e)
if e.headers:
err_text = e.headers.get("X-Error", err_text)
self._raise_error(e.status, err_text)


def _ensure_schema(db: sqlite3.Connection, *, update: bool) -> bool:
Expand Down
107 changes: 43 additions & 64 deletions neuro-sdk/src/neuro_sdk/_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from ._config import Config
from ._core import _Core
from ._errors import NDJSONError, ResourceNotFound, StdStreamError
from ._errors import NDJSONError, StdStreamError
from ._images import (
_DummyProgress,
_raise_on_error_chunk,
Expand Down Expand Up @@ -550,9 +550,10 @@ async def top(
try:
received_any = False
async with self._core.ws_connect(url, auth=auth) as ws:
async for resp in ws:
yield _job_telemetry_from_api(resp.json())
received_any = True
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
yield _job_telemetry_from_api(msg.json())
received_any = True
if not received_any:
raise ValueError(f"Job is not running. Job Id = {id}")
except WSServerHandshakeError as e:
Expand Down Expand Up @@ -639,35 +640,28 @@ async def _port_forward(
loop = asyncio.get_event_loop()
url = self._get_monitoring_url(cluster_name)
url = url / id / "port_forward" / str(job_port)
auth = await self._config._api_auth()
ws = await self._core._session.ws_connect(
async with self._core.ws_connect(
url,
headers={"Authorization": auth},
timeout=None, # type: ignore
auth=await self._config._api_auth(),
timeout=None,
receive_timeout=None,
heartbeat=30,
)
tasks = []
tasks.append(loop.create_task(self._port_reader(ws, writer)))
tasks.append(loop.create_task(self._port_writer(ws, reader)))
try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
writer.close()
await writer.wait_closed()
await ws.close()
) as ws:
tasks = []
tasks.append(loop.create_task(self._port_reader(ws, writer)))
tasks.append(loop.create_task(self._port_writer(ws, reader)))
try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
for task in tasks:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
writer.close()
await writer.wait_closed()
except asyncio.CancelledError:
raise
except WSServerHandshakeError as e:
if e.headers and "X-Error" in e.headers:
log.error(f"Error during port-forwarding: {e.headers['X-Error']}")
log.exception("Unhandled exception during port-forwarding")
writer.close()
except Exception:
log.exception("Unhandled exception during port-forwarding")
writer.close()
Expand Down Expand Up @@ -712,26 +706,17 @@ async def attach(
)
auth = await self._config._api_auth()

try:
ws = await self._core._session.ws_connect(
url,
headers={
"Authorization": auth,
aiohttp.hdrs.SEC_WEBSOCKET_PROTOCOL: "v2.channels.neu.ro",
},
timeout=None, # type: ignore
receive_timeout=None,
heartbeat=30,
)
except aiohttp.ClientResponseError as ex:
if ex.status == 404:
raise ResourceNotFound(f"Job {id!r} is not running")
raise

try:
async with self._core.ws_connect(
url,
auth=auth,
headers={
aiohttp.hdrs.SEC_WEBSOCKET_PROTOCOL: "v2.channels.neu.ro",
},
timeout=None,
receive_timeout=None,
heartbeat=30,
) as ws:
yield StdStream(ws)
finally:
await ws.close()

@asynccontextmanager
async def exec(
Expand All @@ -755,23 +740,17 @@ async def exec(
)
auth = await self._config._api_auth()

try:
ws = await self._core._session.ws_connect(
url,
headers={"Authorization": auth},
timeout=None, # type: ignore
receive_timeout=None,
heartbeat=30,
)
except aiohttp.ClientResponseError as ex:
if ex.status == 404:
raise ResourceNotFound(f"Job {id!r} is not running")
raise

try:
yield StdStream(ws)
finally:
await ws.close()
async with self._core.ws_connect(
url,
auth=auth,
timeout=None,
receive_timeout=None,
heartbeat=30,
) as ws:
try:
yield StdStream(ws)
finally:
await ws.close()

async def send_signal(self, id: str, *, cluster_name: Optional[str] = None) -> None:
url = self._get_monitoring_url(cluster_name) / id / "kill"
Expand Down
41 changes: 39 additions & 2 deletions neuro-sdk/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import sqlite3
import ssl
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -97,10 +98,11 @@ async def handler(request: web.Request) -> web.Response:
async def test_raise_for_status_contains_error_message(
aiohttp_server: _TestServerFactory, api_factory: _ApiFactory
) -> None:
ERROR_MSG = '{"error": "this is the error message"}'
ERROR_MSG = "this is the error message"
ERROR_PAYLOAD = json.dumps({"error": ERROR_MSG})

async def handler(request: web.Request) -> web.Response:
raise web.HTTPBadRequest(text=ERROR_MSG)
raise web.HTTPBadRequest(text=ERROR_PAYLOAD)

app = web.Application()
app.router.add_get("/test", handler)
Expand Down Expand Up @@ -129,6 +131,41 @@ async def handler(request: web.Request) -> web.Response:
assert resp.status == 200


async def test_raise_for_status_no_error_message_ws(
aiohttp_server: _TestServerFactory, api_factory: _ApiFactory
) -> None:
async def handler(request: web.Request) -> web.Response:
raise web.HTTPBadRequest()

app = web.Application()
app.router.add_get("/test", handler)
srv = await aiohttp_server(app)

async with api_factory(srv.make_url("/")) as api:
with pytest.raises(IllegalArgumentError, match="400"):
async with api.ws_connect(abs_url=srv.make_url("test"), auth="auth"):
pass


async def test_raise_for_status_contains_error_message_ws(
aiohttp_server: _TestServerFactory, api_factory: _ApiFactory
) -> None:
ERROR_MSG = "this is the error message"
ERROR_PAYLOAD = json.dumps({"error": ERROR_MSG})

async def handler(request: web.Request) -> web.Response:
raise web.HTTPBadRequest(text=ERROR_PAYLOAD, headers={"X-Error": ERROR_PAYLOAD})

app = web.Application()
app.router.add_get("/test", handler)
srv = await aiohttp_server(app)

async with api_factory(srv.make_url("/")) as api:
with pytest.raises(IllegalArgumentError, match=f"^{ERROR_MSG}$"):
async with api.ws_connect(abs_url=srv.make_url("test"), auth="auth"):
pass


# ### Cookies tests ###


Expand Down
2 changes: 1 addition & 1 deletion neuro-sdk/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async def test_top_nonexisting_job(
aiohttp_server: _TestServerFactory, make_client: _MakeClient
) -> None:
async def handler(request: web.Request) -> web.Response:
raise web.HTTPBadRequest()
raise web.HTTPBadRequest(headers={"X-Error": "job job-id not found"})

app = web.Application()
app.router.add_get("/jobs/job-id/top", handler)
Expand Down