Skip to content

Commit 4a344ad

Browse files
authored
Merge pull request #177 from dispatchrun/bytes-containers
Wrap pickled values in dispatch.sdk.python.v1 container
2 parents 6661a09 + 9fb79ef commit 4a344ad

18 files changed

+195
-143
lines changed

src/dispatch/proto.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import tblib # type: ignore[import-untyped]
1313
from google.protobuf import descriptor_pool, duration_pb2, message_factory
1414

15+
from dispatch.error import IncompatibleStateError, InvalidArgumentError
1516
from dispatch.id import DispatchID
17+
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
1618
from dispatch.sdk.v1 import call_pb2 as call_pb
1719
from dispatch.sdk.v1 import error_pb2 as error_pb
1820
from dispatch.sdk.v1 import exit_pb2 as exit_pb
@@ -77,18 +79,11 @@ def __init__(self, req: function_pb.RunRequest):
7779

7880
self._has_input = req.HasField("input")
7981
if self._has_input:
80-
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
81-
input_pb = google.protobuf.wrappers_pb2.BytesValue()
82-
req.input.Unpack(input_pb)
83-
input_bytes = input_pb.value
84-
try:
85-
self._input = pickle.loads(input_bytes)
86-
except Exception as e:
87-
self._input = input_bytes
88-
else:
89-
self._input = _pb_any_unpack(req.input)
82+
self._input = _pb_any_unpack(req.input)
9083
else:
91-
self._coroutine_state = req.poll_result.coroutine_state
84+
if req.poll_result.coroutine_state:
85+
raise IncompatibleStateError # coroutine_state is deprecated
86+
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
9287
self._call_results = [
9388
CallResult._from_proto(r) for r in req.poll_result.results
9489
]
@@ -155,15 +150,15 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
155150
def from_poll_results(
156151
cls,
157152
function: str,
158-
coroutine_state: Optional[bytes],
153+
coroutine_state: Any,
159154
call_results: List[CallResult],
160155
error: Optional[Error] = None,
161156
):
162157
return Input(
163158
req=function_pb.RunRequest(
164159
function=function,
165160
poll_result=poll_pb.PollResult(
166-
coroutine_state=coroutine_state,
161+
typed_coroutine_state=_pb_any_pickle(coroutine_state),
167162
results=[result._as_proto() for result in call_results],
168163
error=error._as_proto() if error else None,
169164
),
@@ -232,7 +227,7 @@ def exit(
232227
@classmethod
233228
def poll(
234229
cls,
235-
coroutine_state: Optional[bytes] = None,
230+
coroutine_state: Any = None,
236231
calls: Optional[List[Call]] = None,
237232
min_results: int = 1,
238233
max_results: int = 10,
@@ -247,7 +242,7 @@ def poll(
247242
else None
248243
)
249244
poll = poll_pb.Poll(
250-
coroutine_state=coroutine_state,
245+
typed_coroutine_state=_pb_any_pickle(coroutine_state),
251246
min_results=min_results,
252247
max_results=max_results,
253248
max_wait=max_wait,
@@ -447,21 +442,47 @@ def _as_proto(self) -> error_pb.Error:
447442

448443

449444
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
450-
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
451-
return pickle.loads(value_bytes.value)
445+
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
446+
p = pickled_pb.Pickled()
447+
any.Unpack(p)
448+
return pickle.loads(p.pickled_value)
449+
450+
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
451+
b = google.protobuf.wrappers_pb2.BytesValue()
452+
any.Unpack(b)
453+
return pickle.loads(b.value)
454+
455+
elif not any.type_url and not any.value:
456+
return None
457+
458+
raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")
459+
460+
461+
def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
462+
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
463+
any = google.protobuf.any_pb2.Any()
464+
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
465+
return any
452466

453467

454-
def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
455-
value_bytes = pickle.dumps(x)
456-
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
457-
pb_any = google.protobuf.any_pb2.Any()
458-
pb_any.Pack(pb_bytes)
459-
return pb_any
468+
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
469+
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
470+
p = pickled_pb.Pickled()
471+
any.Unpack(p)
472+
return pickle.loads(p.pickled_value)
460473

474+
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
475+
b = google.protobuf.wrappers_pb2.BytesValue()
476+
any.Unpack(b)
477+
try:
478+
# Assume it's the legacy container for pickled values.
479+
return pickle.loads(b.value)
480+
except Exception as e:
481+
# Otherwise, return the literal bytes.
482+
return b.value
461483

462-
def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
463484
pool = descriptor_pool.Default()
464-
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
485+
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
465486
proto = message_factory.GetMessageClass(msg_descriptor)()
466-
x.Unpack(proto)
487+
any.Unpack(proto)
467488
return proto

src/dispatch/scheduler.py

+17-27
Original file line numberDiff line numberDiff line change
@@ -357,19 +357,17 @@ def _init_state(self, input: Input) -> State:
357357
)
358358

359359
def _rebuild_state(self, input: Input):
360-
logger.debug(
361-
"resuming scheduler with %d bytes of state", len(input.coroutine_state)
362-
)
360+
logger.info("resuming main coroutine")
363361
try:
364-
state = pickle.loads(input.coroutine_state)
362+
state = input.coroutine_state
365363
if not isinstance(state, State):
366364
raise ValueError("invalid state")
367365
if state.version != self.version:
368366
raise ValueError(
369367
f"version mismatch: '{state.version}' vs. current '{self.version}'"
370368
)
371369
return state
372-
except (pickle.PickleError, ValueError) as e:
370+
except ValueError as e:
373371
logger.warning("state is incompatible", exc_info=True)
374372
raise IncompatibleStateError from e
375373

@@ -454,32 +452,24 @@ async def _run(self, input: Input) -> Output:
454452
await asyncio.gather(*asyncio_tasks, return_exceptions=True)
455453
return coroutine_result
456454

457-
# Serialize coroutines and scheduler state.
458-
logger.debug("serializing state")
455+
# Yield to Dispatch.
456+
logger.debug("yielding to Dispatch with %d call(s)", len(pending_calls))
459457
try:
460-
serialized_state = pickle.dumps(state)
458+
return Output.poll(
459+
coroutine_state=state,
460+
calls=pending_calls,
461+
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
462+
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
463+
max_wait_seconds=self.poll_max_wait_seconds,
464+
)
461465
except pickle.PickleError as e:
462466
logger.exception("state could not be serialized")
463467
return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR))
464-
465-
# Close coroutines before yielding.
466-
for suspended in state.suspended.values():
467-
suspended.coroutine.close()
468-
state.suspended = {}
469-
470-
# Yield to Dispatch.
471-
logger.debug(
472-
"yielding to Dispatch with %d call(s) and %d bytes of state",
473-
len(pending_calls),
474-
len(serialized_state),
475-
)
476-
return Output.poll(
477-
coroutine_state=serialized_state,
478-
calls=pending_calls,
479-
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
480-
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
481-
max_wait_seconds=self.poll_max_wait_seconds,
482-
)
468+
finally:
469+
# Close coroutines.
470+
for suspended in state.suspended.values():
471+
suspended.coroutine.close()
472+
state.suspended = {}
483473

484474

485475
async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]):

