Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix validation and scale accessor multiscale image #719

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
### Minor

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
- Added `get_pyramid_levels()` utils API

## [0.2.3] - 2024-09-25

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Operations on `SpatialData` objects.
unpad_raster
are_extents_equal
deepcopy
get_pyramid_levels
```

## Models
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"get_centroids",
"read_zarr",
"unpad_raster",
"get_pyramid_levels",
"save_transformations",
"get_dask_backing_files",
"are_extents_equal",
Expand Down Expand Up @@ -75,4 +76,4 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import get_dask_backing_files, save_transformations
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import unpad_raster
from spatialdata._utils import get_pyramid_levels, unpad_raster
22 changes: 3 additions & 19 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from xarray import DataArray

from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import iterate_pyramid_levels
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import (
MappingToCoordinateSystem_t,
SpatialElement,
Expand Down Expand Up @@ -122,22 +122,6 @@ def _write_metadata(
group.attrs["spatialdata_attrs"] = attrs


def _iter_multiscale(
data: DataTree,
attr: str | None,
) -> list[Any]:
# TODO: put this check also in the validator for raster multiscales
for i in data:
variables = set(data[i].variables.keys())
names: set[str] = variables.difference({"c", "z", "y", "x"})
if len(names) != 1:
raise ValueError(f"Invalid variable name: `{names}`.")
name: str = next(iter(names))
if attr is not None:
return [getattr(data[i][name], attr) for i in data]
return [data[i][name] for i in data]


class dircmp(filecmp.dircmp): # type: ignore[type-arg]
"""
Compare the content of dir1 and dir2.
Expand Down Expand Up @@ -241,8 +225,8 @@ def _(element: DataArray) -> list[str]:

@get_dask_backing_files.register(DataTree)
def _(element: DataTree) -> list[str]:
xdata0 = next(iter(iterate_pyramid_levels(element)))
return _get_backing_files(xdata0.data)
dask_data_scale0 = get_pyramid_levels(element, attr="data", n=0)
return _get_backing_files(dask_data_scale0)


@get_dask_backing_files.register(DaskDataFrame)
Expand Down
10 changes: 5 additions & 5 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from spatialdata._io._utils import (
_get_transformations_from_ngff_dict,
_iter_multiscale,
overwrite_coordinate_transformations_raster,
)
from spatialdata._io.format import (
Expand All @@ -26,6 +25,7 @@
RasterFormatV01,
_parse_version,
)
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channels
from spatialdata.models.models import ATTRS_KEY
from spatialdata.transformations._utils import (
Expand Down Expand Up @@ -180,8 +180,8 @@ def _get_group_for_writing_transformations() -> zarr.Group:
group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes
)
elif isinstance(raster_data, DataTree):
data = _iter_multiscale(raster_data, "data")
list_of_input_axes: list[Any] = _iter_multiscale(raster_data, "dims")
data = get_pyramid_levels(raster_data, attr="data")
list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims")
assert len(set(list_of_input_axes)) == 1
input_axes = list_of_input_axes[0]
# saving only the transformations of the first scale
Expand All @@ -191,8 +191,8 @@ def _get_group_for_writing_transformations() -> zarr.Group:
transformations = _get_transformations_xarray(xdata)
assert transformations is not None
assert len(transformations) > 0
chunks = _iter_multiscale(raster_data, "chunks")
# coords = _iter_multiscale(raster_data, "coords")
chunks = get_pyramid_levels(raster_data, "chunks")
# coords = iterate_pyramid_levels(raster_data, "coords")
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format)
storage_options = [{"chunks": chunk} for chunk in chunks]
write_multi_scale_ngff(
Expand Down
63 changes: 36 additions & 27 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import warnings
from collections.abc import Generator
from itertools import islice
from typing import Any, Callable, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -150,45 +151,53 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]:
return compute_coordinates(unpadded)


# TODO: probably we want this method to live in multiscale_spatial_image
def multiscale_spatial_image_from_data_tree(data_tree: DataTree) -> DataTree:
warnings.warn(
f"{multiscale_spatial_image_from_data_tree} is deprecated and will be removed in version 0.2.0.",
DeprecationWarning,
stacklevel=2,
)
d = {}
for k, dt in data_tree.items():
v = dt.values()
assert len(v) == 1
xdata = v.__iter__().__next__()
d[k] = xdata
def get_pyramid_levels(image: DataTree, attr: str | None = None, n: int | None = None) -> list[Any] | Any:
"""
Access the data/attribute of the pyramid levels of a multiscale spatial image.

Parameters
----------
image
The multiscale spatial image.
attr
If `None`, return the data of the pyramid level as a `DataArray`, if not None, return the specified attribute
within the `DataArray` data.
n
If not None, return only the `n` pyramid level.

return DataTree.from_dict(d)
Returns
-------
The pyramid levels data (or an attribute of it) as a list or a generator.
"""
generator = iterate_pyramid_levels(image, attr)
if n is not None:
return next(iter(islice(generator, n, None)))
return list(generator)


# TODO: this functions is similar to _iter_multiscale(), the latter is more powerful but not exposed to the user.
# Use only one and expose it to the user in this file
def iterate_pyramid_levels(image: DataTree) -> Generator[DataArray, None, None]:
def iterate_pyramid_levels(
data: DataTree,
attr: str | None,
) -> Generator[Any, None, None]:
"""
Iterate over the pyramid levels of a multiscale spatial image.

Parameters
----------
image
The multiscale spatial image.
data
The multiscale spatial image
attr
If `None`, return the data of the pyramid level as a `DataArray`, if not None, return the specified attribute
within the `DataArray` data.

Returns
-------
A generator that yields the pyramid levels.
A generator to iterate over the pyramid levels.
"""
for k in range(len(image)):
scale_name = f"scale{k}"
dt = image[scale_name]
v = dt.values()
assert len(v) == 1
xdata = next(iter(v))
yield xdata
names = data["scale0"].ds.keys()
name: str = next(iter(names))
for scale in data:
yield data[scale][name] if attr is None else getattr(data[scale][name], attr)


def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: AnnData) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def _(self, data: DataTree) -> None:
if j != k:
raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.")
name = {list(data[i].data_vars.keys())[0] for i in data}
if len(name) > 1:
raise ValueError(f"Wrong name for datatree: `{name}`.")
if len(name) != 1:
raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.")
name = list(name)[0]
for d in data:
super().validate(data[d][name])
Expand Down
4 changes: 2 additions & 2 deletions tests/core/operations/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from spatialdata import SpatialData, get_extent
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.query.relational_query import get_element_instances
from spatialdata._io._utils import _iter_multiscale
from spatialdata._utils import get_pyramid_levels
from spatialdata.models import PointsModel, ShapesModel, TableModel, get_axes_names
from spatialdata.models._utils import get_spatial_axes
from spatialdata.transformations import MapAxis
Expand Down Expand Up @@ -57,7 +57,7 @@ def _get_data_of_largest_scale(raster):
if isinstance(raster, DataArray):
return raster.data.compute()

xdata = next(iter(_iter_multiscale(raster, None)))
xdata = get_pyramid_levels(raster, n=0)
return xdata.data.compute()

for element_name, raster in rasters.items():
Expand Down
Loading