Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Feature/better signatures #62

Merged
merged 29 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a10d2c0
add default .pylintrc
aguschin Sep 30, 2021
33a26ce
adding module docstrings where they can be useful
aguschin Sep 30, 2021
030b865
Pylint ignores and todos
mike0sv Oct 1, 2021
ef6c32d
Update LICENSE
aguschin Sep 30, 2021
79999c1
Create CODEOWNERS
aguschin Oct 1, 2021
f3e3614
Specify issues for todos (#46)
mike0sv Oct 1, 2021
fcaefa4
Add action to release on PyPi (#50)
aguschin Oct 4, 2021
c92b534
add comment-bot workflow (#52)
aguschin Oct 4, 2021
866b5e3
Pylint fixes WIP
mike0sv Oct 4, 2021
c274311
Add release-drafter (#53)
aguschin Oct 4, 2021
fafe79c
Update README.md (#54)
aguschin Oct 4, 2021
8b63933
Update README.md (#55)
aguschin Oct 4, 2021
69c58ac
mlem: bump to 0.1.1
aguschin Oct 4, 2021
7f8bdc3
Merge branch 'main' into add-pylint
mike0sv Oct 4, 2021
b1bd464
Merge branch 'main' into add-pylint
aguschin Oct 4, 2021
802f42a
writing some docstrings and help messages in cli
aguschin Oct 4, 2021
83757cb
fix some Pylance warnings
aguschin Oct 5, 2021
31da9f2
fixing issues raised by pylint
aguschin Oct 5, 2021
ecfbe91
Merge branch 'main' into add-pylint
aguschin Oct 5, 2021
21b572a
add pylint to tests deps
aguschin Oct 5, 2021
94ec9e5
use pylint in test job, not in check
aguschin Oct 5, 2021
66b55ba
install pre-commit in test job
aguschin Oct 5, 2021
82de51c
improved signature inferring
mike0sv Oct 6, 2021
60c88e4
fixing last pylint warnings
mike0sv Oct 6, 2021
aad93c4
Merge branch 'add-pylint' into feature/better-signatures-#21
mike0sv Oct 6, 2021
0ba4e3e
better requirements
mike0sv Oct 6, 2021
e886e82
merge
mike0sv Oct 7, 2021
8320bd5
log and serialization fix
mike0sv Oct 8, 2021
7ff0973
better test
mike0sv Oct 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class MlemConfig(BaseSettings):
)
AUTOLOAD_EXTS: bool = True
DEFAULT_BRANCH: str = "main"
LOG_LEVEL: str = "INFO"
DEBUG: bool = False

@property
def ADDITIONAL_EXTENSIONS(self) -> List[str]:
Expand Down
4 changes: 4 additions & 0 deletions mlem/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
MLEM_DIR = ".mlem"

PREDICT_METHOD_NAME = "predict"
PREDICT_PROBA_METHOD_NAME = "predict_proba"
PREDICT_ARG_NAME = "data"
67 changes: 46 additions & 21 deletions mlem/contrib/catboost.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import tempfile
from typing import Any, ClassVar
from typing import Any, ClassVar, Optional

import catboost
from catboost import CatBoostClassifier, CatBoostRegressor
from catboost import CatBoost, CatBoostClassifier, CatBoostRegressor
from fsspec import AbstractFileSystem

from mlem.core.artifacts import Artifacts
from mlem.core.dataset_type import UnspecifiedDatasetType
from mlem.core.model import Argument, ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import LibRequirementsMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import InstallableRequirement, Requirements


class CatBoostModelIO(ModelIO):
Expand Down Expand Up @@ -52,36 +51,62 @@ def _get_model_file_name(self, model):
return self.regressor_file_name


class CatBoostModel(ModelType, ModelHook, LibRequirementsMixin):
class CatBoostModel(ModelType, ModelHook):
"""
:class:`mlem.core.model.ModelType` for CatBoost models.
`.model` attribute is a `catboost.CatBoostClassifier` or `catboost.CatBoostRegressor` instance
"""

libraries: ClassVar = [catboost]
type: ClassVar[str] = "catboost"
io: ModelIO = CatBoostModelIO()
model: ClassVar[Optional[CatBoost]]

@classmethod
def is_object_valid(cls, obj: Any) -> bool:
return isinstance(obj, (CatBoostClassifier, CatBoostRegressor))

@classmethod
def process(cls, obj: Any, **kwargs) -> ModelType:
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
model = CatBoostModel(model=obj, methods={})
methods = {
"predict": Signature(
name="predict",
args=[
Argument(key="data", type=UnspecifiedDatasetType())
], # TODO: https://github.com/iterative/mlem/issues/21
returns=UnspecifiedDatasetType(),
)
"predict": Signature.from_method(
model.predict,
auto_infer=sample_data is not None,
data=sample_data,
),
"catboost_predict": Signature.from_method(
obj.predict,
auto_infer=sample_data is not None,
data=sample_data,
),
}
if isinstance(obj, CatBoostClassifier):
methods["predict_proba"] = Signature(
name="predict_proba",
args=[Argument(key="data", type=UnspecifiedDatasetType())],
# TODO: https://github.com/iterative/mlem/issues/21
returns=UnspecifiedDatasetType(),
methods["predict_proba"] = Signature.from_method(
model.predict_proba,
auto_infer=sample_data is not None,
data=sample_data,
)
methods["catboost_predict_proba"] = Signature.from_method(
obj.predict_proba,
auto_infer=sample_data is not None,
X=sample_data,
)
return CatBoostModel(model=obj, methods=methods)
model.methods = methods
return model

def predict(self, data):
return self.model.predict(data)

def predict_proba(self, data):
if not isinstance(self.model, CatBoostClassifier):
raise ValueError(
"Not valid type of model for predict_proba method"
)
return self.model.predict_proba(data)

def get_requirements(self) -> Requirements:
return super().get_requirements() + InstallableRequirement.from_module(
catboost
)
6 changes: 3 additions & 3 deletions mlem/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _create_handler(
cls, method_name: str, signature: Signature, executor: Callable
):
serializers = {
arg.key: arg.type.get_serializer() for arg in signature.args
arg.name: arg.type_.get_serializer() for arg in signature.args
}
kwargs = {
key: (serializer.get_model(), ...)
Expand All @@ -47,8 +47,8 @@ def _create_handler(

def handler(model: payload_model): # type: ignore[valid-type]
kwargs = {
a.key: serializers[a.key].deserialize(
getattr(model, a.key).dict()
a.name: serializers[a.name].deserialize(
getattr(model, a.name).dict()
)
for a in signature.args
}
Expand Down
39 changes: 22 additions & 17 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
import tempfile
from typing import Any, ClassVar
from typing import Any, ClassVar, Optional

import lightgbm as lgb
from fsspec import AbstractFileSystem

from mlem.constants import PREDICT_METHOD_NAME
from mlem.core.artifacts import Artifacts
from mlem.core.dataset_type import (
DatasetAnalyzer,
DatasetHook,
DatasetType,
DatasetWriter,
UnspecifiedDatasetType,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import Argument, ModelHook, ModelIO, ModelType, Signature
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import (
InstallableRequirement,
Requirements,
Expand Down Expand Up @@ -100,25 +100,30 @@ class LightGBMModel(ModelType, ModelHook, IsInstanceHookMixin):
io: ModelIO = LightGBMModelIO()

@classmethod
def process(cls, obj: Any, **kwargs) -> ModelType:
return LightGBMModel(
model=obj,
methods={
"predict": Signature(
name="_predict",
args=[Argument(key="data", type=UnspecifiedDatasetType())],
returns=UnspecifiedDatasetType(), # TODO: https://github.com/iterative/mlem/issues/21
)
},
)

def _predict(self, data):
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
gbm_model = LightGBMModel(model=obj, methods={})
gbm_model.methods = {
PREDICT_METHOD_NAME: Signature.from_method(
gbm_model.predict,
auto_infer=sample_data is not None,
data=sample_data,
),
"lightgbm_predict": Signature.from_method(
obj.predict, auto_infer=sample_data is None, data=sample_data
),
}
return gbm_model

def predict(self, data):
if isinstance(data, lgb.Dataset):
data = data.data
return self.model.predict(data)

def get_requirements(self) -> Requirements:
return (
Requirements.new(InstallableRequirement.from_module(mod=lgb))
super().get_requirements()
+ InstallableRequirement.from_module(mod=lgb)
+ LGB_REQUIREMENT
)
62 changes: 28 additions & 34 deletions mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, ClassVar, Dict
from typing import Any, ClassVar, Optional

import sklearn
from sklearn.base import ClassifierMixin, RegressorMixin

from mlem.core.dataset_type import DatasetAnalyzer, UnspecifiedDatasetType
from mlem.constants import (
PREDICT_ARG_NAME,
PREDICT_METHOD_NAME,
PREDICT_PROBA_METHOD_NAME,
)
from mlem.core.model import (
Argument,
ModelHook,
ModelIO,
ModelType,
Expand All @@ -29,37 +32,28 @@ def is_object_valid(cls, obj: Any) -> bool:
return isinstance(obj, (RegressorMixin, ClassifierMixin))

@classmethod
def process(cls, obj: Any, **kwargs) -> "SklearnModel":
test_data = kwargs.get("test_data")
method_names = ["predict"]
if isinstance(obj, ClassifierMixin):
method_names.append("predict_proba")
if test_data is None:

methods: Dict[str, Signature] = {
m: Signature(
name=m,
args=[Argument(key="X", type=UnspecifiedDatasetType())],
returns=UnspecifiedDatasetType(),
)
for m in method_names
}

else:
methods = {
m: Signature(
name=m,
args=[
Argument(
key="X", type=DatasetAnalyzer.analyze(test_data)
)
],
returns=DatasetAnalyzer.analyze(
getattr(obj, m)(test_data)
),
)
for m in method_names
}
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
sklearn_predict = Signature.from_method(
obj.predict, sample_data is not None, X=sample_data
)
predict = sklearn_predict.copy()
predict.args = [predict.args[0].copy()]
predict.args[0].name = PREDICT_ARG_NAME
methods = {
"sklearn_predict": sklearn_predict,
PREDICT_METHOD_NAME: predict,
}
if hasattr(obj, "predict_proba"):
sklearn_predict_proba = Signature.from_method(
obj.predict_proba, sample_data is not None, X=sample_data
)
predict_proba = sklearn_predict_proba.copy()
predict_proba.args = [predict_proba.args[0].copy()]
predict_proba.args[0].name = PREDICT_ARG_NAME
methods["sklearn_predict_proba"] = sklearn_predict_proba
methods[PREDICT_PROBA_METHOD_NAME] = predict_proba

return SklearnModel(io=SimplePickleIO(), methods=methods).bind(obj)

Expand Down
44 changes: 26 additions & 18 deletions mlem/contrib/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
import xgboost
from fsspec import AbstractFileSystem

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts
from mlem.core.dataset_type import (
DatasetHook,
DatasetType,
DatasetWriter,
UnspecifiedDatasetType,
)
from mlem.core.dataset_type import DatasetHook, DatasetType, DatasetWriter
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import Argument, ModelHook, ModelIO, ModelType, Signature
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import (
InstallableRequirement,
Requirements,
Expand Down Expand Up @@ -135,9 +131,7 @@ def load(self, fs: AbstractFileSystem, path):
return model


class XGBoostModel(
XGBoostRequirement, ModelType, ModelHook, IsInstanceHookMixin
):
class XGBoostModel(ModelType, ModelHook, IsInstanceHookMixin):
"""
:class:`~.ModelType` implementation for XGBoost models
"""
Expand All @@ -148,17 +142,31 @@ class XGBoostModel(
io: ModelIO = XGBoostModelIO()

@classmethod
def process(cls, obj: Any, **kwargs) -> ModelType:
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
model = XGBoostModel(model=obj, methods={})
methods = {
"predict": Signature(
name="_predict",
args=[Argument(key="data", type=UnspecifiedDatasetType())],
returns=UnspecifiedDatasetType(), # TODO: https://github.com/iterative/mlem/issues/21
)
PREDICT_METHOD_NAME: Signature.from_method(
model.predict,
auto_infer=sample_data is not None,
data=sample_data,
),
"xgboost_predict": Signature.from_method(
obj.predict, auto_infer=sample_data is None, data=sample_data
),
}
return XGBoostModel(model=obj, methods=methods)
model.methods = methods
return model

def _predict(self, data):
def predict(self, data):
if not isinstance(data, xgboost.DMatrix):
data = xgboost.DMatrix(data)
return self.model.predict(data)

def get_requirements(self) -> Requirements:
return (
super().get_requirements()
+ InstallableRequirement.from_module(xgboost)
+ XGB_REQUIREMENT
)
5 changes: 1 addition & 4 deletions mlem/core/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PrimitiveType(DatasetType, DatasetHook):
DatasetType for int, str, bool, complex and float types
"""

PRIMITIVES: ClassVar[set] = {int, str, bool, complex, float}
PRIMITIVES: ClassVar[set] = {int, str, bool, complex, float, type(None)}
type: ClassVar[str] = "primitive"

ptype: str
Expand All @@ -104,9 +104,6 @@ def to_type(self):
def deserialize(self, obj):
return self.to_type(obj)

# def get_spec(self) -> ArgList:
# return [Field(None, self.to_type, False)]

def serialize(self, instance):
self.check_type(instance, self.to_type, ValueError)
return instance
Expand Down
2 changes: 1 addition & 1 deletion mlem/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,6 @@ def _find_hook(cls, obj) -> Type[Hook[T]]:
)
raise ValueError(
f"No suitable {cls.base_hook_class.__name__} for object of type "
f"[{type(obj).__name__}]. Registered hooks: {cls.hooks}"
f'"{type(obj).__name__}". Registered hooks: {cls.hooks}'
)
return max(hooks, key=lambda x: x[0])[1]
2 changes: 1 addition & 1 deletion mlem/core/meta_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def serialize(
): # pylint: disable=unused-argument # todo remove later
if not isinstance(obj, MlemObject):
raise ValueError(f"{type(obj)} is not a subclass of MlemObject")
return obj.dict(exclude_unset=True)
return obj.dict(exclude_unset=True, exclude_defaults=True)


T = TypeVar("T")
Expand Down
2 changes: 1 addition & 1 deletion mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_object_metadata(obj: Any, tmp_sample_data=None) -> MlemMeta:
try:
return DatasetMeta.from_data(obj)
except ValueError: # TODO need separate analysis exception
return ModelMeta.from_obj(obj, test_data=tmp_sample_data)
return ModelMeta.from_obj(obj, sample_data=tmp_sample_data)


def save(
Expand Down
Loading