src/dispatch/sdk/python/__init__.py

Whitespace-only changes.

src/dispatch/sdk/python/v1/__init__.py

Whitespace-only changes.

src/dispatch/sdk/python/v1/pickled_pb2.py

+32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import ClassVar as _ClassVar
2+
from typing import Optional as _Optional
3+
4+
from google.protobuf import descriptor as _descriptor
5+
from google.protobuf import message as _message
6+
7+
DESCRIPTOR: _descriptor.FileDescriptor
8+
9+
class Pickled(_message.Message):
10+
__slots__ = ("pickled_value",)
11+
PICKLED_VALUE_FIELD_NUMBER: _ClassVar[int]
12+
pickled_value: bytes
13+
def __init__(self, pickled_value: _Optional[bytes] = ...) -> None: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2+
"""Client and server classes corresponding to protobuf-defined services."""
3+
import grpc

src/dispatch/sdk/v1/call_pb2.py

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/call_pb2.pyi

+11-1
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,34 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2
1414
DESCRIPTOR: _descriptor.FileDescriptor
1515

1616
class Call(_message.Message):
17-
__slots__ = ("correlation_id", "endpoint", "function", "input", "expiration")
17+
__slots__ = (
18+
"correlation_id",
19+
"endpoint",
20+
"function",
21+
"input",
22+
"expiration",
23+
"version",
24+
)
1825
CORRELATION_ID_FIELD_NUMBER: _ClassVar[int]
1926
ENDPOINT_FIELD_NUMBER: _ClassVar[int]
2027
FUNCTION_FIELD_NUMBER: _ClassVar[int]
2128
INPUT_FIELD_NUMBER: _ClassVar[int]
2229
EXPIRATION_FIELD_NUMBER: _ClassVar[int]
30+
VERSION_FIELD_NUMBER: _ClassVar[int]
2331
correlation_id: int
2432
endpoint: str
2533
function: str
2634
input: _any_pb2.Any
2735
expiration: _duration_pb2.Duration
36+
version: str
2837
def __init__(
2938
self,
3039
correlation_id: _Optional[int] = ...,
3140
endpoint: _Optional[str] = ...,
3241
function: _Optional[str] = ...,
3342
input: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...,
3443
expiration: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ...,
44+
version: _Optional[str] = ...,
3545
) -> None: ...
3646

3747
class CallResult(_message.Message):

src/dispatch/sdk/v1/dispatch_pb2.py

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)