Skip to content

Commit 90b97d9

Browse files
authored
Merge pull request #179 from dispatchrun/use-proto-for-serialization-if-available
Avoid pickling primitive values and proto messages
2 parents bb6ec79 + 6b283a4 commit 90b97d9

File tree

5 files changed

+293
-59
lines changed

5 files changed

+293
-59
lines changed

src/dispatch/any.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from __future__ import annotations
2+
3+
import pickle
4+
from datetime import datetime, timedelta, timezone
5+
from typing import Any
6+
7+
import google.protobuf.any_pb2
8+
import google.protobuf.duration_pb2
9+
import google.protobuf.empty_pb2
10+
import google.protobuf.message
11+
import google.protobuf.struct_pb2
12+
import google.protobuf.timestamp_pb2
13+
import google.protobuf.wrappers_pb2
14+
from google.protobuf import descriptor_pool, message_factory
15+
16+
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
17+
18+
INT64_MIN = -9223372036854775808
19+
INT64_MAX = 9223372036854775807
20+
21+
22+
def marshal_any(value: Any) -> google.protobuf.any_pb2.Any:
23+
if value is None:
24+
value = google.protobuf.empty_pb2.Empty()
25+
elif isinstance(value, bool):
26+
value = google.protobuf.wrappers_pb2.BoolValue(value=value)
27+
elif isinstance(value, int) and INT64_MIN <= value <= INT64_MAX:
28+
# To keep things simple, serialize all integers as int64 on the wire.
29+
# For larger integers, fall through and use pickle.
30+
value = google.protobuf.wrappers_pb2.Int64Value(value=value)
31+
elif isinstance(value, float):
32+
value = google.protobuf.wrappers_pb2.DoubleValue(value=value)
33+
elif isinstance(value, str):
34+
value = google.protobuf.wrappers_pb2.StringValue(value=value)
35+
elif isinstance(value, bytes):
36+
value = google.protobuf.wrappers_pb2.BytesValue(value=value)
37+
elif isinstance(value, datetime):
38+
# Note: datetime only supports microsecond granularity
39+
seconds = int(value.timestamp())
40+
nanos = value.microsecond * 1000
41+
value = google.protobuf.timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
42+
elif isinstance(value, timedelta):
43+
# Note: timedelta only supports microsecond granularity
44+
seconds = int(value.total_seconds())
45+
nanos = value.microseconds * 1000
46+
value = google.protobuf.duration_pb2.Duration(seconds=seconds, nanos=nanos)
47+
48+
if isinstance(value, list) or isinstance(value, dict):
49+
try:
50+
value = as_struct_value(value)
51+
except ValueError:
52+
pass # fallthrough
53+
54+
if not isinstance(value, google.protobuf.message.Message):
55+
value = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
56+
57+
any = google.protobuf.any_pb2.Any()
58+
if value.DESCRIPTOR.full_name.startswith("dispatch.sdk."):
59+
any.Pack(value, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
60+
else:
61+
any.Pack(value)
62+
63+
return any
64+
65+
66+
def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any:
67+
pool = descriptor_pool.Default()
68+
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
69+
proto = message_factory.GetMessageClass(msg_descriptor)()
70+
any.Unpack(proto)
71+
72+
if isinstance(proto, pickled_pb.Pickled):
73+
return pickle.loads(proto.pickled_value)
74+
75+
elif isinstance(proto, google.protobuf.empty_pb2.Empty):
76+
return None
77+
78+
elif isinstance(proto, google.protobuf.wrappers_pb2.BoolValue):
79+
return proto.value
80+
81+
elif isinstance(proto, google.protobuf.wrappers_pb2.Int32Value):
82+
return proto.value
83+
84+
elif isinstance(proto, google.protobuf.wrappers_pb2.Int64Value):
85+
return proto.value
86+
87+
elif isinstance(proto, google.protobuf.wrappers_pb2.UInt32Value):
88+
return proto.value
89+
90+
elif isinstance(proto, google.protobuf.wrappers_pb2.UInt64Value):
91+
return proto.value
92+
93+
elif isinstance(proto, google.protobuf.wrappers_pb2.FloatValue):
94+
return proto.value
95+
96+
elif isinstance(proto, google.protobuf.wrappers_pb2.DoubleValue):
97+
return proto.value
98+
99+
elif isinstance(proto, google.protobuf.wrappers_pb2.StringValue):
100+
return proto.value
101+
102+
elif isinstance(proto, google.protobuf.wrappers_pb2.BytesValue):
103+
try:
104+
# Assume it's the legacy container for pickled values.
105+
return pickle.loads(proto.value)
106+
except Exception as e:
107+
# Otherwise, return the literal bytes.
108+
return proto.value
109+
110+
elif isinstance(proto, google.protobuf.timestamp_pb2.Timestamp):
111+
return proto.ToDatetime(tzinfo=timezone.utc)
112+
113+
elif isinstance(proto, google.protobuf.duration_pb2.Duration):
114+
return proto.ToTimedelta()
115+
116+
elif isinstance(proto, google.protobuf.struct_pb2.Value):
117+
return from_struct_value(proto)
118+
119+
return proto
120+
121+
122+
def as_struct_value(value: Any) -> google.protobuf.struct_pb2.Value:
123+
if value is None:
124+
null_value = google.protobuf.struct_pb2.NullValue.NULL_VALUE
125+
return google.protobuf.struct_pb2.Value(null_value=null_value)
126+
127+
elif isinstance(value, bool):
128+
return google.protobuf.struct_pb2.Value(bool_value=value)
129+
130+
elif isinstance(value, int) or isinstance(value, float):
131+
return google.protobuf.struct_pb2.Value(number_value=float(value))
132+
133+
elif isinstance(value, str):
134+
return google.protobuf.struct_pb2.Value(string_value=value)
135+
136+
elif isinstance(value, list):
137+
list_value = google.protobuf.struct_pb2.ListValue(
138+
values=[as_struct_value(v) for v in value]
139+
)
140+
return google.protobuf.struct_pb2.Value(list_value=list_value)
141+
142+
elif isinstance(value, dict):
143+
for key in value.keys():
144+
if not isinstance(key, str):
145+
raise ValueError("unsupported object key")
146+
147+
struct_value = google.protobuf.struct_pb2.Struct(
148+
fields={k: as_struct_value(v) for k, v in value.items()}
149+
)
150+
return google.protobuf.struct_pb2.Value(struct_value=struct_value)
151+
152+
raise ValueError("unsupported value")
153+
154+
155+
def from_struct_value(value: google.protobuf.struct_pb2.Value) -> Any:
156+
if value.HasField("null_value"):
157+
return None
158+
elif value.HasField("bool_value"):
159+
return value.bool_value
160+
elif value.HasField("number_value"):
161+
return value.number_value
162+
elif value.HasField("string_value"):
163+
return value.string_value
164+
elif value.HasField("list_value"):
165+
166+
return [from_struct_value(v) for v in value.list_value.values]
167+
elif value.HasField("struct_value"):
168+
return {k: from_struct_value(v) for k, v in value.struct_value.fields.items()}
169+
else:
170+
raise RuntimeError(f"invalid struct_pb2.Value: {value}")

src/dispatch/proto.py

+9-55
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tblib # type: ignore[import-untyped]
1212
from google.protobuf import descriptor_pool, duration_pb2, message_factory
1313

14+
from dispatch.any import marshal_any, unmarshal_any
1415
from dispatch.error import IncompatibleStateError, InvalidArgumentError
1516
from dispatch.id import DispatchID
1617
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
@@ -78,11 +79,11 @@ def __init__(self, req: function_pb.RunRequest):
7879

7980
self._has_input = req.HasField("input")
8081
if self._has_input:
81-
self._input = _pb_any_unpack(req.input)
82+
self._input = unmarshal_any(req.input)
8283
else:
8384
if req.poll_result.coroutine_state:
8485
raise IncompatibleStateError # coroutine_state is deprecated
85-
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
86+
self._coroutine_state = unmarshal_any(req.poll_result.typed_coroutine_state)
8687
self._call_results = [
8788
CallResult._from_proto(r) for r in req.poll_result.results
8889
]
@@ -141,7 +142,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
141142
return Input(
142143
req=function_pb.RunRequest(
143144
function=function,
144-
input=_pb_any_pickle(input),
145+
input=marshal_any(input),
145146
)
146147
)
147148

@@ -157,7 +158,7 @@ def from_poll_results(
157158
req=function_pb.RunRequest(
158159
function=function,
159160
poll_result=poll_pb.PollResult(
160-
typed_coroutine_state=_pb_any_pickle(coroutine_state),
161+
typed_coroutine_state=marshal_any(coroutine_state),
161162
results=[result._as_proto() for result in call_results],
162163
error=error._as_proto() if error else None,
163164
),
@@ -241,7 +242,7 @@ def poll(
241242
else None
242243
)
243244
poll = poll_pb.Poll(
244-
typed_coroutine_state=_pb_any_pickle(coroutine_state),
245+
typed_coroutine_state=marshal_any(coroutine_state),
245246
min_results=min_results,
246247
max_results=max_results,
247248
max_wait=max_wait,
@@ -279,7 +280,7 @@ class Call:
279280
correlation_id: Optional[int] = None
280281

281282
def _as_proto(self) -> call_pb.Call:
282-
input_bytes = _pb_any_pickle(self.input)
283+
input_bytes = marshal_any(self.input)
283284
return call_pb.Call(
284285
correlation_id=self.correlation_id,
285286
endpoint=self.endpoint,
@@ -301,7 +302,7 @@ def _as_proto(self) -> call_pb.CallResult:
301302
output_any = None
302303
error_proto = None
303304
if self.output is not None:
304-
output_any = _pb_any_pickle(self.output)
305+
output_any = marshal_any(self.output)
305306
if self.error is not None:
306307
error_proto = self.error._as_proto()
307308

@@ -317,7 +318,7 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
317318
output = None
318319
error = None
319320
if proto.HasField("output"):
320-
output = _any_unpickle(proto.output)
321+
output = unmarshal_any(proto.output)
321322
if proto.HasField("error"):
322323
error = Error._from_proto(proto.error)
323324

@@ -438,50 +439,3 @@ def _as_proto(self) -> error_pb.Error:
438439
return error_pb.Error(
439440
type=self.type, message=self.message, value=value, traceback=self.traceback
440441
)
441-
442-
443-
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
444-
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
445-
p = pickled_pb.Pickled()
446-
any.Unpack(p)
447-
return pickle.loads(p.pickled_value)
448-
449-
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
450-
b = google.protobuf.wrappers_pb2.BytesValue()
451-
any.Unpack(b)
452-
return pickle.loads(b.value)
453-
454-
elif not any.type_url and not any.value:
455-
return None
456-
457-
raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")
458-
459-
460-
def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
461-
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
462-
any = google.protobuf.any_pb2.Any()
463-
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
464-
return any
465-
466-
467-
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
468-
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
469-
p = pickled_pb.Pickled()
470-
any.Unpack(p)
471-
return pickle.loads(p.pickled_value)
472-
473-
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
474-
b = google.protobuf.wrappers_pb2.BytesValue()
475-
any.Unpack(b)
476-
try:
477-
# Assume it's the legacy container for pickled values.
478-
return pickle.loads(b.value)
479-
except Exception as e:
480-
# Otherwise, return the literal bytes.
481-
return b.value
482-
483-
pool = descriptor_pool.Default()
484-
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
485-
proto = message_factory.GetMessageClass(msg_descriptor)()
486-
any.Unpack(proto)
487-
return proto

0 commit comments

Comments
 (0)