diff --git a/nutils/SI.py b/nutils/SI.py index d0e0a627e..73ea67397 100644 --- a/nutils/SI.py +++ b/nutils/SI.py @@ -35,7 +35,7 @@ >>> w = SI.Velocity('8km') Traceback (most recent call last): ... - TypeError: expected [L/T], got [L] + nutils.SI.DimensionError: expected [L/T], got [L] Explicit subtypes can also be used in function annotations: @@ -126,6 +126,10 @@ import functools +class DimensionError(TypeError): + pass + + class Dimension(type): __cache = {} # subtypes @@ -220,7 +224,7 @@ def __call__(cls, value): q = parse(value) expect = float if not cls.__powers else cls if type(q) != expect: - raise TypeError(f'expected {expect.__name__}, got {type(q).__name__}') + raise DimensionError(f'expected {expect.__name__}, got {type(q).__name__}') return q def wrap(cls, value): @@ -335,7 +339,7 @@ def __unary(op, *args, **kwargs): def __add_like(op, *args, **kwargs): (dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1]) if dim0 != dim1: - raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}') + raise DimensionError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}') return dim0.wrap(op(arg0, arg1, *args[2:], **kwargs)) @register('mul', 'multiply', 'matmul') @@ -363,7 +367,7 @@ def __sqrt(op, *args, **kwargs): def __setitem(op, *args, **kwargs): (dim0, arg0), (dim2, arg2) = Quantity.__unpack(args[0], args[2]) if dim0 != dim2: - raise TypeError(f'cannot assign {dim2.__name__} to {dim0.__name__}') + raise DimensionError(f'cannot assign {dim2.__name__} to {dim0.__name__}') return dim0.wrap(op(arg0, args[1], arg2, *args[3:], **kwargs)) @register('pow', 'power', 'jacobian') @@ -372,23 +376,23 @@ def __pow_like(op, *args, **kwargs): return (dim0**args[1]).wrap(op(arg0, *args[1:], **kwargs)) @register('isfinite', 'isnan', 'shape', 'ndim', 'size', 'normal', 'normalized') - def __unary_drop(op, *args, **kwargs): + def __unary_op(op, *args, **kwargs): (_dim0, arg0), = Quantity.__unpack(args[0]) return op(arg0, *args[1:], **kwargs) @register('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less', 'less_equal', 'greater', 'greater_equal') - def __binary_drop(op, *args, **kwargs): + def __binary_op(op, *args, **kwargs): (dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1]) if dim0 != dim1: - raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}') + raise DimensionError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}') return op(arg0, arg1, *args[2:], **kwargs) @register('stack', 'concatenate') def __stack_like(op, *args, **kwargs): dims, arg0 = zip(*Quantity.__unpack(*args[0])) if any(dim != dims[0] for dim in dims[1:]): - raise TypeError(f'incompatible arguments for {op.__name__}: ' + ', '.join(dim.__name__ for dim in dims)) + raise DimensionError(f'incompatible arguments for {op.__name__}: ' + ', '.join(dim.__name__ for dim in dims)) return dims[0].wrap(op(arg0, *args[1:], **kwargs)) @register('curvature') @@ -410,40 +414,47 @@ def __dotarg(op, *args, **kwargs): ## DEFINE OPERATORS - def op(name, with_reverse=False, *, __table=__DISPATCH_TABLE): - dispatch = __table[name] - op = getattr(operator, name) - ret = lambda *args: dispatch(op, *args) - if with_reverse: - ret = ret, lambda self, other: dispatch(op, other, self) - return ret + def op(name, __table=__DISPATCH_TABLE): + dispatch, op = __table[name], getattr(operator, name) + return lambda *args: dispatch(op, *args) __getitem__ = op('getitem') __setitem__ = op('setitem') __neg__ = op('neg') __pos__ = op('pos') __abs__ = op('abs') + + def op(name, __table=__DISPATCH_TABLE): + dispatch, op = __table[name], getattr(operator, name) + return lambda *args: _try_or_noimp(dispatch, op, *args) + __lt__ = op('lt') __le__ = op('le') __eq__ = op('eq') __ne__ = op('ne') __gt__ = op('gt') __ge__ = op('ge') - __add__, __radd__ = op('add', True) - __sub__, __rsub__ = op('sub', True) - __mul__, __rmul__ = op('mul', True) - __matmul__, __rmatmul__ = op('matmul', True) - __truediv, __rtruediv__ = op('truediv', True) - __mod__, __rmod__ = op('mod', True) - __pow__, __rpow__ = op('pow', True) + + def op(name, __table=__DISPATCH_TABLE): + dispatch, op = __table[name], getattr(operator, name) + return lambda self, other: _try_or_noimp(dispatch, op, self, other), \ + lambda self, other: _try_or_noimp(dispatch, op, other, self) + + __add__, __radd__ = op('add') + __sub__, __rsub__ = op('sub') + __mul__, __rmul__ = op('mul') + __matmul__, __rmatmul__ = op('matmul') + __truediv, __rtruediv__ = op('truediv') + __mod__, __rmod__ = op('mod') + __pow__, __rpow__ = op('pow') + + del op def __truediv__(self, other): if type(other) is str: return self.__value / self.__class__(other).__value return self.__truediv(other) - del op - ## DISPATCH THIRD PARTY CALLS def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @@ -500,6 +511,13 @@ def _split_factors(s): isnumer = False +def _try_or_noimp(func, *args): + try: + return func(*args) + except DimensionError: + return NotImplemented + + ## SI DIMENSIONS Dimensionless = Dimension.from_powers({}) diff --git a/nutils/_util.py b/nutils/_util.py index a4aa57a4c..2116c2b55 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -455,7 +455,7 @@ def log_arguments(*args, **kwargs): with treelog.context('arguments'): for k, v in bound.arguments.items(): try: - s = stringly.dumps(sig.parameters[k].annotation, v) + s = stringly.dumps(_infer_type(sig.parameters[k]), v) except: s = str(v) treelog.info(f'{k}={s}') @@ -630,6 +630,16 @@ def add_htmllog(outrootdir: str = '~/public_html', outrooturi: str = '', scriptn treelog.info(f'log written to: {loguri}') +def _infer_type(param): + '''Infer argument type from annotation or default value.''' + + if param.annotation is not param.empty: + return param.annotation + if param.default is not param.empty: + return type(param.default) + raise Exception(f'cannot determine type for argument {param.name!r}') + + def cli(f, *, argv=None): '''Call a function using command line arguments.''' @@ -642,11 +652,7 @@ def cli(f, *, argv=None): mandatory = set() for param in inspect.signature(f).parameters.values(): - T = param.annotation - if T == param.empty and param.default != param.empty: - T = type(param.default) - if T == param.empty: - raise Exception(f'cannot determine type for argument {param.name!r}') + T = _infer_type(param) try: s = stringly.serializer.get(T) except Exception as e: diff --git a/tests/test_SI.py b/tests/test_SI.py index 7c42e22e0..a21684217 100644 --- a/tests/test_SI.py +++ b/tests/test_SI.py @@ -68,7 +68,7 @@ def test_setitem(self): F = SI.units.N * numpy.zeros(3) F[0] = SI.Force('1N') F[1] = SI.Force('2N') - with self.assertRaisesRegex(TypeError, r'cannot assign \[L2\] to \[M\*L/T2\]'): + with self.assertRaisesRegex(SI.DimensionError, r'cannot assign \[L2\] to \[M\*L/T2\]'): F[2] = SI.Area('10m2') F[2] = SI.Force('3N') self.assertTrue(numpy.all(F == SI.units.N * numpy.array([1, 2, 3]))) @@ -104,18 +104,18 @@ def test_power(self): def test_add(self): self.assertEqual(SI.Mass('2kg') + SI.Mass('3kg'), SI.Mass('5kg')) self.assertEqual(numpy.add(SI.Mass('2kg'), SI.Mass('3kg')), SI.Mass('5kg')) - with self.assertRaisesRegex(TypeError, r'incompatible arguments for add: \[M\], \[L\]'): + with self.assertRaisesRegex(TypeError, r"unsupported operand type\(s\) for \+: '\[M\]' and '\[L\]'"): SI.Mass('2kg') + SI.Length('3m') def test_sub(self): self.assertEqual(SI.Mass('2kg') - SI.Mass('3kg'), SI.Mass('-1kg')) self.assertEqual(numpy.subtract(SI.Mass('2kg'), SI.Mass('3kg')), SI.Mass('-1kg')) - with self.assertRaisesRegex(TypeError, r'incompatible arguments for sub: \[M\], \[L\]'): + with self.assertRaisesRegex(TypeError, r"unsupported operand type\(s\) for \-: '\[M\]' and '\[L\]'"): SI.Mass('2kg') - SI.Length('3m') def test_hypot(self): self.assertEqual(numpy.hypot(SI.Mass('3kg'), SI.Mass('4kg')), SI.Mass('5kg')) - with self.assertRaisesRegex(TypeError, r'incompatible arguments for hypot: \[M\], \[L\]'): + with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for hypot: \[M\], \[L\]'): numpy.hypot(SI.Mass('3kg'), SI.Length('4m')) def test_neg(self): @@ -224,7 +224,7 @@ def test_stack(self): C = SI.Mass('4kg') D = SI.Time('5s') self.assertTrue(numpy.all(numpy.stack([A, B, C]) == SI.units.kg * numpy.array([2, 3, 4]))) - with self.assertRaisesRegex(TypeError, r'incompatible arguments for stack: \[M\], \[M\], \[M\], \[T\]'): + with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for stack: \[M\], \[M\], \[M\], \[T\]'): numpy.stack([A, B, C, D]) def test_concatenate(self): @@ -232,7 +232,7 @@ def test_concatenate(self): B = SI.units.kg * numpy.array([3, 4]) C = SI.units.s * numpy.array([5, 6]) self.assertTrue(numpy.all(numpy.concatenate([A, B]) == SI.units.kg * numpy.array([1, 2, 3, 4]))) - with self.assertRaisesRegex(TypeError, r'incompatible arguments for concatenate: \[M\], \[M\], \[T\]'): + with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for concatenate: \[M\], \[M\], \[T\]'): numpy.concatenate([A, B, C]) def test_format(self):