Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Client classes for Servers #185

Merged
merged 17 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlem/contrib/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections.abc import Callable
from types import ModuleType
from typing import ClassVar, List, Type
Expand All @@ -13,6 +14,8 @@
from mlem.runtime.server.base import Server
from mlem.ui import EMOJI_NAILS, echo

logger = logging.getLogger(__name__)


def rename_recursively(model: Type[BaseModel], prefix: str):
model.__name__ = f"{prefix}{model.__name__}"
Expand All @@ -21,6 +24,12 @@ def rename_recursively(model: Type[BaseModel], prefix: str):
rename_recursively(field.type_, prefix)


def _create_schema_route(app: FastAPI, interface: Interface):
schema = interface.get_descriptor().dict()
logger.debug("Creating /interface.json route with schema: %s", schema)
app.add_api_route("/interface.json", lambda: schema, tags=["schema"])


class FastAPIServer(Server, LibRequirementsMixin):
libraries: ClassVar[List[ModuleType]] = [uvicorn, fastapi]
type: ClassVar[str] = "fastapi"
Expand Down Expand Up @@ -64,6 +73,7 @@ def handler(model: payload_model): # type: ignore[valid-type]

def app_init(self, interface: Interface):
app = FastAPI()
_create_schema_route(app, interface)

for method, signature in interface.iter_methods():
executor = interface.get_method_executor(method)
Expand Down
101 changes: 96 additions & 5 deletions mlem/runtime/client/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,105 @@
from abc import ABC
import logging
from abc import ABC, abstractmethod
from typing import Callable, ClassVar, Optional

import requests
from pydantic import BaseModel, parse_obj_as

from mlem.core.base import MlemObject
from mlem.core.errors import WrongMethodError
from mlem.core.model import Signature
from mlem.runtime.interface.base import ExecutionError, InterfaceDescriptor

logger = logging.getLogger(__name__)


class BaseClient(MlemObject, ABC):
"""TODO: https://github.com/iterative/mlem/issues/40"""
abs_name: ClassVar[str] = "client"

@property
def interface(self):
return self._interface_factory()

@property
def methods(self):
return self.interface.methods

@abstractmethod
def _interface_factory(self) -> InterfaceDescriptor:
raise NotImplementedError()

@abstractmethod
def _call_method(self, name, args):
raise NotImplementedError()

def __getattr__(self, name):
if name not in self.methods:
raise WrongMethodError(f"{name} method is not exposed by server")
return _MethodCall(
base_url=self.base_url,
method=self.methods[name],
call_method=self._call_method,
)


class _MethodCall(BaseModel):
base_url: str
method: Signature
call_method: Callable

def __call__(self, *args, **kwargs):
if args and kwargs:
raise ValueError(
"Parameters should be passed either in positional or in keyword fashion, not both"
)
if len(args) > len(self.method.args) or len(kwargs) > len(
self.method.args
):
raise ValueError(
f"Too much parameters given, expected: {len(self.method.args)}"
)

data = {}
for i, arg in enumerate(self.method.args):
obj = None
if len(args) > i:
obj = args[i]
if arg.name in kwargs:
obj = kwargs[arg.name]
if obj is None:
raise ValueError(
f'Parameter with name "{arg.name}" (position {i}) should be passed'
)

data[arg.name] = arg.type_.serialize(obj)

logger.debug(
'Calling server method "%s", args: %s ...', self.method.name, data
)
out = self.call_method(self.method.name, data)
logger.debug("Server call returned %s", out)
return self.method.returns.get_serializer().deserialize(out)


class HTTPClient(BaseClient):
"""TODO: https://github.com/iterative/mlem/issues/40"""
host: str = "0.0.0.0"
port: Optional[int] = 8080

@property
def base_url(self):
if self.port:
return f"http://{self.host}:{self.port}"
return f"http://{self.host}"

def _interface_factory(self) -> InterfaceDescriptor:
resp = requests.get(f"{self.base_url}/interface.json")
return parse_obj_as(InterfaceDescriptor, resp.json())

