|
1 | 1 | import enum
|
2 |
| -from typing import Any, Callable, Dict, Type |
| 2 | +from typing import Any, Callable, Dict, Type, Union |
3 | 3 |
|
4 | 4 | from dispatch.sdk.v1 import status_pb2 as status_pb
|
5 | 5 |
|
@@ -78,16 +78,18 @@ def __str__(self):
|
78 | 78 | Status.NOT_FOUND.__doc__ = "An operation was performed on a non-existent resource"
|
79 | 79 | Status.NOT_FOUND._proto = status_pb.STATUS_NOT_FOUND
|
80 | 80 |
|
81 |
| -_ERROR_TYPES: Dict[Type[Exception], Callable[[Exception], Status]] = {} |
82 |
| -_OUTPUT_TYPES: Dict[Type[Any], Callable[[Any], Status]] = {} |
| 81 | +_ERROR_TYPES: Dict[Type[Exception], Union[Status, Callable[[Exception], Status]]] = {} |
| 82 | +_OUTPUT_TYPES: Dict[Type[Any], Union[Status, Callable[[Any], Status]]] = {} |
83 | 83 |
|
84 | 84 |
|
85 | 85 | def status_for_error(error: BaseException) -> Status:
|
86 | 86 | """Returns a Status that corresponds to the specified error."""
|
87 | 87 | # See if the error matches one of the registered types.
|
88 |
| - handler = _find_handler(error, _ERROR_TYPES) |
89 |
| - if handler is not None: |
90 |
| - return handler(error) |
| 88 | + status_or_handler = _find_status_or_handler(error, _ERROR_TYPES) |
| 89 | + if status_or_handler is not None: |
| 90 | + if isinstance(status_or_handler, Status): |
| 91 | + return status_or_handler |
| 92 | + return status_or_handler(error) |
91 | 93 | # If not, resort to standard error categorization.
|
92 | 94 | #
|
93 | 95 | # See https://docs.python.org/3/library/exceptions.html
|
@@ -120,28 +122,43 @@ def status_for_error(error: BaseException) -> Status:
|
120 | 122 | def status_for_output(output: Any) -> Status:
|
121 | 123 | """Returns a Status that corresponds to the specified output value."""
|
122 | 124 | # See if the output value matches one of the registered types.
|
123 |
| - handler = _find_handler(output, _OUTPUT_TYPES) |
124 |
| - if handler is not None: |
125 |
| - return handler(output) |
| 125 | + status_or_handler = _find_status_or_handler(output, _OUTPUT_TYPES) |
| 126 | + if status_or_handler is not None: |
| 127 | + if isinstance(status_or_handler, Status): |
| 128 | + return status_or_handler |
| 129 | + return status_or_handler(output) |
126 | 130 |
|
127 | 131 | return Status.OK
|
128 | 132 |
|
129 | 133 |
|
130 | 134 | def register_error_type(
|
131 |
| - error_type: Type[Exception], handler: Callable[[Exception], Status] |
| 135 | + error_type: Type[Exception], |
| 136 | + status_or_handler: Union[Status, Callable[[Exception], Status]], |
132 | 137 | ):
|
133 |
| - """Register an error type, and a handler which derives a Status from |
134 |
| - errors of this type.""" |
135 |
| - _ERROR_TYPES[error_type] = handler |
| 138 | + """Register an error type to Status mapping. |
| 139 | +
|
| 140 | + The caller can either register a base exception and a handler, which |
| 141 | + derives a Status from errors of this type. Or, if there's only one |
| 142 | + exception to Status mapping to register, the caller can simply pass |
| 143 | + the exception class and the associated Status. |
| 144 | + """ |
| 145 | + _ERROR_TYPES[error_type] = status_or_handler |
136 | 146 |
|
137 | 147 |
|
138 |
| -def register_output_type(output_type: Type[Any], handler: Callable[[Any], Status]): |
139 |
| - """Register an output type, and a handler which derives a Status from |
140 |
| - outputs of this type.""" |
141 |
| - _OUTPUT_TYPES[output_type] = handler |
| 148 | +def register_output_type( |
| 149 | + output_type: Type[Any], status_or_handler: Union[Status, Callable[[Any], Status]] |
| 150 | +): |
| 151 | + """Register an output type to Status mapping. |
| 152 | +
|
| 153 | + The caller can either register a base class and a handler, which |
| 154 | + derives a Status from other classes of this type. Or, if there's |
| 155 | + only one output class to Status mapping to register, the caller can |
| 156 | + simply pass the class and the associated Status. |
| 157 | + """ |
| 158 | + _OUTPUT_TYPES[output_type] = status_or_handler |
142 | 159 |
|
143 | 160 |
|
144 |
| -def _find_handler(obj, types): |
| 161 | +def _find_status_or_handler(obj, types): |
145 | 162 | for cls in type(obj).__mro__:
|
146 | 163 | try:
|
147 | 164 | return types[cls]
|
|
0 commit comments