Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
add sklearn pipelines (#251)
Browse files Browse the repository at this point in the history
* add sklearn pipelines
closes #218

* small fix of unrelated bug but who is gonna notice

* fix tests

* another small bug ;))

* and another one
closes #240

* and another one
closes #250

* Update tests/contrib/test_sklearn.py

Co-authored-by: Alexander Guschin <1aguschin@gmail.com>
  • Loading branch information
mike0sv and aguschin authored May 17, 2022
1 parent a724ce6 commit 83f169f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mlem/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typer.core import TyperCommand, TyperGroup
from yaml import safe_load

from mlem import version
from mlem import CONFIG, version
from mlem.analytics import send_cli_call
from mlem.constants import MLEM_DIR, PREDICT_METHOD_NAME
from mlem.core.base import MlemABC, build_mlem_object
Expand Down Expand Up @@ -226,7 +226,7 @@ def mlem_callback(
logger = logging.getLogger("mlem")
logger.handlers[0].setLevel(logging.DEBUG)
logger.setLevel(logging.DEBUG)
ctx.obj = {"traceback": traceback}
ctx.obj = {"traceback": traceback or CONFIG.DEBUG}


def _extract_examples(
Expand Down
26 changes: 26 additions & 0 deletions mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sklearn
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.pipeline import Pipeline

from mlem.constants import (
PREDICT_ARG_NAME,
Expand Down Expand Up @@ -67,3 +68,28 @@ def get_requirements(self) -> Requirements:
return super().get_requirements() + InstallableRequirement.from_module(
sklearn
)


class SklearnPipelineType(SklearnModel):
valid_types: ClassVar = (Pipeline,)
type: ClassVar = "sklearn_pipeline"

@classmethod
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
mt = SklearnModel(io=SimplePickleIO(), methods={}).bind(obj)
predict = obj.predict
predict_args = {"X": sample_data}
if hasattr(predict, "__wrapped__"):
predict = predict.__wrapped__
predict_args["self"] = obj
sk_predict_sig = Signature.from_method(
predict, auto_infer=True, **predict_args
)
mt.methods["sklearn_predict"] = sk_predict_sig
predict_sig = sk_predict_sig.copy()
predict_sig.args[0].name = "data"
predict_sig.varkw = None
mt.methods[PREDICT_METHOD_NAME] = predict_sig
return mt
2 changes: 1 addition & 1 deletion mlem/core/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def from_fs_path(cls, fs: AbstractFileSystem, path: str):

class LocalStorage(FSSpecStorage):
type: ClassVar = "local"
fs = LocalFileSystem()
fs: AbstractFileSystem = LocalFileSystem()

def get_base_path(self):
return self.uri
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
install_requires = [
"dill",
"requests",
"isort>4",
"isort>=5.10",
"docker",
"pydantic>=1.9.0,<2",
"typer",
Expand Down Expand Up @@ -148,6 +148,7 @@
"model_type.catboost = mlem.contrib.catboost:CatBoostModel",
"model_type.lightgbm = mlem.contrib.lightgbm:LightGBMModel",
"model_type.sklearn = mlem.contrib.sklearn:SklearnModel",
"model_type.sklearn_pipeline = mlem.contrib.sklearn:SklearnPipelineType",
"model_type.xgboost = mlem.contrib.xgboost:XGBoostModel",
"model_type.torch = mlem.contrib.torch:TorchModel",
"packager.docker = mlem.contrib.docker.base:DockerImagePackager",
Expand Down
16 changes: 14 additions & 2 deletions tests/contrib/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import pytest
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

from mlem.contrib.numpy import NumpyNdarrayType
from mlem.contrib.sklearn import SklearnModel
Expand Down Expand Up @@ -38,13 +41,22 @@ def regressor(inp_data, out_data):
return lr


@pytest.fixture()
def pipeline(inp_data, out_data):
pipe = Pipeline([("scaler", StandardScaler()), ("svc", SVC())])
pipe.fit(inp_data, out_data)
return pipe


@pytest.fixture
def lgbm_model(inp_data, out_data):
lgbm_regressor = lgb.LGBMRegressor()
return lgbm_regressor.fit(inp_data, out_data)


@pytest.mark.parametrize("model_fixture", ["classifier", "regressor"])
@pytest.mark.parametrize(
"model_fixture", ["classifier", "regressor", "pipeline"]
)
def test_hook(model_fixture, inp_data, request):
model = request.getfixturevalue(model_fixture)
data_type = DatasetAnalyzer.analyze(inp_data)
Expand Down Expand Up @@ -77,7 +89,7 @@ def test_hook_lgb(lgbm_model, inp_data):
)


@pytest.mark.parametrize("model", ["classifier", "regressor"])
@pytest.mark.parametrize("model", ["classifier", "regressor", "pipeline"])
def test_model_type__predict(model, inp_data, request):
model = request.getfixturevalue(model)
model_type = ModelAnalyzer.analyze(model, sample_data=inp_data)
Expand Down

0 comments on commit 83f169f

Please # to comment.