Skip to content

Commit fbec88d

Browse files
authored
Merge pull request #166 from dispatchrun/simplify-error-output-status-registration
Simplify registration of error/output to status mappings
2 parents bffe62c + 4618470 commit fbec88d

File tree

2 files changed

+141
-21
lines changed

2 files changed

+141
-21
lines changed

src/dispatch/status.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import Any, Callable, Dict, Type
2+
from typing import Any, Callable, Dict, Type, Union
33

44
from dispatch.sdk.v1 import status_pb2 as status_pb
55

@@ -78,16 +78,18 @@ def __str__(self):
7878
Status.NOT_FOUND.__doc__ = "An operation was performed on a non-existent resource"
7979
Status.NOT_FOUND._proto = status_pb.STATUS_NOT_FOUND
8080

81-
_ERROR_TYPES: Dict[Type[Exception], Callable[[Exception], Status]] = {}
82-
_OUTPUT_TYPES: Dict[Type[Any], Callable[[Any], Status]] = {}
81+
_ERROR_TYPES: Dict[Type[Exception], Union[Status, Callable[[Exception], Status]]] = {}
82+
_OUTPUT_TYPES: Dict[Type[Any], Union[Status, Callable[[Any], Status]]] = {}
8383

8484

8585
def status_for_error(error: BaseException) -> Status:
8686
"""Returns a Status that corresponds to the specified error."""
8787
# See if the error matches one of the registered types.
88-
handler = _find_handler(error, _ERROR_TYPES)
89-
if handler is not None:
90-
return handler(error)
88+
status_or_handler = _find_status_or_handler(error, _ERROR_TYPES)
89+
if status_or_handler is not None:
90+
if isinstance(status_or_handler, Status):
91+
return status_or_handler
92+
return status_or_handler(error)
9193
# If not, resort to standard error categorization.
9294
#
9395
# See https://docs.python.org/3/library/exceptions.html
@@ -120,28 +122,43 @@ def status_for_error(error: BaseException) -> Status:
120122
def status_for_output(output: Any) -> Status:
121123
"""Returns a Status that corresponds to the specified output value."""
122124
# See if the output value matches one of the registered types.
123-
handler = _find_handler(output, _OUTPUT_TYPES)
124-
if handler is not None:
125-
return handler(output)
125+
status_or_handler = _find_status_or_handler(output, _OUTPUT_TYPES)
126+
if status_or_handler is not None:
127+
if isinstance(status_or_handler, Status):
128+
return status_or_handler
129+
return status_or_handler(output)
126130

127131
return Status.OK
128132

129133

130134
def register_error_type(
131-
error_type: Type[Exception], handler: Callable[[Exception], Status]
135+
error_type: Type[Exception],
136+
status_or_handler: Union[Status, Callable[[Exception], Status]],
132137
):
133-
"""Register an error type, and a handler which derives a Status from
134-
errors of this type."""
135-
_ERROR_TYPES[error_type] = handler
138+
"""Register an error type to Status mapping.
139+
140+
The caller can either register a base exception and a handler, which
141+
derives a Status from errors of this type. Or, if there's only one
142+
exception to Status mapping to register, the caller can simply pass
143+
the exception class and the associated Status.
144+
"""
145+
_ERROR_TYPES[error_type] = status_or_handler
136146

137147

138-
def register_output_type(output_type: Type[Any], handler: Callable[[Any], Status]):
139-
"""Register an output type, and a handler which derives a Status from
140-
outputs of this type."""
141-
_OUTPUT_TYPES[output_type] = handler
148+
def register_output_type(
149+
output_type: Type[Any], status_or_handler: Union[Status, Callable[[Any], Status]]
150+
):
151+
"""Register an output type to Status mapping.
152+
153+
The caller can either register a base class and a handler, which
154+
derives a Status from other classes of this type. Or, if there's
155+
only one output class to Status mapping to register, the caller can
156+
simply pass the class and the associated Status.
157+
"""
158+
_OUTPUT_TYPES[output_type] = status_or_handler
142159

143160

144-
def _find_handler(obj, types):
161+
def _find_status_or_handler(obj, types):
145162
for cls in type(obj).__mro__:
146163
try:
147164
return types[cls]

tests/dispatch/test_status.py

+106-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import unittest
2+
from typing import Any
23

34
from dispatch import error
45
from dispatch.integrations.http import http_response_code_status
5-
from dispatch.status import Status, status_for_error
6+
from dispatch.status import (
7+
Status,
8+
register_error_type,
9+
register_output_type,
10+
status_for_error,
11+
status_for_output,
12+
)
613

714

