Skip to content

Commit

Permalink
Replace some internal uses of unyt_array with cosmo_array.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Jan 28, 2025
1 parent ef43675 commit c86b9c9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
55 changes: 43 additions & 12 deletions swiftsimio/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@

import warnings

import unyt
import h5py

import numpy as np

from swiftsimio.metadata.objects import SWIFTMetadata

from swiftsimio.objects import InvalidSnapshot
from swiftsimio.objects import (
InvalidSnapshot,
cosmo_array,
cosmo_quantity,
cosmo_factor,
a,
)

from swiftsimio.accelerated import ranges_from_array
from typing import Dict


class SWIFTMask(object):
Expand Down Expand Up @@ -229,11 +233,19 @@ def _unpack_cell_metadata(self):
self.counts[key] = counts[sort]

# Also need to sort centers in the same way
self.centers = unyt.unyt_array(centers_handle[:][sort], units=self.units.length)
self.centers = cosmo_array(
centers_handle[:][sort],
units=self.units.length,
comoving=True,
cosmo_factor=cosmo_factor(a**1, self.metadata.scale_factor),
)

# Note that we cannot assume that these are cubic, unfortunately.
self.cell_size = unyt.unyt_array(
metadata_handle.attrs["size"], units=self.units.length
self.cell_size = cosmo_array(
metadata_handle.attrs["size"],
units=self.units.length,
comoving=True,
cosmo_factor=cosmo_factor(a**1, self.metadata.scale_factor),
)

return
Expand All @@ -242,8 +254,8 @@ def constrain_mask(
self,
group_name: str,
quantity: str,
lower: unyt.array.unyt_quantity,
upper: unyt.array.unyt_quantity,
lower: cosmo_quantity,
upper: cosmo_quantity,
):
"""
Constrains the mask further for a given particle type, and bounds a
Expand All @@ -263,10 +275,10 @@ def constrain_mask(
quantity : str
quantity being constrained
lower : unyt.array.unyt_quantity
lower : ~swiftsimio.objects.cosmo_quantity
constraint lower bound
upper : unyt.array.unyt_quantity
upper : ~swiftsimio.objects.cosmo_quantity
constraint upper bound
See Also
Expand Down Expand Up @@ -300,13 +312,32 @@ def constrain_mask(

handle = handle_dict[quantity]

physical_dict = {
k: v
for k, v in zip(group_metadata.field_names, group_metadata.field_physicals)
}

physical = physical_dict[quantity]

cosmologies_dict = {
k: v
for k, v in zip(
group_metadata.field_names, group_metadata.field_cosmologies
)
}

cosmology_factor = cosmologies_dict[quantity]

# Load in the relevant data.

with h5py.File(self.metadata.filename, "r") as file:
# Surprisingly this is faster than just using the boolean
# indexing because h5py has slow indexing routines.
data = unyt.unyt_array(
np.take(file[handle], np.where(current_mask)[0], axis=0), units=unit
data = cosmo_array(
np.take(file[handle], np.where(current_mask)[0], axis=0),
units=unit,
comoving=not physical,
cosmo_factor=cosmology_factor,
)

new_mask = np.logical_and.reduce([data > lower, data <= upper])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from swiftsimio import load, mask
import numpy as np

from unyt import unyt_array as array, dimensionless
from unyt import dimensionless
from swiftsimio import cosmo_array as array


@requires("cosmological_volume.hdf5")
Expand Down

0 comments on commit c86b9c9

Please # to comment.