Skip to content

Commit d63ae80

Browse files
test: drop FastAPI dependency
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
1 parent 3e1fd3e commit d63ae80

File tree

3 files changed

+113
-119
lines changed

3 files changed

+113
-119
lines changed

src/dispatch/test/client.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from datetime import datetime
22
from typing import Optional
33

4-
import fastapi
54
import grpc
65
import httpx
7-
from fastapi.testclient import TestClient
86

97
from dispatch.sdk.v1 import function_pb2 as function_pb
108
from dispatch.sdk.v1 import function_pb2_grpc as function_grpc
@@ -22,7 +20,7 @@ class EndpointClient:
2220
Note that this is different from dispatch.Client, which is a client
2321
for the Dispatch API. The EndpointClient is a client similar to the one
2422
that Dispatch itself would use to interact with an endpoint that provides
25-
functions, for example a FastAPI app.
23+
functions.
2624
"""
2725

2826
def __init__(
@@ -54,15 +52,6 @@ def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
5452
http_client = httpx.Client(base_url=url)
5553
return EndpointClient(http_client, signing_key)
5654

57-
@classmethod
58-
def from_app(
59-
cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
60-
):
61-
"""Returns an EndpointClient for a Dispatch endpoint bound to a
62-
FastAPI app instance."""
63-
http_client = TestClient(app)
64-
return EndpointClient(http_client, signing_key)
65-
6655

6756
class _HttpxGrpcChannel(grpc.Channel):
6857
def __init__(

tests/test_fastapi.py

+112-8
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,36 @@
33
import pickle
44
import struct
55
import unittest
6-
from typing import Any
6+
from typing import Any, Optional
77
from unittest import mock
88

99
import fastapi
1010
import google.protobuf.any_pb2
1111
import google.protobuf.wrappers_pb2
1212
import httpx
13-
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
13+
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
14+
Ed25519PrivateKey,
15+
Ed25519PublicKey,
16+
)
1417
from fastapi.testclient import TestClient
1518

19+
import dispatch
1620
from dispatch.experimental.durable.registry import clear_functions
1721
from dispatch.fastapi import Dispatch
1822
from dispatch.function import Arguments, Error, Function, Input, Output
1923
from dispatch.proto import _any_unpickle as any_unpickle
2024
from dispatch.sdk.v1 import call_pb2 as call_pb
2125
from dispatch.sdk.v1 import function_pb2 as function_pb
22-
from dispatch.signature import parse_verification_key, public_key_from_pem
26+
from dispatch.signature import (
27+
parse_verification_key,
28+
private_key_from_pem,
29+
public_key_from_pem,
30+
)
2331
from dispatch.status import Status
24-
from dispatch.test import EndpointClient
32+
from dispatch.test import DispatchServer, DispatchService, EndpointClient
2533

2634

27-
def create_dispatch_instance(app, endpoint):
35+
def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
2836
return Dispatch(
2937
app,
3038
endpoint=endpoint,
@@ -33,6 +41,13 @@ def create_dispatch_instance(app, endpoint):
3341
)
3442

3543

44+
def create_endpoint_client(
45+
app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
46+
):
47+
http_client = TestClient(app)
48+
return EndpointClient(http_client, signing_key)
49+
50+
3651
class TestFastAPI(unittest.TestCase):
3752
def test_Dispatch(self):
3853
app = fastapi.FastAPI()
@@ -79,8 +94,7 @@ def my_function(input: Input) -> Output:
7994
f"You told me: '{input.input}' ({len(input.input)} characters)"
8095
)
8196

82-
client = EndpointClient.from_app(app)
83-
97+
client = create_endpoint_client(app)
8498
pickled = pickle.dumps("Hello World!")
8599
input_any = google.protobuf.any_pb2.Any()
86100
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
@@ -102,6 +116,96 @@ def my_function(input: Input) -> Output:
102116
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
103117

104118

119+
signing_key = private_key_from_pem(
120+
"""
121+
-----BEGIN PRIVATE KEY-----
122+
MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
123+
-----END PRIVATE KEY-----
124+
"""
125+
)
126+
127+
verification_key = public_key_from_pem(
128+
"""
129+
-----BEGIN PUBLIC KEY-----
130+
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
131+
-----END PUBLIC KEY-----
132+
"""
133+
)
134+
135+
136+
class TestFullFastapi(unittest.TestCase):
137+
def setUp(self):
138+
self.endpoint_app = fastapi.FastAPI()
139+
endpoint_client = create_endpoint_client(self.endpoint_app, signing_key)
140+
141+
api_key = "0000000000000000"
142+
self.dispatch_service = DispatchService(
143+
endpoint_client, api_key, collect_roundtrips=True
144+
)
145+
self.dispatch_server = DispatchServer(self.dispatch_service)
146+
self.dispatch_client = dispatch.Client(
147+
api_key, api_url=self.dispatch_server.url
148+
)
149+
150+
self.dispatch = Dispatch(
151+
self.endpoint_app,
152+
endpoint="http://function-service", # unused
153+
verification_key=verification_key,
154+
api_key=api_key,
155+
api_url=self.dispatch_server.url,
156+
)
157+
158+
self.dispatch_server.start()
159+
160+
def tearDown(self):
161+
self.dispatch_server.stop()
162+
163+
def test_simple_end_to_end(self):
164+
# The FastAPI server.
165+
@self.dispatch.function
166+
def my_function(name: str) -> str:
167+
return f"Hello world: {name}"
168+
169+
call = my_function.build_call(52)
170+
self.assertEqual(call.function.split(".")[-1], "my_function")
171+
172+
# The client.
173+
[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])
174+
175+
# Simulate execution for testing purposes.
176+
self.dispatch_service.dispatch_calls()
177+
178+
# Validate results.
179+
roundtrips = self.dispatch_service.roundtrips[dispatch_id]
180+
self.assertEqual(len(roundtrips), 1)
181+
_, response = roundtrips[0]
182+
self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52")
183+
184+
def test_simple_missing_signature(self):
185+
@self.dispatch.function
186+
async def my_function(name: str) -> str:
187+
return f"Hello world: {name}"
188+
189+
call = my_function.build_call(52)
190+
self.assertEqual(call.function.split(".")[-1], "my_function")
191+
192+
[dispatch_id] = self.dispatch_client.dispatch([call])
193+
194+
self.dispatch_service.endpoint_client = create_endpoint_client(
195+
self.endpoint_app
196+
) # no signing key
197+
try:
198+
self.dispatch_service.dispatch_calls()
199+
except httpx.HTTPStatusError as e:
200+
assert e.response.status_code == 403
201+
assert e.response.json() == {
202+
"code": "permission_denied",
203+
"message": 'Expected "Signature-Input" header field to be present',
204+
}
205+
else:
206+
assert False, "Expected HTTPStatusError"
207+
208+
105209
def response_output(resp: function_pb.RunResponse) -> Any:
106210
return any_unpickle(resp.exit.result.output)
107211

@@ -120,7 +224,7 @@ def root():
120224
self.app, endpoint="https://127.0.0.1:9999"
121225
)
122226
self.http_client = TestClient(self.app)
123-
self.client = EndpointClient.from_app(self.app)
227+
self.client = create_endpoint_client(self.app)
124228

125229
def execute(
126230
self, func: Function, input=None, state=None, calls=None

tests/test_full.py

-99
This file was deleted.

0 commit comments

Comments
 (0)