From 9f72c7db639404b692707707faa412557af0bc68 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Wed, 20 Mar 2024 14:05:48 +0100 Subject: [PATCH 1/3] add _util.IDDict --- nutils/_util.py | 50 +++++++++++++++++++++++++++++++++++++++++ tests/test_util.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/nutils/_util.py b/nutils/_util.py index 2116c2b55..908fc7c51 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -778,4 +778,54 @@ def wrapper(*args, **kwargs): return wrapper +class IDDict: + '''Mapping from instance (is, not ==) to value. Keys need not be hashable.''' + + def __init__(self): + self.__dict = {} + + def __setitem__(self, key, value): + self.__dict[id(key)] = key, value + + def __getitem__(self, key): + key_, value = self.__dict[id(key)] + assert key_ is key + return value + + def get(self, key, default=None): + kv = self.__dict.get(id(key)) + if kv is None: + return default + key_, value = kv + assert key_ is key + return value + + def __delitem__(self, key): + del self.__dict[id(key)] + + def __len__(self): + return len(self.__dict) + + def keys(self): + return (key for key, value in self.__dict.values()) + + def values(self): + return (value for key, value in self.__dict.values()) + + def items(self): + return self.__dict.values() + + def __iter__(self): + return self.keys() + + def __contains__(self, key): + return self.__dict.__contains__(id(key)) + + def __str__(self): + return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}' + + def __repr__(self): + return self.__str__() + + # vim:sw=4:sts=4:et diff --git a/tests/test_util.py b/tests/test_util.py index 938b48b1c..889c0166a 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -394,3 +394,58 @@ def test_double(self): def test_notimp(self): notimp = self.NotImp() self.assertEqual(self.f(notimp, self.Ten()), (notimp, 10, 2)) + + +class IDDict(TestCase): + + def setUp(self): + self.d = util.IDDict() + self.a, self.b = 'ab' + self.d[self.a] = 1 + self.d[self.b] = 2 + + def test_getitem(self): + self.assertEqual(self.d[self.a], 1) + self.assertEqual(self.d[self.b], 2) + + def test_setitem(self): + c = 'c' + self.d[c] = 3 + self.assertEqual(self.d[c], 3) + + def test_delitem(self): + del self.d[self.a] + self.assertNotIn(self.a, self.d) + self.assertIn(self.b, self.d) + + def test_contains(self): + self.assertIn(self.a, self.d) + self.assertIn(self.b, self.d) + c = 'c' + self.assertNotIn('c', self.d) + + def test_len(self): + self.assertEqual(len(self.d), 2) + + def test_get(self): + self.assertEqual(self.d.get(self.a, 10), 1) + self.assertEqual(self.d.get(self.b, 10), 2) + self.assertEqual(self.d.get('c', 10), 10) + + def test_keys(self): + self.assertEqual(list(self.d.keys()), ['a', 'b']) + + def test_iter(self): + self.assertEqual(list(self.d), ['a', 'b']) + + def test_values(self): + self.assertEqual(list(self.d.values()), [1, 2]) + + def test_items(self): + self.assertEqual(list(self.d.items()), [('a', 1), ('b', 2)]) + + def test_str(self): + self.assertEqual(str(self.d), "{'a': 1, 'b': 2}") + + def test_repr(self): + self.assertEqual(repr(self.d), "{'a': 1, 'b': 2}") From d25e8eddd63d040875ab061543dd6cf366ceb51f Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Wed, 20 Mar 2024 14:06:51 +0100 Subject: [PATCH 2/3] add deep_replace_property, shallow_replace utils This patch adds the deep_replace_property and shallow_replace decorators to _util. The differences between the two constructs are as follows: @_util.deep_replace_property - property - intermediate values cached in object attribute - depth first - recursive @_util.shallow_replace - function - intermediate values cached only during replacement - depth last - non recursive --- nutils/_util.py | 168 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_util.py | 52 ++++++++++++++ 2 files changed, 220 insertions(+) diff --git a/nutils/_util.py b/nutils/_util.py index 908fc7c51..138bb28e3 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -828,4 +828,172 @@ def __repr__(self): return self.__str__() +def _tuple(*args): + return args + + +def _reduce(obj): + 'helper function for deep_replace_property and shallow_replace' + + T = type(obj) + if T in (tuple, list, dict, set, frozenset): + if not obj: # empty containers need not be entered + return + elif T is tuple: + return _tuple, obj + elif T is dict: + return T, (tuple(obj.items()),) + else: + return T, (tuple(obj),) + try: + f, args = obj.__reduce__() + except: + return + else: + return f, args + + +class deep_replace_property: + '''decorator for deep object replacement + + Generates a cached property for deep replacement of reduceable objects, + based on a callable that is applied depth first and recursively on + individual constructor arguments. Intermediate values are stored in the + attribute by the same name of any object that is a descendent of the class + that owns the property. + + Args + ---- + func + Callable which maps an object onto a new object, or ``None`` if no + replacement is made. It must have precisely one positional argument for + the object. + ''' + + identity = object() + recreate = collections.namedtuple('recreate', ['f', 'nargs']) + + def __init__(self, func): + self.func = func + + def __set_name__(self, owner, name): + self.owner = owner + self.name = name + + def __set__(self, obj, value): + raise AttributeError("can't set attribute") + + def __delete__(self, obj): + raise AttributeError("can't delete attribute") + + def __get__(self, obj, objtype=None): + fstack = [obj] # stack of unprocessed objects and command tokens + rstack = [] # stack of processed objects + ostack = [] # stack of original objects to cache new value into + + while fstack: + obj = fstack.pop() + + if isinstance(obj, self.recreate): # recreate object from rstack + f, nargs = obj + r = f(*[rstack.pop() for _ in range(nargs)]) + if isinstance(r, self.owner) and (newr := self.func(r)) is not None: + fstack.append(newr) # recursion + else: + rstack.append(r) + + elif obj is ostack: # store new representation + orig = ostack.pop() + r = rstack[-1] + if r is orig: # this may happen if obj is memoizing + r = self.identity # prevent cyclic reference + orig.__dict__[self.name] = r + + elif isinstance(obj, self.owner): + if (r := obj.__dict__.get(self.name)) is not None: # in cache + rstack.append(r if r is not self.identity else obj) + elif obj in ostack: + index = ostack.index(obj) + raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop of size {len(ostack)-index}') + else: + ostack.append(obj) + fstack.append(ostack) + f, args = obj.__reduce__() + fstack.append(self.recreate(f, len(args))) + fstack.extend(args) + + elif reduced := _reduce(obj): + f, args = reduced + fstack.append(self.recreate(f, len(args))) + fstack.extend(args) + + else: + rstack.append(obj) + + assert not ostack + assert len(rstack) == 1 + return rstack[0] + + +def shallow_replace(func): + '''decorator for deep object replacement + + Generates a deep replacement method for reduceable objects based on a + callable that is applied on individual constructor arguments. The + replacement takes a shallow first approach and stops as soon as the + callable returns a value that is not ``None``. Intermediate values are + flushed upon return. + + Args + ---- + func + Callable which maps an object onto a new object, or ``None`` if no + replacement is made. It must have one positional argument for the object, + and may have any number of additional positional and/or keyword + arguments. + + Returns + ------- + :any:`callable` + The method that searches the object to perform the replacements. + ''' + + recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig']) + + @functools.wraps(func) + def wrapped(target, *funcargs, **funckwargs): + fstack = [target] # stack of unprocessed objects and command tokens + rstack = [] # stack of processed objects + cache = IDDict() # cache of seen objects + + while fstack: + obj = fstack.pop() + + if isinstance(obj, recreate): + f, nargs, orig = obj + r = f(*[rstack.pop() for _ in range(nargs)]) + cache[orig] = r + rstack.append(r) + + elif (r := cache.get(obj)) is not None: + rstack.append(r) + + elif (r := func(obj, *funcargs, **funckwargs)) is not None: + cache[obj] = r + rstack.append(r) + + elif reduced := _reduce(obj): + f, args = reduced + fstack.append(recreate(f, len(args), obj)) + fstack.extend(args) + + else: # obj cannot be reduced + rstack.append(obj) + + assert len(rstack) == 1 + return rstack[0] + + return wrapped + + # vim:sw=4:sts=4:et diff --git a/tests/test_util.py b/tests/test_util.py index 889c0166a..4ea82b926 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -449,3 +449,55 @@ def test_str(self): def test_repr(self): self.assertEqual(repr(self.d), "{'a': 1, 'b': 2}") + + +class replace(TestCase): + + class Base: + def __init__(self, *args): + self.args = args + self.called = False + def __reduce__(self): + return type(self), self.args + @util.deep_replace_property + def simple(self): + assert not self.called, 'caching failure: simple called twice on the same object' + self.called = True + if isinstance(self, replace.Ten): + return replace.Intermediate() # to test recursion + if isinstance(self, replace.Intermediate): + return 10 + + class Ten(Base): pass + class Intermediate(Base): pass + + @staticmethod + @util.shallow_replace + def subs10(obj, value): + if isinstance(obj, replace.Ten): + return value + + def test_deep_simple(self): + ten = self.Ten() + self.assertEqual(ten.simple, 10) + self.assertIn('simple', ten.__dict__) + self.assertEqual(ten.simple, 10) + + def test_deep_nested(self): + ten = self.Ten() + obj = self.Base(5, {7, ten}) + self.assertEqual(type(obj.simple), type(obj)) + self.assertEqual(obj.simple.args, (5, {7, 10})) + self.assertIn('simple', obj.__dict__) + self.assertIn('simple', ten.__dict__) + + def test_shallow_simple(self): + ten = self.Ten() + self.assertEqual(self.subs10(ten, 20), 20) + + def test_shallow_nested(self): + ten = self.Ten() + obj = self.Base(5, {7, ten}) + newobj = self.subs10(obj, 20) + self.assertEqual(type(newobj), type(obj)) + self.assertEqual(newobj.args, (5, {7, 20})) From b0b02a8325a68893ea3c617a54c15282b3d1cc0f Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Mon, 18 Mar 2024 16:56:00 +0100 Subject: [PATCH 3/3] replace function.replace by util variants This patch removes the function.replace decorator and replaces it with _util.deep_replace_property (for simplified and optimized_for_numpy) and _util.shallow_replace (for replace_arguments, _deep_flatten_constants and _combine_loop_concatenates). The specialized variants are faster, use less memory, and offer better cache reuse. --- nutils/evaluable.py | 180 +++++--------------------------------------- 1 file changed, 17 insertions(+), 163 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 107117c03..b3a417ec0 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -46,7 +46,6 @@ import collections.abc import math import treelog as log -import weakref import time import contextlib import subprocess @@ -129,149 +128,6 @@ class ExpensiveEvaluationWarning(warnings.NutilsInefficiencyWarning): pass -def replace(func=None, depthfirst=False, recursive=False, lru=4): - '''decorator for deep object replacement - - Generates a deep replacement method for general objects based on a callable - that is applied (recursively) on individual constructor arguments. - - Args - ---- - func - Callable which maps an object onto a new object, or `None` if no - replacement is made. It must have one positional argument for the object, - and may have any number of additional positional and/or keyword - arguments. - depthfirst : :class:`bool` - If `True`, decompose each object as far a possible, then apply `func` to - all arguments as the objects are reconstructed. Otherwise apply `func` - directly on each new object that is encountered in the decomposition, - proceding only if the return value is `None`. - recursive : :class:`bool` - If `True`, repeat replacement for any object returned by `func` until it - returns `None`. Otherwise perform a single, non-recursive sweep. - lru : :class:`int` - Maximum size of the least-recently-used cache. A persistent weak-key - dictionary is maintained for every unique set of function arguments. When - the size of `lru` is reached, the least recently used cache is dropped. - - Returns - ------- - :any:`callable` - The method that searches the object to perform the replacements. - ''' - - if func is None: - return functools.partial(replace, depthfirst=depthfirst, recursive=recursive, lru=lru) - - signature = inspect.signature(func) - arguments = [] # list of past function arguments, least recently used last - caches = [] # list of weak-key dictionaries matching arguments (above) - - remember = object() # token to signal that rstack[-1] can be cached as the replacement of fstack[-1] - recreate = object() # token to signal that all arguments for object recreation are ready on rstack - pending = object() # token to hold the place of a cachable object pending creation - identity = object() # token to hold the place of the cache value in case it matches key, to avoid circular references - - @functools.wraps(func) - def wrapped(target, *funcargs, **funckwargs): - - # retrieve or create a weak-key dictionary - bound = signature.bind(None, *funcargs, **funckwargs) - bound.apply_defaults() - try: - index = arguments.index(bound.arguments) # by using index, arguments need not be hashable - except ValueError: - index = -1 - cache = weakref.WeakKeyDictionary() - else: - cache = caches[index] - if index != 0: # function arguments are not the most recent (possibly new) - if index > 0 or len(arguments) >= lru: - caches.pop(index) # pop matching (or oldest) item - arguments.pop(index) - caches.insert(0, cache) # insert popped (or new) item to front - arguments.insert(0, bound.arguments) - - fstack = [target] # stack of unprocessed objects and command tokens - rstack = [] # stack of processed objects - _stack = fstack if recursive else rstack - - try: - while fstack: - obj = fstack.pop() - - if obj is recreate: - args = [rstack.pop() for obj in range(fstack.pop())] - f = fstack.pop() - r = f(*args) - if depthfirst: - newr = func(r, *funcargs, **funckwargs) - if newr is not None: - _stack.append(newr) - continue - rstack.append(r) - continue - - if obj is remember: - obj = fstack.pop() - cache[obj] = rstack[-1] if rstack[-1] is not obj else identity - continue - - if obj.__class__ in (tuple, list, dict, set, frozenset): - if not obj: - rstack.append(obj) # shortcut to avoid recreation of empty container - else: - fstack.append(lambda *x, T=type(obj): T(x)) - fstack.append(len(obj)) - fstack.append(recreate) - fstack.extend(obj if not isinstance(obj, dict) else obj.items()) - continue - - try: - r = cache[obj] - except KeyError: # object can be weakly cached, but isn't - cache[obj] = pending - fstack.append(obj) - fstack.append(remember) - except TypeError: # object cannot be referenced or is not hashable - pass - else: # object is in cache - if r is pending: - pending_objs = tuple(k for k, v in cache.items() if v is pending) - index = pending_objs.index(obj) - raise Exception('{}@replace caught in a circular dependence\n'.format(func.__name__) + Tuple(pending_objs[index:]).asciitree().split('\n', 1)[1]) - rstack.append(r if r is not identity else obj) - continue - - if not depthfirst: - newr = func(obj, *funcargs, **funckwargs) - if newr is not None: - _stack.append(newr) - continue - - try: - f, args = obj.__reduce__() - except: # obj cannot be reduced into a constructor and its arguments - rstack.append(obj) - else: - fstack.append(f) - fstack.append(len(args)) - fstack.append(recreate) - fstack.extend(args) - - assert len(rstack) == 1 - - finally: - while fstack: - if fstack.pop() is remember: - assert cache.pop(fstack.pop()) is pending - - return rstack[0] - - return wrapped - - class Evaluable(types.Singleton): 'Base class' @@ -434,36 +290,34 @@ def _format_stack(self, values, e): lines.append(f'{next(stack)} --> {e}') return '\n '.join(lines) - @property - @replace(depthfirst=True, recursive=True) + @util.deep_replace_property def simplified(obj): - if isinstance(obj, Evaluable): - retval = obj._simplified() - if retval is not None and isinstance(obj, Array): - assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval) - return retval + retval = obj._simplified() + if retval is not None and isinstance(obj, Array): + assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval) + return retval def _simplified(self): return @cached_property def optimized_for_numpy(self): - retval = self.simplified._optimized_for_numpy1() or self - retval = retval._deep_flatten_constants() or retval - return retval._combine_loop_concatenates(frozenset()) + return self.simplified \ + ._optimized_for_numpy1 \ + ._deep_flatten_constants() \ + ._combine_loop_concatenates(frozenset()) - @replace(depthfirst=True, recursive=True) + @util.deep_replace_property def _optimized_for_numpy1(obj): - if isinstance(obj, Evaluable): - retval = obj._simplified() or obj._optimized_for_numpy() - if retval is not None and isinstance(obj, Array): - assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__) - return retval + retval = obj._simplified() or obj._optimized_for_numpy() + if retval is not None and isinstance(obj, Array): + assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__) + return retval def _optimized_for_numpy(self): return - @replace(depthfirst=False, recursive=False) + @util.shallow_replace def _deep_flatten_constants(self): if isinstance(self, Array): return self._flatten_constant() @@ -511,7 +365,7 @@ def _combine_loop_concatenates(self, outer_exclude): intbounds = dict(zip(('_lower', '_upper'), lc._intbounds)) if lc.dtype == int else {} replacements[lc] = ArrayFromTuple(combined, i, lc.shape, lc.dtype, **intbounds) if replacements: - self = replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None, recursive=False, depthfirst=False)(self) + self = util.shallow_replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None)(self) else: return self @@ -5087,7 +4941,7 @@ def loop_concatenate_combined(funcs, index): return tuple(ArrayFromTuple(loop, unique_funcs.index(func), tuple(shape), func.dtype) for func, start, stop, *shape in unique_func_data) -@replace +@util.shallow_replace def replace_arguments(value, arguments): '''Replace :class:`Argument` objects in ``value``.