815
class TestErrorStatus(unittest.TestCase):
@@ -56,13 +63,49 @@ class CustomError(Exception):
5663
pass
5764

5865
def handler(error: Exception) -> Status:
66+
assert isinstance(error, CustomError)
5967
return Status.OK
6068

61-
from dispatch.status import register_error_type
62-
6369
register_error_type(CustomError, handler)
6470
assert status_for_error(CustomError()) is Status.OK
6571

72+
def test_status_for_custom_error_with_base_handler(self):
73+
class CustomBaseError(Exception):
74+
pass
75+
76+
class CustomError(CustomBaseError):
77+
pass
78+
79+
def handler(error: Exception) -> Status:
80+
assert isinstance(error, CustomBaseError)
81+
assert isinstance(error, CustomError)
82+
return Status.TCP_ERROR
83+
84+
register_error_type(CustomBaseError, handler)
85+
assert status_for_error(CustomError()) is Status.TCP_ERROR
86+
87+
def test_status_for_custom_error_with_status(self):
88+
class CustomError(Exception):
89+
pass
90+
91+
register_error_type(CustomError, Status.THROTTLED)
92+
assert status_for_error(CustomError()) is Status.THROTTLED
93+
94+
def test_status_for_custom_error_with_base_status(self):
95+
class CustomBaseError(Exception):
96+
pass
97+
98+
class CustomError(CustomBaseError):
99+
pass
100+
101+
class CustomError2(CustomBaseError):
102+
pass
103+
104+
register_error_type(CustomBaseError, Status.THROTTLED)
105+
register_error_type(CustomError2, Status.INVALID_ARGUMENT)
106+
assert status_for_error(CustomError()) is Status.THROTTLED
107+
assert status_for_error(CustomError2()) is Status.INVALID_ARGUMENT
108+
66109
def test_status_for_custom_timeout(self):
67110
class CustomError(TimeoutError):
68111
pass
@@ -90,6 +133,66 @@ def test_status_for_DispatchError(self):
90133
assert status_for_error(error.NotFoundError()) is Status.NOT_FOUND
91134
assert status_for_error(error.DispatchError()) is Status.UNSPECIFIED
92135

136+
def test_status_for_custom_output(self):
137+
class CustomOutput:
138+
pass
139+
140+
assert status_for_output(CustomOutput()) is Status.OK # default
141+
142+
def test_status_for_custom_output_with_handler(self):
143+
class CustomOutput:
144+
pass
145+
146+
def handler(output: Any) -> Status:
147+
assert isinstance(output, CustomOutput)
148+
return Status.DNS_ERROR
149+
150+
register_output_type(CustomOutput, handler)
151+
assert status_for_output(CustomOutput()) is Status.DNS_ERROR
152+
153+
def test_status_for_custom_output_with_base_handler(self):
154+
class CustomOutputBase:
155+
pass
156+
157+
class CustomOutputError(CustomOutputBase):
158+
pass
159+
160+
class CustomOutputSuccess(CustomOutputBase):
161+
pass
162+
163+
def handler(output: Any) -> Status:
164+
assert isinstance(output, CustomOutputBase)
165+
if isinstance(output, CustomOutputError):
166+
return Status.DNS_ERROR
167+
assert isinstance(output, CustomOutputSuccess)
168+
return Status.OK
169+
170+
register_output_type(CustomOutputBase, handler)
171+
assert status_for_output(CustomOutputSuccess()) is Status.OK
172+
assert status_for_output(CustomOutputError()) is Status.DNS_ERROR
173+
174+
def test_status_for_custom_output_with_status(self):
175+
class CustomOutputBase:
176+
pass
177+
178+
class CustomOutputChild1(CustomOutputBase):
179+
pass
180+
181+
class CustomOutputChild2(CustomOutputBase):
182+
pass
183+
184+
register_output_type(CustomOutputBase, Status.PERMISSION_DENIED)
185+
register_output_type(CustomOutputChild1, Status.TCP_ERROR)
186+
assert status_for_output(CustomOutputChild1()) is Status.TCP_ERROR
187+
assert status_for_output(CustomOutputChild2()) is Status.PERMISSION_DENIED
188+
189+
def test_status_for_custom_output_with_base_status(self):
190+
class CustomOutput(Exception):
191+
pass
192+
193+
register_output_type(CustomOutput, Status.THROTTLED)
194+
assert status_for_output(CustomOutput()) is Status.THROTTLED
195+
93196

94197
class TestHTTPStatusCodes(unittest.TestCase):
95198
def test_http_response_code_status_400(self):

0 commit comments

Comments
 (0)