Skip to content

Commit

Permalink
Some cleanup of array constructors.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Feb 7, 2025
1 parent 85f3900 commit 0f80416
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions swiftsimio/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __rtruediv__(self, b):
def __pow__(self, p):
if self.expr is None:
return cosmo_factor(expr=None, scale_factor=self.scale_factor)
return cosmo_factor(expr=self.expr ** p, scale_factor=self.scale_factor)
return cosmo_factor(expr=self.expr**p, scale_factor=self.scale_factor)

def __lt__(self, b):
return self.a_factor < b.a_factor
Expand Down Expand Up @@ -539,9 +539,9 @@ def __new__(
none is found, uses np.float64
bypass_validation : bool, optional
If True, all input validation is skipped. Using this option may produce
corrupted, invalid units or array data, but can lead to significant speedups
in the input validation logic adds significant overhead. If set, input_units
must be a valid unit object. Defaults to False.
corrupted or invalid data, but can lead to significant speedups
in the input validation logic adds significant overhead. If set, minimally
pass valid values for units, comoving and cosmo_factor. Defaults to False.
input_units : str, optional
deprecated in favour of units option
name : str, optional
Expand All @@ -562,6 +562,22 @@ def __new__(

cosmo_factor: cosmo_factor

if bypass_validation is True:
obj = super().__new__(
cls,
input_array,
units=units,
registry=registry,
dtype=dtype,
bypass_validation=bypass_validation,
name=name,
)
# dtype, units, registry & name handled by unyt
obj.comoving = comoving
obj.cosmo_factor = cosmo_factor if cosmo_factor is not None else NULL_CF
obj.valid_transform = valid_transform
obj.compression = compression
return obj
if isinstance(input_array, cosmo_array):
if comoving:
input_array.convert_to_comoving()
Expand Down Expand Up @@ -594,8 +610,6 @@ def __new__(
input_array = helper_result["args"]
if cosmo_factor is None:
cosmo_factor = _preserve_cosmo_factor(*helper_result["cfs"])
if cosmo_factor is None:
cosmo_factor = NULL_CF
elif all([cf is None for cf in helper_result["cfs"]]):
cosmo_factor = cosmo_factor
else:
Expand Down Expand Up @@ -1046,9 +1060,9 @@ def __new__(
none is found, uses np.float64
bypass_validation : bool, optional
If True, all input validation is skipped. Using this option may produce
corrupted, invalid units or array data, but can lead to significant speedups
in the input validation logic adds significant overhead. If set, input_units
must be a valid unit object. Defaults to False.
corrupted or invalid data, but can lead to significant speedups
in the input validation logic adds significant overhead. If set, minimally
pass valid values for units, comoving and cosmo_factor. Defaults to False.
name : str, optional
The name of the array. Defaults to None. This attribute does not propagate
through mathematical operations, but is preserved under indexing and unit
Expand All @@ -1064,23 +1078,33 @@ def __new__(
Description of the compression filters that were applied to that array in the
hdf5 file.
"""
if not (
bypass_validation
or isinstance(input_scalar, (numeric_type, np.number, np.ndarray))
):
if bypass_validation is True:
ret = super().__new__(
cls,
np.asarray(input_scalar),
units,
registry,
dtype=dtype,
bypass_validation=bypass_validation,
name=name,
cosmo_factor=cosmo_factor,
comoving=comoving,
valid_transform=valid_transform,
compression=compression,
)

if not isinstance(input_scalar, (numeric_type, np.number, np.ndarray)):
raise RuntimeError("cosmo_quantity values must be numeric")

# Use values from kwargs, if None use values from input_scalar
units = getattr(input_scalar, "units", None) if units is None else units
name = getattr(input_scalar, "name", None) if name is None else name
if hasattr(input_scalar, "cosmo_factor") and (cosmo_factor is not None):
cosmo_factor = _preserve_cosmo_factor(
cosmo_factor, getattr(input_scalar, "cosmo_factor")
)
elif cosmo_factor is None and hasattr(input_scalar, "cosmo_factor"):
cosmo_factor = getattr(input_scalar, "cosmo_factor")
elif cosmo_factor is not None and not hasattr(input_scalar, "cosmo_factor"):
pass
else:
cosmo_factor = (
getattr(input_scalar, "cosmo_factor", None)
if cosmo_factor is None
else cosmo_factor
)
if cosmo_factor is None:
cosmo_factor = NULL_CF
comoving = (
getattr(input_scalar, "comoving", None) if comoving is None else comoving
Expand Down

0 comments on commit 0f80416

Please # to comment.