Skip to content

Commit bccc800

Browse files
authored
Merge pull request #153 from stealthrocket/run-request-dispatch-id
Extract new proto fields
2 parents 93439d8 + 9351491 commit bccc800

File tree

7 files changed

+69
-19
lines changed

7 files changed

+69
-19
lines changed

src/dispatch/proto.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tblib # type: ignore[import-untyped]
1313
from google.protobuf import descriptor_pool, duration_pb2, message_factory
1414

15+
from dispatch.id import DispatchID
1516
from dispatch.sdk.v1 import call_pb2 as call_pb
1617
from dispatch.sdk.v1 import error_pb2 as error_pb
1718
from dispatch.sdk.v1 import exit_pb2 as exit_pb
@@ -51,6 +52,9 @@ class Input:
5152
"""
5253

5354
__slots__ = (
55+
"dispatch_id",
56+
"parent_dispatch_id",
57+
"root_dispatch_id",
5458
"_has_input",
5559
"_input",
5660
"_coroutine_state",
@@ -59,6 +63,10 @@ class Input:
5963
)
6064

6165
def __init__(self, req: function_pb.RunRequest):
66+
self.dispatch_id = req.dispatch_id
67+
self.parent_dispatch_id = req.parent_dispatch_id
68+
self.root_dispatch_id = req.root_dispatch_id
69+
6270
self._has_input = req.HasField("input")
6371
if self._has_input:
6472
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
@@ -285,6 +293,7 @@ class CallResult:
285293
correlation_id: Optional[int] = None
286294
output: Optional[Any] = None
287295
error: Optional[Error] = None
296+
dispatch_id: DispatchID = ""
288297

289298
def _as_proto(self) -> call_pb.CallResult:
290299
output_any = None
@@ -295,7 +304,10 @@ def _as_proto(self) -> call_pb.CallResult:
295304
error_proto = self.error._as_proto()
296305

297306
return call_pb.CallResult(
298-
correlation_id=self.correlation_id, output=output_any, error=error_proto
307+
correlation_id=self.correlation_id,
308+
output=output_any,
309+
error=error_proto,
310+
dispatch_id=self.dispatch_id,
299311
)
300312

301313
@classmethod
@@ -308,7 +320,10 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
308320
error = Error._from_proto(proto.error)
309321

310322
return CallResult(
311-
correlation_id=proto.correlation_id, output=output, error=error
323+
correlation_id=proto.correlation_id,
324+
output=output,
325+
error=error,
326+
dispatch_id=proto.dispatch_id,
312327
)
313328

314329
@classmethod

src/dispatch/sdk/v1/call_pb2.py

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/call_pb2.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,19 @@ class Call(_message.Message):
3535
) -> None: ...
3636

3737
class CallResult(_message.Message):
38-
__slots__ = ("correlation_id", "output", "error")
38+
__slots__ = ("correlation_id", "output", "error", "dispatch_id")
3939
CORRELATION_ID_FIELD_NUMBER: _ClassVar[int]
4040
OUTPUT_FIELD_NUMBER: _ClassVar[int]
4141
ERROR_FIELD_NUMBER: _ClassVar[int]
42+
DISPATCH_ID_FIELD_NUMBER: _ClassVar[int]
4243
correlation_id: int
4344
output: _any_pb2.Any
4445
error: _error_pb2.Error
46+
dispatch_id: str
4547
def __init__(
4648
self,
4749
correlation_id: _Optional[int] = ...,
4850
output: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...,
4951
error: _Optional[_Union[_error_pb2.Error, _Mapping]] = ...,
52+
dispatch_id: _Optional[str] = ...,
5053
) -> None: ...

src/dispatch/sdk/v1/dispatch_pb2.py

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/function_pb2.py

+6-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/function_pb2.pyi

+17-1
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,34 @@ from dispatch.sdk.v1 import status_pb2 as _status_pb2
1414
DESCRIPTOR: _descriptor.FileDescriptor
1515

1616
class RunRequest(_message.Message):
17-
__slots__ = ("function", "input", "poll_result")
17+
__slots__ = (
18+
"function",
19+
"input",
20+
"poll_result",
21+
"dispatch_id",
22+
"parent_dispatch_id",
23+
"root_dispatch_id",
24+
)
1825
FUNCTION_FIELD_NUMBER: _ClassVar[int]
1926
INPUT_FIELD_NUMBER: _ClassVar[int]
2027
POLL_RESULT_FIELD_NUMBER: _ClassVar[int]
28+
DISPATCH_ID_FIELD_NUMBER: _ClassVar[int]
29+
PARENT_DISPATCH_ID_FIELD_NUMBER: _ClassVar[int]
30+
ROOT_DISPATCH_ID_FIELD_NUMBER: _ClassVar[int]
2131
function: str
2232
input: _any_pb2.Any
2333
poll_result: _poll_pb2.PollResult
34+
dispatch_id: str
35+
parent_dispatch_id: str
36+
root_dispatch_id: str
2437
def __init__(
2538
self,
2639
function: _Optional[str] = ...,
2740
input: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...,
2841
poll_result: _Optional[_Union[_poll_pb2.PollResult, _Mapping]] = ...,
42+
dispatch_id: _Optional[str] = ...,
43+
parent_dispatch_id: _Optional[str] = ...,
44+
root_dispatch_id: _Optional[str] = ...,
2945
) -> None: ...
3046

3147
class RunResponse(_message.Message):

src/dispatch/test/service.py

+16
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def Dispatch(self, request: dispatch_pb.DispatchRequest, context):
110110
run_request = function_pb.RunRequest(
111111
function=call.function,
112112
input=call.input,
113+
dispatch_id=dispatch_id,
114+
root_dispatch_id=dispatch_id,
113115
)
114116
self.queue.append((dispatch_id, run_request, CallType.CALL))
115117

@@ -207,6 +209,8 @@ def dispatch_calls(self):
207209
assert dispatch_id not in self.pollers
208210
poller = Poller(
209211
id=dispatch_id,
212+
parent_id=request.parent_dispatch_id,
213+
root_id=request.root_dispatch_id,
210214
function=request.function,
211215
coroutine_state=response.poll.coroutine_state,
212216
waiting={},
@@ -219,6 +223,9 @@ def dispatch_calls(self):
219223
child_request = function_pb.RunRequest(
220224
function=call.function,
221225
input=call.input,
226+
dispatch_id=child_dispatch_id,
227+
parent_dispatch_id=request.dispatch_id,
228+
root_dispatch_id=request.root_dispatch_id,
222229
)
223230

224231
_next_queue.append(
@@ -239,6 +246,9 @@ def dispatch_calls(self):
239246
tail_call_request = function_pb.RunRequest(
240247
function=tail_call.function,
241248
input=tail_call.input,
249+
dispatch_id=request.dispatch_id,
250+
parent_dispatch_id=request.parent_dispatch_id,
251+
root_dispatch_id=request.root_dispatch_id,
242252
)
243253
_next_queue.append((dispatch_id, tail_call_request, CallType.CALL))
244254

@@ -269,6 +279,9 @@ def dispatch_calls(self):
269279
len(poller.results),
270280
)
271281
poll_results_request = function_pb.RunRequest(
282+
dispatch_id=poller.id,
283+
parent_dispatch_id=poller.parent_id,
284+
root_dispatch_id=poller.root_id,
272285
function=poller.function,
273286
poll_result=poll_pb.PollResult(
274287
coroutine_state=poller.coroutine_state,
@@ -349,6 +362,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
349362
@dataclass
350363
class Poller:
351364
id: DispatchID
365+
parent_id: DispatchID
366+
root_id: DispatchID
367+
352368
function: str
353369

354370
coroutine_state: bytes

0 commit comments

Comments
 (0)