From 28d6a552c62a9d5d9fd1a21a20f241842c01463d Mon Sep 17 00:00:00 2001 From: mike0sv Date: Sat, 30 Jul 2022 00:27:07 +0300 Subject: [PATCH] replace _ExternalRef logic with persistent_id --- mlem/contrib/callable.py | 77 ++++++---------------------------------- 1 file changed, 11 insertions(+), 66 deletions(-) diff --git a/mlem/contrib/callable.py b/mlem/contrib/callable.py index bec8af44..f5b3fcbe 100644 --- a/mlem/contrib/callable.py +++ b/mlem/contrib/callable.py @@ -1,13 +1,11 @@ -import pickle import posixpath from collections import defaultdict from importlib import import_module from io import BytesIO -from pickle import _Unpickler # type: ignore from typing import Any, Callable, ClassVar, Dict, Optional, Tuple from uuid import uuid4 -from dill import Pickler +from dill import Pickler, Unpickler from mlem.core.artifacts import Artifacts, Storage from mlem.core.hooks import LOW_PRIORITY_VALUE @@ -146,38 +144,12 @@ def __init__(self, model, *args, **kwargs): self.model = model self.refs: Dict[str, Tuple[ModelIO, Any]] = {} - # we couldn't import hook and analyzer at top as it leads to circular import failure known_types = set() for hook in ModelAnalyzer.hooks: if not isinstance(hook, CallableModelType) and hook.valid_types: known_types.update(hook.valid_types) self.known_types = tuple(known_types) - # pickle "hook" for overriding serialization of objects - def save(self, obj, save_persistent_id=True): - """ - Checks if obj has IO. - If it does, serializes object with `ModelIO.dump` - and creates a ref to it. Otherwise, saves object as default pickle would do - :param obj: obj to save - :param save_persistent_id: - :return: - """ - if obj is self.model: - # at starting point, follow usual path not to fall into infinite loop - return super().save(obj, save_persistent_id) - - io = self._get_non_pickle_io(obj) - if io is None: - # no non-Pickle IO found, follow usual path - return super().save(obj, save_persistent_id) - - # found model with non-pickle serialization: - # replace with `_ExternalRef` stub and memorize IO to serialize model aside later - obj_uuid = str(uuid4()) - self.refs[obj_uuid] = (io, obj) - return super().save(_ExternalRef(obj_uuid), save_persistent_id) - def _get_non_pickle_io(self, obj): """ Checks if obj has non-Pickle IO and returns it @@ -200,49 +172,22 @@ def _get_non_pickle_io(self, obj): # non-model object return None + def persistent_id(self, obj: Any) -> Any: + io = self._get_non_pickle_io(obj) + if io is None: + return None + obj_uuid = str(uuid4()) + self.refs[obj_uuid] = (io, obj) + return obj_uuid -# `Unpickler`, unlike `_Unpickler`, doesn't support `load_build` overriding -class _ModelUnpickler(_Unpickler): - """ - A class to unpickle model saved with :class:`_ModelPickler` - :param refs: dict of object uuid -> ref_obj - :param args: pickle._Unpickler args - :param kwargs: pickle._Unpickle kwargs - """ - - dispatch = _Unpickler.dispatch.copy() +class _ModelUnpickler(Unpickler): def __init__(self, refs, *args, **kwargs): super().__init__(*args, **kwargs) self.refs = refs - # pickle "hook" for overriding deserialization of objects - def load_build(self): - """ - Checks if last builded object is :class:`_ExternalRef` and if it is, swaps it with referenced object - :return: - """ - super().load_build() - - # this is the last deserialized object for now - obj = self.stack[-1] - if not isinstance(obj, _ExternalRef): - return - - # replace `_ExternalRef` with a real model it references - self.stack.pop() - self.stack.append(self.refs[obj.ref]) - - dispatch[pickle.BUILD[0]] = load_build # type: ignore - - -class _ExternalRef: - """ - A class to mark objects dumped their own :class:`ModelIO` - """ - - def __init__(self, ref: str): - self.ref = ref + def persistent_load(self, pid: str) -> Any: + return self.refs[pid] class CallableModelType(ModelType, ModelHook):