Skip to content

Commit

Permalink
Support __round__ for cosmo_quantity and dot for cosmo_array.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Feb 14, 2025
1 parent 1b3feae commit 71c9212
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
6 changes: 6 additions & 0 deletions swiftsimio/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
_arctan2_cosmo_factor,
_comparison_cosmo_factor,
_prepare_array_func_args,
_default_binary_wrapper,
)

try:
Expand Down Expand Up @@ -1260,6 +1261,7 @@ def __setstate__(self, state: Tuple) -> None:
__getitem__ = _propagate_cosmo_array_attributes_to_result(
_ensure_result_is_cosmo_array_or_quantity(unyt_array.__getitem__)
)
dot = _default_binary_wrapper(unyt_array.dot, _multiply_cosmo_factor)

# Also wrap some array "properties":
T = property(_propagate_cosmo_array_attributes_to_result(unyt_array.transpose))
Expand Down Expand Up @@ -1799,3 +1801,7 @@ def __new__(
if ret.size > 1:
raise RuntimeError("cosmo_quantity instances must be scalars")
return ret

__round__ = _propagate_cosmo_array_attributes_to_result(
_ensure_result_is_cosmo_array_or_quantity(unyt_quantity.__round__)
)
27 changes: 27 additions & 0 deletions tests/test_cosmo_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,16 @@ def test_iter(self):
for cq in ca(np.arange(3)):
assert isinstance(cq, cosmo_quantity)

def test_dot(self):
"""
Make sure that we get a cosmo_array when we use array attribute dot.
"""
res = ca(np.arange(3)).dot(ca(np.arange(3)))
assert isinstance(res, cosmo_quantity)
assert res.comoving is False
assert res.cosmo_factor == cosmo_factor(a ** 2, 0.5)
assert res.valid_transform is True


class TestCosmoQuantity:
"""
Expand Down Expand Up @@ -810,6 +820,23 @@ def test_propagation_func(self, func, args):
assert res.cosmo_factor == cosmo_factor(a ** 1, 1.0)
assert res.valid_transform is True

def test_round(self):
"""
Test that attributes propagate through the round builtin.
"""
cq = cosmo_quantity(
1.03,
u.m,
comoving=False,
cosmo_factor=cosmo_factor(a ** 1, 1.0),
valid_transform=True,
)
res = round(cq)
assert res.value == 1.0
assert res.comoving is False
assert res.cosmo_factor == cosmo_factor(a ** 1, 1.0)
assert res.valid_transform is True

def test_scalar_return_func(self):
"""
Make sure that default-wrapped functions that take a cosmo_array and return a
Expand Down

0 comments on commit 71c9212

Please # to comment.