Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix SI handling of special binary methods #860

Merged
merged 4 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
>>> w = SI.Velocity('8km')
Traceback (most recent call last):
...
TypeError: expected [L/T], got [L]
nutils.SI.DimensionError: expected [L/T], got [L]

Explicit subtypes can also be used in function annotations:

Expand Down Expand Up @@ -126,6 +126,10 @@
import functools


class DimensionError(TypeError):
pass


class Dimension(type):

__cache = {} # subtypes
Expand Down Expand Up @@ -220,7 +224,7 @@
q = parse(value)
expect = float if not cls.__powers else cls
if type(q) != expect:
raise TypeError(f'expected {expect.__name__}, got {type(q).__name__}')
raise DimensionError(f'expected {expect.__name__}, got {type(q).__name__}')

Check warning on line 227 in nutils/SI.py

View check run for this annotation

Codecov / codecov/patch

nutils/SI.py#L227

Added line #L227 was not covered by tests
return q

def wrap(cls, value):
Expand Down Expand Up @@ -335,7 +339,7 @@
def __add_like(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
if dim0 != dim1:
raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')
raise DimensionError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')
return dim0.wrap(op(arg0, arg1, *args[2:], **kwargs))

@register('mul', 'multiply', 'matmul')
Expand Down Expand Up @@ -363,7 +367,7 @@
def __setitem(op, *args, **kwargs):
(dim0, arg0), (dim2, arg2) = Quantity.__unpack(args[0], args[2])
if dim0 != dim2:
raise TypeError(f'cannot assign {dim2.__name__} to {dim0.__name__}')
raise DimensionError(f'cannot assign {dim2.__name__} to {dim0.__name__}')
return dim0.wrap(op(arg0, args[1], arg2, *args[3:], **kwargs))

@register('pow', 'power', 'jacobian')
Expand All @@ -372,23 +376,23 @@
return (dim0**args[1]).wrap(op(arg0, *args[1:], **kwargs))

@register('isfinite', 'isnan', 'shape', 'ndim', 'size', 'normal', 'normalized')
def __unary_drop(op, *args, **kwargs):
def __unary_op(op, *args, **kwargs):
(_dim0, arg0), = Quantity.__unpack(args[0])
return op(arg0, *args[1:], **kwargs)

@register('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less',
'less_equal', 'greater', 'greater_equal')
def __binary_drop(op, *args, **kwargs):
def __binary_op(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
if dim0 != dim1:
raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')
raise DimensionError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')

Check warning on line 388 in nutils/SI.py

View check run for this annotation

Codecov / codecov/patch

nutils/SI.py#L388

Added line #L388 was not covered by tests
return op(arg0, arg1, *args[2:], **kwargs)

@register('stack', 'concatenate')
def __stack_like(op, *args, **kwargs):
dims, arg0 = zip(*Quantity.__unpack(*args[0]))
if any(dim != dims[0] for dim in dims[1:]):
raise TypeError(f'incompatible arguments for {op.__name__}: ' + ', '.join(dim.__name__ for dim in dims))
raise DimensionError(f'incompatible arguments for {op.__name__}: ' + ', '.join(dim.__name__ for dim in dims))
return dims[0].wrap(op(arg0, *args[1:], **kwargs))

@register('curvature')
Expand All @@ -410,40 +414,47 @@

## DEFINE OPERATORS

def op(name, with_reverse=False, *, __table=__DISPATCH_TABLE):
dispatch = __table[name]
op = getattr(operator, name)
ret = lambda *args: dispatch(op, *args)
if with_reverse:
ret = ret, lambda self, other: dispatch(op, other, self)
return ret
def op(name, __table=__DISPATCH_TABLE):
dispatch, op = __table[name], getattr(operator, name)
return lambda *args: dispatch(op, *args)

__getitem__ = op('getitem')
__setitem__ = op('setitem')
__neg__ = op('neg')
__pos__ = op('pos')
__abs__ = op('abs')

def op(name, __table=__DISPATCH_TABLE):
dispatch, op = __table[name], getattr(operator, name)
return lambda *args: _try_or_noimp(dispatch, op, *args)

__lt__ = op('lt')
__le__ = op('le')
__eq__ = op('eq')
__ne__ = op('ne')
__gt__ = op('gt')
__ge__ = op('ge')
__add__, __radd__ = op('add', True)
__sub__, __rsub__ = op('sub', True)
__mul__, __rmul__ = op('mul', True)
__matmul__, __rmatmul__ = op('matmul', True)
__truediv, __rtruediv__ = op('truediv', True)
__mod__, __rmod__ = op('mod', True)
__pow__, __rpow__ = op('pow', True)

def op(name, __table=__DISPATCH_TABLE):
dispatch, op = __table[name], getattr(operator, name)
return lambda self, other: _try_or_noimp(dispatch, op, self, other), \
lambda self, other: _try_or_noimp(dispatch, op, other, self)

__add__, __radd__ = op('add')
__sub__, __rsub__ = op('sub')
__mul__, __rmul__ = op('mul')
__matmul__, __rmatmul__ = op('matmul')
__truediv, __rtruediv__ = op('truediv')
__mod__, __rmod__ = op('mod')
__pow__, __rpow__ = op('pow')

del op

def __truediv__(self, other):
if type(other) is str:
return self.__value / self.__class__(other).__value
return self.__truediv(other)

del op

## DISPATCH THIRD PARTY CALLS

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
Expand Down Expand Up @@ -500,6 +511,13 @@
isnumer = False


def _try_or_noimp(func, *args):
try:
return func(*args)
except DimensionError:
return NotImplemented


## SI DIMENSIONS

Dimensionless = Dimension.from_powers({})
Expand Down
18 changes: 12 additions & 6 deletions nutils/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@
with treelog.context('arguments'):
for k, v in bound.arguments.items():
try:
s = stringly.dumps(sig.parameters[k].annotation, v)
s = stringly.dumps(_infer_type(sig.parameters[k]), v)
except:
s = str(v)
treelog.info(f'{k}={s}')
Expand Down Expand Up @@ -630,6 +630,16 @@
treelog.info(f'log written to: {loguri}')


def _infer_type(param):
'''Infer argument type from annotation or default value.'''

if param.annotation is not param.empty:
return param.annotation
if param.default is not param.empty:
return type(param.default)

Check warning on line 639 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L639

Added line #L639 was not covered by tests
raise Exception(f'cannot determine type for argument {param.name!r}')


def cli(f, *, argv=None):
'''Call a function using command line arguments.'''

Expand All @@ -642,11 +652,7 @@
mandatory = set()

for param in inspect.signature(f).parameters.values():
T = param.annotation
if T == param.empty and param.default != param.empty:
T = type(param.default)
if T == param.empty:
raise Exception(f'cannot determine type for argument {param.name!r}')
T = _infer_type(param)
try:
s = stringly.serializer.get(T)
except Exception as e:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_setitem(self):
F = SI.units.N * numpy.zeros(3)
F[0] = SI.Force('1N')
F[1] = SI.Force('2N')
with self.assertRaisesRegex(TypeError, r'cannot assign \[L2\] to \[M\*L/T2\]'):
with self.assertRaisesRegex(SI.DimensionError, r'cannot assign \[L2\] to \[M\*L/T2\]'):
F[2] = SI.Area('10m2')
F[2] = SI.Force('3N')
self.assertTrue(numpy.all(F == SI.units.N * numpy.array([1, 2, 3])))
Expand Down Expand Up @@ -104,18 +104,18 @@ def test_power(self):
def test_add(self):
self.assertEqual(SI.Mass('2kg') + SI.Mass('3kg'), SI.Mass('5kg'))
self.assertEqual(numpy.add(SI.Mass('2kg'), SI.Mass('3kg')), SI.Mass('5kg'))
with self.assertRaisesRegex(TypeError, r'incompatible arguments for add: \[M\], \[L\]'):
with self.assertRaisesRegex(TypeError, r"unsupported operand type\(s\) for \+: '\[M\]' and '\[L\]'"):
SI.Mass('2kg') + SI.Length('3m')

def test_sub(self):
self.assertEqual(SI.Mass('2kg') - SI.Mass('3kg'), SI.Mass('-1kg'))
self.assertEqual(numpy.subtract(SI.Mass('2kg'), SI.Mass('3kg')), SI.Mass('-1kg'))
with self.assertRaisesRegex(TypeError, r'incompatible arguments for sub: \[M\], \[L\]'):
with self.assertRaisesRegex(TypeError, r"unsupported operand type\(s\) for \-: '\[M\]' and '\[L\]'"):
SI.Mass('2kg') - SI.Length('3m')

def test_hypot(self):
self.assertEqual(numpy.hypot(SI.Mass('3kg'), SI.Mass('4kg')), SI.Mass('5kg'))
with self.assertRaisesRegex(TypeError, r'incompatible arguments for hypot: \[M\], \[L\]'):
with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for hypot: \[M\], \[L\]'):
numpy.hypot(SI.Mass('3kg'), SI.Length('4m'))

def test_neg(self):
Expand Down Expand Up @@ -224,15 +224,15 @@ def test_stack(self):
C = SI.Mass('4kg')
D = SI.Time('5s')
self.assertTrue(numpy.all(numpy.stack([A, B, C]) == SI.units.kg * numpy.array([2, 3, 4])))
with self.assertRaisesRegex(TypeError, r'incompatible arguments for stack: \[M\], \[M\], \[M\], \[T\]'):
with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for stack: \[M\], \[M\], \[M\], \[T\]'):
numpy.stack([A, B, C, D])

def test_concatenate(self):
A = SI.units.kg * numpy.array([1, 2])
B = SI.units.kg * numpy.array([3, 4])
C = SI.units.s * numpy.array([5, 6])
self.assertTrue(numpy.all(numpy.concatenate([A, B]) == SI.units.kg * numpy.array([1, 2, 3, 4])))
with self.assertRaisesRegex(TypeError, r'incompatible arguments for concatenate: \[M\], \[M\], \[T\]'):
with self.assertRaisesRegex(SI.DimensionError, r'incompatible arguments for concatenate: \[M\], \[M\], \[T\]'):
numpy.concatenate([A, B, C])

def test_format(self):
Expand Down
Loading