Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Defer type-checking from object_to_item to Item classes #355

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/skore/item/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Item types for the skore package."""

from __future__ import annotations

from contextlib import suppress
from typing import Any

from skore.item.item import Item
from skore.item.item_repository import ItemRepository
from skore.item.media_item import MediaItem
Expand All @@ -8,6 +13,29 @@
from skore.item.primitive_item import PrimitiveItem
from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem


def object_to_item(object: Any) -> Item:
"""Transform an object into an Item."""
for cls in (
PrimitiveItem,
PandasDataFrameItem,
NumpyArrayItem,
SklearnBaseEstimatorItem,
MediaItem,
):
with suppress(ImportError, TypeError):
# ImportError:
# The factories are responsible to import third-party libraries in a
# lazy way. If library is missing, an ImportError exception will
# automatically be thrown.
# TypeError:
# The factories are responsible for checking that parameters are of the
# correct type. If not, they throw a TypeError exception.
return cls.factory(object)

raise NotImplementedError(f"Type '{object.__class__}' is not supported.")


__all__ = [
"Item",
"ItemRepository",
Expand All @@ -16,4 +44,5 @@
"PandasDataFrameItem",
"PrimitiveItem",
"SklearnBaseEstimatorItem",
"object_to_item",
]
2 changes: 1 addition & 1 deletion src/skore/item/media_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def factory(cls, media, *args, **kwargs):
if lazy_is_instance(media, "PIL.Image.Image"):
return cls.factory_pillow(media, *args, **kwargs)

raise NotImplementedError(f"Type '{media.__class__}' is not yet supported")
raise TypeError(f"Type '{media.__class__}' is not supported.")

@classmethod
def factory_bytes(
Expand Down
5 changes: 5 additions & 0 deletions src/skore/item/numpy_array_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def factory(cls, array: numpy.ndarray) -> NumpyArrayItem:
NumpyArrayItem
A new NumpyArrayItem instance.
"""
import numpy

if not isinstance(array, numpy.ndarray):
raise TypeError(f"Type '{array.__class__}' is not supported.")

instance = cls(array_list=array.tolist())

# add array as cached property
Expand Down
5 changes: 5 additions & 0 deletions src/skore/item/pandas_dataframe_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def factory(cls, dataframe: pandas.DataFrame) -> PandasDataFrameItem:
PandasDataFrameItem
A new PandasDataFrameItem instance.
"""
import pandas.core.frame

if not isinstance(dataframe, pandas.core.frame.DataFrame):
raise TypeError(f"Type '{dataframe.__class__}' is not supported.")

instance = cls(dataframe_dict=dataframe.to_dict(orient="tight"))

# add dataframe as cached property
Expand Down
2 changes: 1 addition & 1 deletion src/skore/item/primitive_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,6 @@ def factory(cls, primitive: Primitive) -> PrimitiveItem:
A new PrimitiveItem instance.
"""
if not is_primitive(primitive):
raise ValueError(f"{primitive} is not Primitive.")
raise TypeError(f"Type '{primitive.__class__}' is not supported.")

return cls(primitive=primitive)
4 changes: 4 additions & 0 deletions src/skore/item/sklearn_base_estimator_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI
SklearnBaseEstimatorItem
A new SklearnBaseEstimatorItem instance.
"""
import sklearn.base
import sklearn.utils
import skops.io

if not isinstance(estimator, sklearn.base.BaseEstimator):
raise TypeError(f"Type '{estimator.__class__}' is not supported.")

instance = cls(
estimator_skops=skops.io.dumps(estimator),
estimator_html_repr=sklearn.utils.estimator_html_repr(estimator),
Expand Down
40 changes: 11 additions & 29 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,20 @@
from pathlib import Path
from typing import Any

from skore.item import Item
from skore.item.item_repository import ItemRepository
from skore.item.media_item import MediaItem, lazy_is_instance
from skore.item.numpy_array_item import NumpyArrayItem
from skore.item.pandas_dataframe_item import PandasDataFrameItem
from skore.item.primitive_item import PrimitiveItem, is_primitive
from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem
from skore.item import (
Item,
ItemRepository,
MediaItem,
NumpyArrayItem,
PandasDataFrameItem,
PrimitiveItem,
SklearnBaseEstimatorItem,
object_to_item,
)
from skore.layout import Layout, LayoutRepository
from skore.persistence.disk_cache_storage import DirectoryDoesNotExist, DiskCacheStorage


def object_to_item(o: Any) -> Item:
"""Transform an object into an Item."""
if is_primitive(o):
return PrimitiveItem.factory(o)
elif lazy_is_instance(o, "pandas.core.frame.DataFrame"):
return PandasDataFrameItem.factory(o)
elif lazy_is_instance(o, "numpy.ndarray"):
return NumpyArrayItem.factory(o)
elif lazy_is_instance(o, "sklearn.base.BaseEstimator"):
return SklearnBaseEstimatorItem.factory(o)
elif lazy_is_instance(o, "altair.vegalite.v5.schema.core.TopLevelSpec"):
return MediaItem.factory_altair(o)
elif lazy_is_instance(o, "matplotlib.figure.Figure"):
return MediaItem.factory_matplotlib(o)
elif lazy_is_instance(o, "PIL.Image.Image"):
return MediaItem.factory_pillow(o)
else:
raise NotImplementedError(f"Type {o.__class__.__name__} is not supported yet.")


class Project:
"""A project is a collection of items that are stored in a storage."""

Expand All @@ -47,8 +30,7 @@ def __init__(

def put(self, key: str, value: Any):
"""Add a value to the Project."""
item = object_to_item(value)
self.put_item(key, item)
self.put_item(key, object_to_item(value))

def put_item(self, key: str, item: Item):
"""Add an Item to the Project."""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/item/test_media_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def monkeypatch_datetime(self, monkeypatch, MockDatetime):
monkeypatch.setattr("skore.item.item.datetime", MockDatetime)

def test_factory_exception(self):
with pytest.raises(NotImplementedError):
with pytest.raises(TypeError):
MediaItem.factory(None)

def test_factory_bytes(self, mock_nowstr):
Expand Down