host: str
port: int
def _call_method(self, name, args): # pylint: disable=R1710
ret = requests.post(f"{self.base_url}/{name}", json=args)
if ret.status_code == 200: # pylint: disable=R1705
return ret.json()
elif ret.status_code == 400:
raise ExecutionError(ret.json()["error"])
else:
ret.raise_for_status()
127 changes: 127 additions & 0 deletions tests/runtime/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import platform

import numpy as np
import pytest
from fastapi.testclient import TestClient

from mlem.constants import PREDICT_ARG_NAME, PREDICT_METHOD_NAME
from mlem.contrib.fastapi import FastAPIServer
from mlem.contrib.numpy import NumpyNdarrayType
from mlem.core.dataset_type import DatasetAnalyzer
from mlem.core.errors import WrongMethodError
from mlem.core.model import Argument, Signature
from mlem.core.objects import ModelMeta
from mlem.runtime.client.base import HTTPClient
from mlem.runtime.interface.base import ModelInterface


@pytest.fixture
def signature(train):
data_type = DatasetAnalyzer.analyze(train)
returns_type = NumpyNdarrayType(
shape=(None,),
dtype="int32" if platform.system() == "Windows" else "int64",
)
kwargs = {"varkw": None}
return Signature(
name=PREDICT_METHOD_NAME,
args=[Argument(name=PREDICT_ARG_NAME, type_=data_type)],
returns=returns_type,
**kwargs,
)


@pytest.fixture
def interface(model, train):
model = ModelMeta.from_obj(model, sample_data=train)
interface = ModelInterface.from_model(model)
return interface


@pytest.fixture
def client(interface):
app = FastAPIServer().app_init(interface)
return TestClient(app)


@pytest.fixture
def request_get_mock(mocker, client):
def patched_get(url, params=None, **kwargs):
url = url[len("http://") :]
return client.get(url, params=params, **kwargs)

return mocker.patch(
"mlem.runtime.client.base.requests.get",
side_effect=patched_get,
)


@pytest.fixture
def request_post_mock(mocker, client):
def patched_post(url, data=None, json=None, **kwargs):
url = url[len("http://") :]
return client.post(url, data=data, json=json, **kwargs)

return mocker.patch(
"mlem.runtime.client.base.requests.post",
side_effect=patched_post,
)


@pytest.fixture
def mlem_client(request_get_mock, request_post_mock):
client = HTTPClient(host="", port=None)
return client


@pytest.mark.parametrize("port", [None, 80])
def test_mlem_client_base_url(port):
client = HTTPClient(host="", port=port)
assert client.base_url == f"http://:{port}" if port else "http://"


@pytest.mark.parametrize("use_keyword", [False, True])
def test_interface_endpoint(mlem_client, train, signature, use_keyword):
assert PREDICT_METHOD_NAME in mlem_client.methods
assert mlem_client.methods[PREDICT_METHOD_NAME] == signature
if use_keyword:
assert np.array_equal(
getattr(mlem_client, PREDICT_METHOD_NAME)(data=train),
np.array([0] * 50 + [1] * 50 + [2] * 50),
)
else:
assert np.array_equal(
getattr(mlem_client, PREDICT_METHOD_NAME)(train),
np.array([0] * 50 + [1] * 50 + [2] * 50),
)


def test_wrong_endpoint(mlem_client):
with pytest.raises(WrongMethodError):
mlem_client.dummy_method()


def test_data_validation_more_params_than_expected(mlem_client, train):
with pytest.raises(ValueError) as e:
getattr(mlem_client, PREDICT_METHOD_NAME)(train, 2)
assert str(e.value) == "Too much parameters given, expected: 1"


def test_data_validation_params_in_positional_and_keyword(mlem_client, train):
with pytest.raises(ValueError) as e:
getattr(mlem_client, f"sklearn_{PREDICT_METHOD_NAME}")(
train, check_input=False
)
assert (
str(e.value)
== "Parameters should be passed either in positional or in keyword fashion, not both"
)


def test_data_validation_params_with_wrong_name(mlem_client, train):
with pytest.raises(ValueError) as e:
getattr(mlem_client, PREDICT_METHOD_NAME)(X=train)
assert (
str(e.value)
== 'Parameter with name "data" (position 0) should be passed'
)