3
3
import pickle
4
4
import struct
5
5
import unittest
6
- from typing import Any
6
+ from typing import Any , Optional
7
7
from unittest import mock
8
8
9
9
import fastapi
10
10
import google .protobuf .any_pb2
11
11
import google .protobuf .wrappers_pb2
12
12
import httpx
13
- from cryptography .hazmat .primitives .asymmetric .ed25519 import Ed25519PublicKey
13
+ from cryptography .hazmat .primitives .asymmetric .ed25519 import (
14
+ Ed25519PrivateKey ,
15
+ Ed25519PublicKey ,
16
+ )
14
17
from fastapi .testclient import TestClient
15
18
19
+ import dispatch
16
20
from dispatch .experimental .durable .registry import clear_functions
17
21
from dispatch .fastapi import Dispatch
18
22
from dispatch .function import Arguments , Error , Function , Input , Output
19
23
from dispatch .proto import _any_unpickle as any_unpickle
20
24
from dispatch .sdk .v1 import call_pb2 as call_pb
21
25
from dispatch .sdk .v1 import function_pb2 as function_pb
22
- from dispatch .signature import parse_verification_key , public_key_from_pem
26
+ from dispatch .signature import (
27
+ parse_verification_key ,
28
+ private_key_from_pem ,
29
+ public_key_from_pem ,
30
+ )
23
31
from dispatch .status import Status
24
- from dispatch .test import EndpointClient
32
+ from dispatch .test import DispatchServer , DispatchService , EndpointClient
25
33
26
34
27
- def create_dispatch_instance (app , endpoint ):
35
+ def create_dispatch_instance (app : fastapi . FastAPI , endpoint : str ):
28
36
return Dispatch (
29
37
app ,
30
38
endpoint = endpoint ,
@@ -33,6 +41,13 @@ def create_dispatch_instance(app, endpoint):
33
41
)
34
42
35
43
44
+ def create_endpoint_client (
45
+ app : fastapi .FastAPI , signing_key : Optional [Ed25519PrivateKey ] = None
46
+ ):
47
+ http_client = TestClient (app )
48
+ return EndpointClient (http_client , signing_key )
49
+
50
+
36
51
class TestFastAPI (unittest .TestCase ):
37
52
def test_Dispatch (self ):
38
53
app = fastapi .FastAPI ()
@@ -79,8 +94,7 @@ def my_function(input: Input) -> Output:
79
94
f"You told me: '{ input .input } ' ({ len (input .input )} characters)"
80
95
)
81
96
82
- client = EndpointClient .from_app (app )
83
-
97
+ client = create_endpoint_client (app )
84
98
pickled = pickle .dumps ("Hello World!" )
85
99
input_any = google .protobuf .any_pb2 .Any ()
86
100
input_any .Pack (google .protobuf .wrappers_pb2 .BytesValue (value = pickled ))
@@ -102,6 +116,96 @@ def my_function(input: Input) -> Output:
102
116
self .assertEqual (output , "You told me: 'Hello World!' (12 characters)" )
103
117
104
118
119
+ signing_key = private_key_from_pem (
120
+ """
121
+ -----BEGIN PRIVATE KEY-----
122
+ MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
123
+ -----END PRIVATE KEY-----
124
+ """
125
+ )
126
+
127
+ verification_key = public_key_from_pem (
128
+ """
129
+ -----BEGIN PUBLIC KEY-----
130
+ MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
131
+ -----END PUBLIC KEY-----
132
+ """
133
+ )
134
+
135
+
136
+ class TestFullFastapi (unittest .TestCase ):
137
+ def setUp (self ):
138
+ self .endpoint_app = fastapi .FastAPI ()
139
+ endpoint_client = create_endpoint_client (self .endpoint_app , signing_key )
140
+
141
+ api_key = "0000000000000000"
142
+ self .dispatch_service = DispatchService (
143
+ endpoint_client , api_key , collect_roundtrips = True
144
+ )
145
+ self .dispatch_server = DispatchServer (self .dispatch_service )
146
+ self .dispatch_client = dispatch .Client (
147
+ api_key , api_url = self .dispatch_server .url
148
+ )
149
+
150
+ self .dispatch = Dispatch (
151
+ self .endpoint_app ,
152
+ endpoint = "http://function-service" , # unused
153
+ verification_key = verification_key ,
154
+ api_key = api_key ,
155
+ api_url = self .dispatch_server .url ,
156
+ )
157
+
158
+ self .dispatch_server .start ()
159
+
160
+ def tearDown (self ):
161
+ self .dispatch_server .stop ()
162
+
163
+ def test_simple_end_to_end (self ):
164
+ # The FastAPI server.
165
+ @self .dispatch .function
166
+ def my_function (name : str ) -> str :
167
+ return f"Hello world: { name } "
168
+
169
+ call = my_function .build_call (52 )
170
+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
171
+
172
+ # The client.
173
+ [dispatch_id ] = self .dispatch_client .dispatch ([my_function .build_call (52 )])
174
+
175
+ # Simulate execution for testing purposes.
176
+ self .dispatch_service .dispatch_calls ()
177
+
178
+ # Validate results.
179
+ roundtrips = self .dispatch_service .roundtrips [dispatch_id ]
180
+ self .assertEqual (len (roundtrips ), 1 )
181
+ _ , response = roundtrips [0 ]
182
+ self .assertEqual (any_unpickle (response .exit .result .output ), "Hello world: 52" )
183
+
184
+ def test_simple_missing_signature (self ):
185
+ @self .dispatch .function
186
+ async def my_function (name : str ) -> str :
187
+ return f"Hello world: { name } "
188
+
189
+ call = my_function .build_call (52 )
190
+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
191
+
192
+ [dispatch_id ] = self .dispatch_client .dispatch ([call ])
193
+
194
+ self .dispatch_service .endpoint_client = create_endpoint_client (
195
+ self .endpoint_app
196
+ ) # no signing key
197
+ try :
198
+ self .dispatch_service .dispatch_calls ()
199
+ except httpx .HTTPStatusError as e :
200
+ assert e .response .status_code == 403
201
+ assert e .response .json () == {
202
+ "code" : "permission_denied" ,
203
+ "message" : 'Expected "Signature-Input" header field to be present' ,
204
+ }
205
+ else :
206
+ assert False , "Expected HTTPStatusError"
207
+
208
+
105
209
def response_output (resp : function_pb .RunResponse ) -> Any :
106
210
return any_unpickle (resp .exit .result .output )
107
211
@@ -120,7 +224,7 @@ def root():
120
224
self .app , endpoint = "https://127.0.0.1:9999"
121
225
)
122
226
self .http_client = TestClient (self .app )
123
- self .client = EndpointClient . from_app (self .app )
227
+ self .client = create_endpoint_client (self .app )
124
228
125
229
def execute (
126
230
self , func : Function , input = None , state = None , calls = None
0 commit comments