Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Better error message on extension import error (#300)
Browse files Browse the repository at this point in the history
closes #200
  • Loading branch information
mike0sv authored Jun 16, 2022
1 parent b70975d commit c701297
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 7 deletions.
21 changes: 19 additions & 2 deletions mlem/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Literal
from yaml import safe_load

from mlem.core.errors import UnknownImplementation
from mlem.core.errors import ExtensionRequirementError, UnknownImplementation
from mlem.polydantic import PolyModel
from mlem.utils.importing import import_string
from mlem.utils.path import make_posix
Expand Down Expand Up @@ -54,7 +54,24 @@ def load_impl_ext(
eps = load_entrypoints()
for ep in eps.values():
if ep.abs_name == abs_name and ep.name == type_name:
obj = ep.ep.load()
try:
obj = ep.ep.load()
except ImportError as e:
from mlem.ext import ExtensionLoader

ext = ExtensionLoader.builtin_extensions.get(
ep.ep.module_name, None
)
reqs: List[str]
if ext is None:
reqs = [e.name] if e.name is not None else []
extra = None
else:
reqs = ext.reqs
extra = ext.extra
raise ExtensionRequirementError(
ep.name or "", reqs, extra
) from e
if not issubclass(obj, MlemABC):
raise ValueError(f"{obj} is not subclass of MlemABC")
return obj
Expand Down
16 changes: 16 additions & 0 deletions mlem/core/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Exceptions raised by the MLEM."""
from typing import List, Optional

from mlem.constants import MLEM_DIR


Expand Down Expand Up @@ -143,3 +145,17 @@ class UnknownConfigSection(MlemError):
def __init__(self, section: str):
self.section = section
super().__init__(f'Unknown config section "{section}"')


class ExtensionRequirementError(MlemError):
def __init__(self, ext: str, reqs: List[str], extra: Optional[str]):
self.ext = ext
self.reqs = reqs
self.extra = extra
extra_install = (
"" if extra is None else f"`pip install mlem[{extra}]` or "
)
reqs_str = " ".join(reqs)
super().__init__(
f"Extension '{ext}' requires additional dependencies: {extra_install}`pip install {reqs_str}`"
)
10 changes: 7 additions & 3 deletions mlem/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import sys
from types import ModuleType
from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Optional, Union

from mlem.config import LOCAL_CONFIG
from mlem.utils.importing import (
Expand Down Expand Up @@ -34,11 +34,15 @@ def __init__(
reqs: List[str],
force: bool = True,
validator: Callable[[], bool] = None,
extra: Optional[str] = "",
):
self.force = force
self.reqs = reqs
self.module = module
self.validator = validator
self.extra = extra
if extra == "":
self.extra = module.split(".")[-1]

def __str__(self):
return f"<Extension {self.module}>"
Expand Down Expand Up @@ -81,7 +85,7 @@ class ExtensionLoader:
builtin_extensions: Dict[str, Extension] = ExtensionDict(
Extension("mlem.contrib.numpy", ["numpy"], False),
Extension("mlem.contrib.pandas", ["pandas"], False),
Extension("mlem.contrib.sklearn", ["sklearn"], False),
Extension("mlem.contrib.sklearn", ["scipy", "scikit-learn"], False),
# Extension('mlem.contrib.tensorflow', ['tensorflow'], False, is_tf_v1),
# Extension('mlem.contrib.tensorflow_v2', ['tensorflow'], False, is_tf_v2),
Extension("mlem.contrib.torch", ["torch"], False),
Expand All @@ -94,7 +98,7 @@ class ExtensionLoader:
Extension("mlem.contrib.docker", ["docker"], False),
Extension("mlem.contrib.fastapi", ["fastapi", "uvicorn"], False),
Extension("mlem.contrib.callable", [], True),
Extension("mlem.contrib.rabbitmq", ["pika"], False),
Extension("mlem.contrib.rabbitmq", ["pika"], False, extra="rmq"),
)

_loaded_extensions: Dict[Extension, ModuleType] = {}
Expand Down
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"dill",
"requests",
"isort>=5.10",
"docker",
"pydantic>=1.9.0,<2",
"typer",
"click<8.1",
Expand Down Expand Up @@ -42,12 +41,17 @@
"gcsfs",
"testcontainers",
"emoji",
"lxml",
"openpyxl",
"xlrd",
"tables",
"pyarrow",
]

extras = {
"tests": tests,
"dvc": ["dvc~=2.0"],
"pandas": ["pandas", "lxml", "openpyxl", "xlrd", "tables", "pyarrow"],
"pandas": ["pandas"],
"numpy": ["numpy"],
"sklearn": ["scipy", "scikit-learn"],
"catboost": ["catboost"],
Expand All @@ -65,6 +69,7 @@
"s3": ["s3fs[boto3]>=2021.11.1", "aiobotocore[boto3]>2"],
"ssh": ["bcrypt", "sshfs[bcrypt]>=2021.11.2"],
"rmq": ["pika"],
"docker": ["docker"],
}

extras["all"] = [_ for e in extras.values() for _ in e]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mlem import ExtensionLoader
from mlem.utils.entrypoints import (
MLEM_ENTRY_POINT,
find_implementations,
Expand Down Expand Up @@ -27,3 +28,17 @@ def test_all_impls_in_entrypoints():
exts = {e.entry for e in exts.values()}
impls = set(find_implementations()[MLEM_ENTRY_POINT])
assert exts == impls


def test_all_ext_has_pip_extra():
from setup import extras

exts_reqs = {
v.extra: v.reqs
for v in ExtensionLoader.builtin_extensions.values()
if v.extra is not None and len(v.reqs)
}

for name, reqs in exts_reqs.items():
assert name in extras
assert set(reqs) == set(extras[name])

0 comments on commit c701297

Please # to comment.